TgToFileSystem/backend/MediaCacheManager.py
2024-06-02 18:13:48 +08:00

276 lines
9.9 KiB
Python

import os
import functools
import logging
import bisect
import collections
import asyncio
import traceback
import hashlib
import collections
from typing import Union, Optional, Callable
import diskcache
from fastapi import Request
from telethon import types
logger = logging.getLogger(__file__.split("/")[-1])
@functools.total_ordering
class ChunkInfo(object):
def __init__(self, md5id: str, chat_id: int, msg_id: int, start: int, length: int) -> None:
self.id = md5id
self.chat_id = chat_id
self.msg_id = msg_id
self.start = start
self.length = length
def __repr__(self) -> str:
return f"chunkinfo:id:{self.id},cid:{self.chat_id},mid:{self.msg_id},offset:{self.start},len:{self.length}"
def __eq__(self, other: Union["ChunkInfo", int]):
if isinstance(other, int):
return self.start == other
return self.start == other.start
def __le__(self, other: Union["ChunkInfo", int]):
if isinstance(other, int):
return self.start <= other
return self.start <= other.start
@functools.total_ordering
class MediaChunkHolder(object):
waiters: collections.deque[asyncio.Future]
requester: list[Request] = []
unique_id: str = ""
info: ChunkInfo
callback: Callable = None
@staticmethod
def generate_id(chat_id: int, msg_id: int, start: int) -> str:
return f"{chat_id}:{msg_id}:{start}"
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, callback: Callable = None) -> None:
self.unique_id = MediaChunkHolder.generate_id(chat_id, msg_id, start)
self.info = ChunkInfo(hashlib.md5(self.unique_id.encode()).hexdigest(), chat_id, msg_id, start, target_len)
self.mem = bytes()
self.length = len(self.mem)
self.waiters = collections.deque()
self.callback = callback
def __repr__(self) -> str:
return f"MediaChunk,{self.info},len:{self.length}"
def __eq__(self, other: Union["MediaChunkHolder", ChunkInfo, int]):
if isinstance(other, int):
return self.info.start == other
if isinstance(other, ChunkInfo):
return self.info.start == other.start
return self.info.start == other.info.start
def __le__(self, other: Union["MediaChunkHolder", ChunkInfo, int]):
if isinstance(other, int):
return self.info.start <= other
if isinstance(other, ChunkInfo):
return self.info.start <= other.start
return self.info.start <= other.info.start
def is_completed(self) -> bool:
return self.length >= self.info.length
@property
def chunk_id(self) -> str:
return self.info.id
@property
def start(self) -> int:
return self.info.start
@property
def target_len(self) -> int:
return self.info.length
def notify_waiters(self) -> None:
while self.waiters:
waiter = self.waiters.popleft()
if not waiter.done():
waiter.set_result(None)
def append_chunk_mem(self, mem: bytes) -> None:
self.mem = self.mem + mem
self.length = len(self.mem)
if self.length > self.target_len:
logger.warning(RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}"))
self.notify_waiters()
def add_chunk_requester(self, req: Request) -> None:
if self.is_completed():
return
self.requester.append(req)
async def is_disconneted(self) -> bool:
while self.requester:
req = self.requester[0]
if not await req.is_disconnected():
return False
try:
self.requester.remove(req)
except Exception as err:
logger.warning(f"{err=}, trace:{traceback.format_exc()}")
return False
return True
async def wait_chunk_update(self) -> None:
if self.is_completed():
return
waiter = asyncio.Future()
self.waiters.append(waiter)
try:
await waiter
except:
waiter.cancel()
try:
self.waiters.remove(waiter)
except ValueError:
pass
def set_done(self) -> None:
if self.callback is None:
return
callback = self.callback
self.callback = None
callback(self)
def can_store_in_disk(self) -> bool:
if not self.is_completed():
return False
if not self.is_disconneted():
return False
# clear all waiter and requester
self.notify_waiters()
return True
class MediaChunkHolderManager(object):
MAX_CACHE_SIZE = 2**32 # 4GB
current_cache_size: int = 0
# chunk unique id -> ChunkHolder
disk_chunk_cache: diskcache.Cache
# incompleted chunk
incompleted_chunk: dict[str, MediaChunkHolder] = {}
# chunk id -> ChunkInfo
chunk_lru: collections.OrderedDict[str, ChunkInfo]
# chat_id -> msg_id -> list[ChunkInfo]
chunk_cache: dict[int, dict[int, list[ChunkInfo]]] = {}
def __init__(self) -> None:
self.chunk_lru = collections.OrderedDict()
self.disk_chunk_cache = diskcache.Cache(f"{os.path.dirname(__file__)}/cache_media")
self._restore_cache()
def _restore_cache(self) -> None:
for id in self.disk_chunk_cache.iterkeys():
try:
holder: MediaChunkHolder = self.disk_chunk_cache.get(id)
if holder is not None:
self._set_media_chunk_index(holder.info)
except Exception as err:
logger.warning(f"restore, {err=},{traceback.format_exc()}")
def get_chunk_holder_by_info(self, info: ChunkInfo) -> MediaChunkHolder:
holder = self.incompleted_chunk.get(info.id)
if holder is not None:
return holder
holder = self.disk_chunk_cache.get(info.id)
return holder
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[ChunkInfo]]:
chat_cache = self.chunk_cache.get(msg.chat_id)
if chat_cache is None:
return None
return chat_cache.get(msg.id)
def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional[MediaChunkHolder]:
msg_cache = self._get_media_msg_cache(msg)
if msg_cache is None or len(msg_cache) == 0:
return None
pos = bisect.bisect_left(msg_cache, start)
if pos == len(msg_cache):
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].length > start:
return self.get_chunk_holder_by_info(msg_cache[pos])
return None
elif msg_cache[pos].start == start:
return self.get_chunk_holder_by_info(msg_cache[pos])
elif pos > 0:
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].length > start:
return self.get_chunk_holder_by_info(msg_cache[pos])
return None
return None
def _remove_pop_chunk(self, pop_chunk: ChunkInfo = None) -> None:
try:
if pop_chunk is None:
dummy = self.chunk_lru.popitem(last=False)
pop_chunk = dummy[1]
self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove(pop_chunk.start)
self.current_cache_size -= pop_chunk.length
if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0:
self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id)
if len(self.chunk_cache[pop_chunk.chat_id]) == 0:
self.chunk_cache.pop(pop_chunk.chat_id)
pop_holder = self.incompleted_chunk.get(pop_chunk.id)
if pop_holder is not None:
self.incompleted_chunk.pop(pop_chunk.id)
return
suc = self.disk_chunk_cache.delete(pop_chunk.id)
if not suc:
logger.warning(f"could not del, {pop_chunk}")
except Exception as err:
logger.warning(f"remove chunk,{err=},{traceback.format_exc()}")
def create_media_chunk_holder(self, chat_id: int, msg_id: int, start: int, target_len: int) -> MediaChunkHolder:
def holder_completed_callback(holder: MediaChunkHolder):
cache_holder = self.incompleted_chunk.pop(holder.chunk_id, None)
if cache_holder is None:
logger.warning(f"the holder not in mem, {holder}")
return
self.disk_chunk_cache.set(holder.chunk_id, holder)
return MediaChunkHolder(chat_id, msg_id, start, target_len, callback=holder_completed_callback)
def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional[MediaChunkHolder]:
res = self._get_media_chunk_cache(msg, start)
logger.debug(f"get_media_chunk:{res}")
if res is None:
return None
if lru:
self.chunk_lru.move_to_end(res.chunk_id)
return res
def _set_media_chunk_index(self, info: ChunkInfo) -> None:
self.chunk_lru[info.id] = info
self.chunk_cache.setdefault(info.chat_id, {})
self.chunk_cache[info.chat_id].setdefault(info.msg_id, [])
bisect.insort(self.chunk_cache[info.chat_id][info.msg_id], info)
self.current_cache_size += info.length
def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
can_store = chunk.can_store_in_disk()
if can_store:
self.disk_chunk_cache.set(chunk.chunk_id, chunk)
else:
self.incompleted_chunk[chunk.chunk_id] = chunk
self._set_media_chunk_index(chunk.info)
while self.current_cache_size > self.MAX_CACHE_SIZE:
self._remove_pop_chunk()
def cancel_media_chunk(self, chunk: MediaChunkHolder) -> None:
dummy = self.chunk_lru.pop(chunk.chunk_id, None)
if dummy is None:
return
self._remove_pop_chunk(dummy)