277 lines
10 KiB
Python
277 lines
10 KiB
Python
import os
|
|
import functools
|
|
import logging
|
|
import bisect
|
|
import collections
|
|
import asyncio
|
|
import traceback
|
|
import hashlib
|
|
import collections
|
|
from typing import Union, Optional, Callable
|
|
|
|
import diskcache
|
|
from fastapi import Request
|
|
from telethon import types
|
|
|
|
logger = logging.getLogger(__file__.split("/")[-1])
|
|
|
|
|
|
@functools.total_ordering
|
|
class ChunkInfo(object):
|
|
def __init__(self, md5id: str, chat_id: int, msg_id: int, start: int, length: int) -> None:
|
|
self.id = md5id
|
|
self.chat_id = chat_id
|
|
self.msg_id = msg_id
|
|
self.start = start
|
|
self.length = length
|
|
|
|
def __repr__(self) -> str:
|
|
return f"chunkinfo:id:{self.id},cid:{self.chat_id},mid:{self.msg_id},offset:{self.start},len:{self.length}"
|
|
|
|
def __eq__(self, other: Union["ChunkInfo", int]):
|
|
if isinstance(other, int):
|
|
return self.start == other
|
|
return self.start == other.start
|
|
|
|
def __le__(self, other: Union["ChunkInfo", int]):
|
|
if isinstance(other, int):
|
|
return self.start <= other
|
|
return self.start <= other.start
|
|
|
|
|
|
@functools.total_ordering
|
|
class MediaChunkHolder(object):
|
|
waiters: collections.deque[asyncio.Future]
|
|
requester: list[Request] = []
|
|
unique_id: str = ""
|
|
info: ChunkInfo
|
|
callback: Callable = None
|
|
|
|
@staticmethod
|
|
def generate_id(chat_id: int, msg_id: int, start: int) -> str:
|
|
return f"{chat_id}:{msg_id}:{start}"
|
|
|
|
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, callback: Callable = None) -> None:
|
|
self.unique_id = MediaChunkHolder.generate_id(chat_id, msg_id, start)
|
|
self.info = ChunkInfo(hashlib.md5(self.unique_id.encode()).hexdigest(), chat_id, msg_id, start, target_len)
|
|
self.mem = bytes()
|
|
self.length = len(self.mem)
|
|
self.waiters = collections.deque()
|
|
self.callback = callback
|
|
|
|
def __repr__(self) -> str:
|
|
return f"MediaChunk,{self.info},len:{self.length}"
|
|
|
|
def __eq__(self, other: Union["MediaChunkHolder", ChunkInfo, int]):
|
|
if isinstance(other, int):
|
|
return self.info.start == other
|
|
if isinstance(other, ChunkInfo):
|
|
return self.info.start == other.start
|
|
return self.info.start == other.info.start
|
|
|
|
def __le__(self, other: Union["MediaChunkHolder", ChunkInfo, int]):
|
|
if isinstance(other, int):
|
|
return self.info.start <= other
|
|
if isinstance(other, ChunkInfo):
|
|
return self.info.start <= other.start
|
|
return self.info.start <= other.info.start
|
|
|
|
def is_completed(self) -> bool:
|
|
return self.length >= self.info.length
|
|
|
|
@property
|
|
def chunk_id(self) -> str:
|
|
return self.info.id
|
|
|
|
@property
|
|
def start(self) -> int:
|
|
return self.info.start
|
|
|
|
@property
|
|
def target_len(self) -> int:
|
|
return self.info.length
|
|
|
|
def notify_waiters(self) -> None:
|
|
while self.waiters:
|
|
waiter = self.waiters.popleft()
|
|
if not waiter.done():
|
|
waiter.set_result(None)
|
|
|
|
def append_chunk_mem(self, mem: bytes) -> None:
|
|
self.mem = self.mem + mem
|
|
self.length = len(self.mem)
|
|
if self.length > self.target_len:
|
|
logger.warning(RuntimeWarning(
|
|
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}"))
|
|
self.notify_waiters()
|
|
|
|
def add_chunk_requester(self, req: Request) -> None:
|
|
if self.is_completed():
|
|
return
|
|
self.requester.append(req)
|
|
|
|
async def is_disconneted(self) -> bool:
|
|
while self.requester:
|
|
req = self.requester[0]
|
|
if not await req.is_disconnected():
|
|
return False
|
|
try:
|
|
self.requester.remove(req)
|
|
except Exception as err:
|
|
logger.warning(f"{err=}, trace:{traceback.format_exc()}")
|
|
return False
|
|
return True
|
|
|
|
async def wait_chunk_update(self) -> None:
|
|
if self.is_completed():
|
|
return
|
|
waiter = asyncio.Future()
|
|
self.waiters.append(waiter)
|
|
try:
|
|
await waiter
|
|
except:
|
|
waiter.cancel()
|
|
try:
|
|
self.waiters.remove(waiter)
|
|
except ValueError:
|
|
pass
|
|
|
|
def set_done(self) -> None:
|
|
if self.callback is None:
|
|
return
|
|
callback = self.callback
|
|
self.callback = None
|
|
callback(self)
|
|
|
|
def can_store_in_disk(self) -> bool:
|
|
if not self.is_completed():
|
|
return False
|
|
if not self.is_disconneted():
|
|
return False
|
|
# clear all waiter and requester
|
|
self.notify_waiters()
|
|
return True
|
|
|
|
|
|
class MediaChunkHolderManager(object):
|
|
MAX_CACHE_SIZE = 2**32 # 4GB
|
|
current_cache_size: int = 0
|
|
# chunk unique id -> ChunkHolder
|
|
disk_chunk_cache: diskcache.Cache
|
|
# incompleted chunk
|
|
incompleted_chunk: dict[str, MediaChunkHolder] = {}
|
|
# chunk id -> ChunkInfo
|
|
chunk_lru: collections.OrderedDict[str, ChunkInfo]
|
|
# chat_id -> msg_id -> list[ChunkInfo]
|
|
chunk_cache: dict[int, dict[int, list[ChunkInfo]]] = {}
|
|
|
|
def __init__(self) -> None:
|
|
self.chunk_lru = collections.OrderedDict()
|
|
self.disk_chunk_cache = diskcache.Cache(f"{os.path.dirname(__file__)}/cache_media")
|
|
self._restore_cache()
|
|
|
|
def _restore_cache(self) -> None:
|
|
for id in self.disk_chunk_cache.iterkeys():
|
|
try:
|
|
holder: MediaChunkHolder = self.disk_chunk_cache.get(id)
|
|
if holder is not None:
|
|
self._set_media_chunk_index(holder.info)
|
|
except Exception as err:
|
|
logger.warning(f"restore, {err=},{traceback.format_exc()}")
|
|
|
|
def get_chunk_holder_by_info(self, info: ChunkInfo) -> MediaChunkHolder:
|
|
holder = self.incompleted_chunk.get(info.id)
|
|
if holder is not None:
|
|
return holder
|
|
holder = self.disk_chunk_cache.get(info.id)
|
|
return holder
|
|
|
|
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[ChunkInfo]]:
|
|
chat_cache = self.chunk_cache.get(msg.chat_id)
|
|
if chat_cache is None:
|
|
return None
|
|
return chat_cache.get(msg.id)
|
|
|
|
def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional[MediaChunkHolder]:
|
|
msg_cache = self._get_media_msg_cache(msg)
|
|
if msg_cache is None or len(msg_cache) == 0:
|
|
return None
|
|
pos = bisect.bisect_left(msg_cache, start)
|
|
if pos == len(msg_cache):
|
|
pos = pos - 1
|
|
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].length > start:
|
|
return self.get_chunk_holder_by_info(msg_cache[pos])
|
|
return None
|
|
elif msg_cache[pos].start == start:
|
|
return self.get_chunk_holder_by_info(msg_cache[pos])
|
|
elif pos > 0:
|
|
pos = pos - 1
|
|
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].length > start:
|
|
return self.get_chunk_holder_by_info(msg_cache[pos])
|
|
return None
|
|
return None
|
|
|
|
def _remove_pop_chunk(self, pop_chunk: ChunkInfo = None) -> None:
|
|
try:
|
|
if pop_chunk is None:
|
|
dummy = self.chunk_lru.popitem(last=False)
|
|
pop_chunk = dummy[1]
|
|
self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove(pop_chunk.start)
|
|
self.current_cache_size -= pop_chunk.length
|
|
if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0:
|
|
self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id)
|
|
if len(self.chunk_cache[pop_chunk.chat_id]) == 0:
|
|
self.chunk_cache.pop(pop_chunk.chat_id)
|
|
pop_holder = self.incompleted_chunk.get(pop_chunk.id)
|
|
if pop_holder is not None:
|
|
self.incompleted_chunk.pop(pop_chunk.id)
|
|
return
|
|
suc = self.disk_chunk_cache.delete(pop_chunk.id)
|
|
if not suc:
|
|
logger.warning(f"could not del, {pop_chunk}")
|
|
except Exception as err:
|
|
logger.warning(f"remove chunk,{err=},{traceback.format_exc()}")
|
|
|
|
def create_media_chunk_holder(self, chat_id: int, msg_id: int, start: int, target_len: int) -> MediaChunkHolder:
|
|
def holder_completed_callback(holder: MediaChunkHolder):
|
|
cache_holder = self.incompleted_chunk.pop(holder.chunk_id, None)
|
|
if cache_holder is None:
|
|
logger.warning(f"the holder not in mem, {holder}")
|
|
return
|
|
logger.info(f"cache new chunk:{holder}")
|
|
self.disk_chunk_cache.set(holder.chunk_id, holder)
|
|
|
|
return MediaChunkHolder(chat_id, msg_id, start, target_len, callback=holder_completed_callback)
|
|
|
|
def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional[MediaChunkHolder]:
|
|
res = self._get_media_chunk_cache(msg, start)
|
|
logger.debug(f"get_media_chunk:{res}")
|
|
if res is None:
|
|
return None
|
|
if lru:
|
|
self.chunk_lru.move_to_end(res.chunk_id)
|
|
return res
|
|
|
|
def _set_media_chunk_index(self, info: ChunkInfo) -> None:
|
|
self.chunk_lru[info.id] = info
|
|
self.chunk_cache.setdefault(info.chat_id, {})
|
|
self.chunk_cache[info.chat_id].setdefault(info.msg_id, [])
|
|
bisect.insort(self.chunk_cache[info.chat_id][info.msg_id], info)
|
|
self.current_cache_size += info.length
|
|
|
|
def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
|
|
can_store = chunk.can_store_in_disk()
|
|
if can_store:
|
|
self.disk_chunk_cache.set(chunk.chunk_id, chunk)
|
|
else:
|
|
self.incompleted_chunk[chunk.chunk_id] = chunk
|
|
self._set_media_chunk_index(chunk.info)
|
|
while self.current_cache_size > self.MAX_CACHE_SIZE:
|
|
self._remove_pop_chunk()
|
|
|
|
def cancel_media_chunk(self, chunk: MediaChunkHolder) -> None:
|
|
dummy = self.chunk_lru.pop(chunk.chunk_id, None)
|
|
if dummy is None:
|
|
return
|
|
self._remove_pop_chunk(dummy)
|