TgToFileSystem/backend/MediaCacheManager.py

289 lines
9.7 KiB
Python

import functools
import logging
import bisect
import collections
import asyncio
import collections
from typing import Union, Optional
import diskcache
from fastapi import Request
from telethon import types
logger = logging.getLogger(__file__.split("/")[-1])
@functools.total_ordering
class MediaChunkHolder(object):
waiters: collections.deque[asyncio.Future]
requester: list[Request] = []
chunk_id: int = 0
is_done: bool = False
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
self.chat_id = chat_id
self.msg_id = msg_id
self.start = start
self.target_len = target_len
self.mem = mem or bytes()
self.length = len(self.mem)
self.waiters = collections.deque()
def __repr__(self) -> str:
return f"MediaChunk,start:{self.start},len:{self.length}"
def __eq__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start == other
return self.start == other.start
def __le__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start <= other
return self.start <= other.start
def __gt__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start > other
return self.start > other.start
def __add__(self, other: Union['MediaChunkHolder', bytes]):
if isinstance(other, MediaChunkHolder):
other = other.mem
self.append_chunk_mem(other)
def is_completed(self) -> bool:
return self.length >= self.target_len
def set_done(self) -> None:
# self.is_done = True
# self.notify_waiters()
self.requester.clear()
def notify_waiters(self) -> None:
while self.waiters:
waiter = self.waiters.popleft()
if not waiter.done():
waiter.set_result(None)
def _set_chunk_mem(self, mem: Optional[bytes]) -> None:
self.mem = mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
def append_chunk_mem(self, mem: bytes) -> None:
self.mem = self.mem + mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
self.notify_waiters()
def add_chunk_requester(self, req: Request) -> None:
self.requester.append(req)
async def is_disconneted(self) -> bool:
while self.requester:
res = await self.requester[0].is_disconnected()
if res:
self.requester.pop(0)
continue
return res
return True
async def wait_chunk_update(self) -> None:
if self.is_done:
return
waiter = asyncio.Future()
self.waiters.append(waiter)
try:
await waiter
except:
waiter.cancel()
try:
self.waiters.remove(waiter)
except ValueError:
pass
class MediaChunkHolderManager(object):
MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb
current_cache_size: int = 0
# chat_id -> msg_id -> offset -> mem
chunk_cache: dict[int, dict[int,
list[MediaChunkHolder]]] = {}
# ChunkHolderId -> ChunkHolder
unique_chunk_id: int = 0
chunk_lru: collections.OrderedDict[int, MediaChunkHolder]
def __init__(self) -> None:
self.chunk_lru = collections.OrderedDict()
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]:
chat_cache = self.chunk_cache.get(msg.chat_id)
if chat_cache is None:
return None
return chat_cache.get(msg.id)
def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional[MediaChunkHolder]:
msg_cache = self._get_media_msg_cache(msg)
if msg_cache is None or len(msg_cache) == 0:
return None
pos = bisect.bisect_left(msg_cache, start)
if pos == len(msg_cache):
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
elif msg_cache[pos].start == start:
return msg_cache[pos]
elif pos > 0:
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
return None
def _remove_pop_chunk(self, pop_chunk: MediaChunkHolder) -> None:
self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove(
pop_chunk.start)
self.current_cache_size -= pop_chunk.target_len
if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0:
self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id)
if len(self.chunk_cache[pop_chunk.chat_id]) == 0:
self.chunk_cache.pop(pop_chunk.chat_id)
def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional[MediaChunkHolder]:
res = self._get_media_chunk_cache(msg, start)
if res is None:
return None
if lru:
self.chunk_lru.move_to_end(res.chunk_id)
return res
def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
self.chunk_cache[chunk.chat_id] = {}
cache_chat = self.chunk_cache[chunk.chat_id]
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
cache_chat[chunk.msg_id] = []
cache_msg = cache_chat[chunk.msg_id]
chunk.chunk_id = self.unique_chunk_id
self.unique_chunk_id += 1
bisect.insort(cache_msg, chunk)
self.chunk_lru[chunk.chunk_id] = chunk
self.current_cache_size += chunk.target_len
while self.current_cache_size > self.MAX_CACHE_SIZE:
dummy = self.chunk_lru.popitem(last=False)
self._remove_pop_chunk(dummy[1])
def cancel_media_chunk(self, chunk: MediaChunkHolder) -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
return
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
return
dummy = self.chunk_lru.pop(chunk.chunk_id, None)
if dummy is None:
return
self._remove_pop_chunk(dummy)
@functools.total_ordering
class MediaBlockHolder(object):
waiters: collections.deque[asyncio.Future]
chunk_id: int = 0
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int) -> None:
self.chat_id = chat_id
self.msg_id = msg_id
self.start = start
self.target_len = target_len
self.mem = bytes()
self.length = len(self.mem)
self.waiters = collections.deque()
def __repr__(self) -> str:
return f"MediaBlockHolder,id:{self.chat_id}-{self.msg_id},start:{self.start},len:{self.length}/{self.target_len}"
def __eq__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int):
return self.start == other
return self.start == other.start
def __le__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int):
return self.start <= other
return self.start <= other.start
def __gt__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int):
return self.start > other
return self.start > other.start
def __add__(self, other: Union['MediaBlockHolder', bytes]):
if isinstance(other, MediaBlockHolder):
other = other.mem
self.append_mem(other.mem)
def is_completed(self) -> bool:
return self.length >= self.target_len
def notify_waiters(self) -> None:
while self.waiters:
waiter = self.waiters.popleft()
if not waiter.done():
waiter.set_result(None)
def append_mem(self, mem: bytes) -> None:
self.mem = self.mem + mem
self.length = len(self.mem)
self.notify_waiters()
if self.length > self.target_len:
logger.warning(f"MeidaBlock Overflow:{self}")
async def wait_update(self) -> None:
if self.is_completed():
return
waiter = asyncio.Future()
self.waiters.append(waiter)
try:
await waiter
except:
waiter.cancel()
try:
self.waiters.remove(waiter)
except ValueError:
pass
@functools.total_ordering
class BlockInfo(object):
def __init__(self, hashid: int, offset: int, length: int, in_mem: bool) -> None:
self.hashid = hashid
self.offset = offset
self.length = length
self.in_mem = in_mem
def __eq__(self, other: Union['BlockInfo', int]):
if isinstance(other, int):
return self.offset == other
return self.offset == other.offset
def __le__(self, other: Union['BlockInfo', int]):
if isinstance(other, int):
return self.offset <= other
return self.offset <= other.offset
class MediaBlockHolderManager(object):
DEFAULT_MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb
# chat_id -> msg_id -> list[BlockInfo]
chunk_cache: dict[int, dict[int, list[BlockInfo]]] = {}
def __init__(self, limit_size: int = DEFAULT_MAX_CACHE_SIZE, dir: str = 'cache') -> None:
pass