289 lines
9.7 KiB
Python
289 lines
9.7 KiB
Python
import functools
|
|
import logging
|
|
import bisect
|
|
import collections
|
|
import asyncio
|
|
import collections
|
|
from typing import Union, Optional
|
|
|
|
import diskcache
|
|
from fastapi import Request
|
|
from telethon import types
|
|
|
|
logger = logging.getLogger(__file__.split("/")[-1])
|
|
|
|
@functools.total_ordering
|
|
class MediaChunkHolder(object):
|
|
waiters: collections.deque[asyncio.Future]
|
|
requester: list[Request] = []
|
|
chunk_id: int = 0
|
|
is_done: bool = False
|
|
|
|
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
|
|
self.chat_id = chat_id
|
|
self.msg_id = msg_id
|
|
self.start = start
|
|
self.target_len = target_len
|
|
self.mem = mem or bytes()
|
|
self.length = len(self.mem)
|
|
self.waiters = collections.deque()
|
|
|
|
def __repr__(self) -> str:
|
|
return f"MediaChunk,start:{self.start},len:{self.length}"
|
|
|
|
def __eq__(self, other: 'MediaChunkHolder'):
|
|
if isinstance(other, int):
|
|
return self.start == other
|
|
return self.start == other.start
|
|
|
|
def __le__(self, other: 'MediaChunkHolder'):
|
|
if isinstance(other, int):
|
|
return self.start <= other
|
|
return self.start <= other.start
|
|
|
|
def __gt__(self, other: 'MediaChunkHolder'):
|
|
if isinstance(other, int):
|
|
return self.start > other
|
|
return self.start > other.start
|
|
|
|
def __add__(self, other: Union['MediaChunkHolder', bytes]):
|
|
if isinstance(other, MediaChunkHolder):
|
|
other = other.mem
|
|
self.append_chunk_mem(other)
|
|
|
|
def is_completed(self) -> bool:
|
|
return self.length >= self.target_len
|
|
|
|
def set_done(self) -> None:
|
|
# self.is_done = True
|
|
# self.notify_waiters()
|
|
self.requester.clear()
|
|
|
|
def notify_waiters(self) -> None:
|
|
while self.waiters:
|
|
waiter = self.waiters.popleft()
|
|
if not waiter.done():
|
|
waiter.set_result(None)
|
|
|
|
def _set_chunk_mem(self, mem: Optional[bytes]) -> None:
|
|
self.mem = mem
|
|
self.length = len(self.mem)
|
|
if self.length > self.target_len:
|
|
raise RuntimeWarning(
|
|
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
|
|
|
|
def append_chunk_mem(self, mem: bytes) -> None:
|
|
self.mem = self.mem + mem
|
|
self.length = len(self.mem)
|
|
if self.length > self.target_len:
|
|
raise RuntimeWarning(
|
|
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
|
|
self.notify_waiters()
|
|
|
|
def add_chunk_requester(self, req: Request) -> None:
|
|
self.requester.append(req)
|
|
|
|
async def is_disconneted(self) -> bool:
|
|
while self.requester:
|
|
res = await self.requester[0].is_disconnected()
|
|
if res:
|
|
self.requester.pop(0)
|
|
continue
|
|
return res
|
|
return True
|
|
|
|
async def wait_chunk_update(self) -> None:
|
|
if self.is_done:
|
|
return
|
|
waiter = asyncio.Future()
|
|
self.waiters.append(waiter)
|
|
try:
|
|
await waiter
|
|
except:
|
|
waiter.cancel()
|
|
try:
|
|
self.waiters.remove(waiter)
|
|
except ValueError:
|
|
pass
|
|
|
|
class MediaChunkHolderManager(object):
|
|
MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb
|
|
current_cache_size: int = 0
|
|
# chat_id -> msg_id -> offset -> mem
|
|
chunk_cache: dict[int, dict[int,
|
|
list[MediaChunkHolder]]] = {}
|
|
# ChunkHolderId -> ChunkHolder
|
|
unique_chunk_id: int = 0
|
|
chunk_lru: collections.OrderedDict[int, MediaChunkHolder]
|
|
|
|
def __init__(self) -> None:
|
|
self.chunk_lru = collections.OrderedDict()
|
|
|
|
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]:
|
|
chat_cache = self.chunk_cache.get(msg.chat_id)
|
|
if chat_cache is None:
|
|
return None
|
|
return chat_cache.get(msg.id)
|
|
|
|
def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional[MediaChunkHolder]:
|
|
msg_cache = self._get_media_msg_cache(msg)
|
|
if msg_cache is None or len(msg_cache) == 0:
|
|
return None
|
|
pos = bisect.bisect_left(msg_cache, start)
|
|
if pos == len(msg_cache):
|
|
pos = pos - 1
|
|
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
|
|
return msg_cache[pos]
|
|
return None
|
|
elif msg_cache[pos].start == start:
|
|
return msg_cache[pos]
|
|
elif pos > 0:
|
|
pos = pos - 1
|
|
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
|
|
return msg_cache[pos]
|
|
return None
|
|
return None
|
|
|
|
def _remove_pop_chunk(self, pop_chunk: MediaChunkHolder) -> None:
|
|
self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove(
|
|
pop_chunk.start)
|
|
self.current_cache_size -= pop_chunk.target_len
|
|
if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0:
|
|
self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id)
|
|
if len(self.chunk_cache[pop_chunk.chat_id]) == 0:
|
|
self.chunk_cache.pop(pop_chunk.chat_id)
|
|
|
|
def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional[MediaChunkHolder]:
|
|
res = self._get_media_chunk_cache(msg, start)
|
|
if res is None:
|
|
return None
|
|
if lru:
|
|
self.chunk_lru.move_to_end(res.chunk_id)
|
|
return res
|
|
|
|
def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
|
|
cache_chat = self.chunk_cache.get(chunk.chat_id)
|
|
if cache_chat is None:
|
|
self.chunk_cache[chunk.chat_id] = {}
|
|
cache_chat = self.chunk_cache[chunk.chat_id]
|
|
cache_msg = cache_chat.get(chunk.msg_id)
|
|
if cache_msg is None:
|
|
cache_chat[chunk.msg_id] = []
|
|
cache_msg = cache_chat[chunk.msg_id]
|
|
chunk.chunk_id = self.unique_chunk_id
|
|
self.unique_chunk_id += 1
|
|
bisect.insort(cache_msg, chunk)
|
|
self.chunk_lru[chunk.chunk_id] = chunk
|
|
self.current_cache_size += chunk.target_len
|
|
while self.current_cache_size > self.MAX_CACHE_SIZE:
|
|
dummy = self.chunk_lru.popitem(last=False)
|
|
self._remove_pop_chunk(dummy[1])
|
|
|
|
def cancel_media_chunk(self, chunk: MediaChunkHolder) -> None:
|
|
cache_chat = self.chunk_cache.get(chunk.chat_id)
|
|
if cache_chat is None:
|
|
return
|
|
cache_msg = cache_chat.get(chunk.msg_id)
|
|
if cache_msg is None:
|
|
return
|
|
dummy = self.chunk_lru.pop(chunk.chunk_id, None)
|
|
if dummy is None:
|
|
return
|
|
self._remove_pop_chunk(dummy)
|
|
|
|
|
|
@functools.total_ordering
|
|
class MediaBlockHolder(object):
|
|
waiters: collections.deque[asyncio.Future]
|
|
chunk_id: int = 0
|
|
|
|
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int) -> None:
|
|
self.chat_id = chat_id
|
|
self.msg_id = msg_id
|
|
self.start = start
|
|
self.target_len = target_len
|
|
self.mem = bytes()
|
|
self.length = len(self.mem)
|
|
self.waiters = collections.deque()
|
|
|
|
def __repr__(self) -> str:
|
|
return f"MediaBlockHolder,id:{self.chat_id}-{self.msg_id},start:{self.start},len:{self.length}/{self.target_len}"
|
|
|
|
def __eq__(self, other: Union['MediaBlockHolder', int]):
|
|
if isinstance(other, int):
|
|
return self.start == other
|
|
return self.start == other.start
|
|
|
|
def __le__(self, other: Union['MediaBlockHolder', int]):
|
|
if isinstance(other, int):
|
|
return self.start <= other
|
|
return self.start <= other.start
|
|
|
|
def __gt__(self, other: Union['MediaBlockHolder', int]):
|
|
if isinstance(other, int):
|
|
return self.start > other
|
|
return self.start > other.start
|
|
|
|
def __add__(self, other: Union['MediaBlockHolder', bytes]):
|
|
if isinstance(other, MediaBlockHolder):
|
|
other = other.mem
|
|
self.append_mem(other.mem)
|
|
|
|
def is_completed(self) -> bool:
|
|
return self.length >= self.target_len
|
|
|
|
def notify_waiters(self) -> None:
|
|
while self.waiters:
|
|
waiter = self.waiters.popleft()
|
|
if not waiter.done():
|
|
waiter.set_result(None)
|
|
|
|
def append_mem(self, mem: bytes) -> None:
|
|
self.mem = self.mem + mem
|
|
self.length = len(self.mem)
|
|
self.notify_waiters()
|
|
if self.length > self.target_len:
|
|
logger.warning(f"MeidaBlock Overflow:{self}")
|
|
|
|
async def wait_update(self) -> None:
|
|
if self.is_completed():
|
|
return
|
|
waiter = asyncio.Future()
|
|
self.waiters.append(waiter)
|
|
try:
|
|
await waiter
|
|
except:
|
|
waiter.cancel()
|
|
try:
|
|
self.waiters.remove(waiter)
|
|
except ValueError:
|
|
pass
|
|
|
|
@functools.total_ordering
|
|
class BlockInfo(object):
|
|
def __init__(self, hashid: int, offset: int, length: int, in_mem: bool) -> None:
|
|
self.hashid = hashid
|
|
self.offset = offset
|
|
self.length = length
|
|
self.in_mem = in_mem
|
|
|
|
def __eq__(self, other: Union['BlockInfo', int]):
|
|
if isinstance(other, int):
|
|
return self.offset == other
|
|
return self.offset == other.offset
|
|
|
|
def __le__(self, other: Union['BlockInfo', int]):
|
|
if isinstance(other, int):
|
|
return self.offset <= other
|
|
return self.offset <= other.offset
|
|
|
|
class MediaBlockHolderManager(object):
|
|
|
|
DEFAULT_MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb
|
|
# chat_id -> msg_id -> list[BlockInfo]
|
|
chunk_cache: dict[int, dict[int, list[BlockInfo]]] = {}
|
|
|
|
def __init__(self, limit_size: int = DEFAULT_MAX_CACHE_SIZE, dir: str = 'cache') -> None:
|
|
pass
|
|
|