chore: move MeidaChunkHolder and Manager to MediaCacheManager.py

This commit is contained in:
hehesheng 2024-06-01 16:33:28 +08:00
parent f48a35ad17
commit d6e46533df
6 changed files with 227 additions and 208 deletions

View File

@ -1,12 +1,197 @@
import functools
import logging
import bisect
import collections
import asyncio
import collections
from typing import Union, Optional
import diskcache
from fastapi import Request
from telethon import types
logger = logging.getLogger(__file__.split("/")[-1])
@functools.total_ordering
class MediaChunkHolder(object):
waiters: collections.deque[asyncio.Future]
requester: list[Request] = []
chunk_id: int = 0
is_done: bool = False
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
self.chat_id = chat_id
self.msg_id = msg_id
self.start = start
self.target_len = target_len
self.mem = mem or bytes()
self.length = len(self.mem)
self.waiters = collections.deque()
def __repr__(self) -> str:
return f"MediaChunk,start:{self.start},len:{self.length}"
def __eq__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start == other
return self.start == other.start
def __le__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start <= other
return self.start <= other.start
def __gt__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start > other
return self.start > other.start
def __add__(self, other: Union['MediaChunkHolder', bytes]):
if isinstance(other, MediaChunkHolder):
other = other.mem
self.append_chunk_mem(other)
def is_completed(self) -> bool:
return self.length >= self.target_len
def set_done(self) -> None:
# self.is_done = True
# self.notify_waiters()
self.requester.clear()
def notify_waiters(self) -> None:
while self.waiters:
waiter = self.waiters.popleft()
if not waiter.done():
waiter.set_result(None)
def _set_chunk_mem(self, mem: Optional[bytes]) -> None:
self.mem = mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
def append_chunk_mem(self, mem: bytes) -> None:
self.mem = self.mem + mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
self.notify_waiters()
def add_chunk_requester(self, req: Request) -> None:
self.requester.append(req)
async def is_disconneted(self) -> bool:
while self.requester:
res = await self.requester[0].is_disconnected()
if res:
self.requester.pop(0)
continue
return res
return True
async def wait_chunk_update(self) -> None:
if self.is_done:
return
waiter = asyncio.Future()
self.waiters.append(waiter)
try:
await waiter
except:
waiter.cancel()
try:
self.waiters.remove(waiter)
except ValueError:
pass
class MediaChunkHolderManager(object):
MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb
current_cache_size: int = 0
# chat_id -> msg_id -> offset -> mem
chunk_cache: dict[int, dict[int,
list[MediaChunkHolder]]] = {}
# ChunkHolderId -> ChunkHolder
unique_chunk_id: int = 0
chunk_lru: collections.OrderedDict[int, MediaChunkHolder]
def __init__(self) -> None:
self.chunk_lru = collections.OrderedDict()
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]:
chat_cache = self.chunk_cache.get(msg.chat_id)
if chat_cache is None:
return None
return chat_cache.get(msg.id)
def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional[MediaChunkHolder]:
msg_cache = self._get_media_msg_cache(msg)
if msg_cache is None or len(msg_cache) == 0:
return None
pos = bisect.bisect_left(msg_cache, start)
if pos == len(msg_cache):
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
elif msg_cache[pos].start == start:
return msg_cache[pos]
elif pos > 0:
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
return None
def _remove_pop_chunk(self, pop_chunk: MediaChunkHolder) -> None:
self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove(
pop_chunk.start)
self.current_cache_size -= pop_chunk.target_len
if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0:
self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id)
if len(self.chunk_cache[pop_chunk.chat_id]) == 0:
self.chunk_cache.pop(pop_chunk.chat_id)
def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional[MediaChunkHolder]:
res = self._get_media_chunk_cache(msg, start)
if res is None:
return None
if lru:
self.chunk_lru.move_to_end(res.chunk_id)
return res
def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
self.chunk_cache[chunk.chat_id] = {}
cache_chat = self.chunk_cache[chunk.chat_id]
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
cache_chat[chunk.msg_id] = []
cache_msg = cache_chat[chunk.msg_id]
chunk.chunk_id = self.unique_chunk_id
self.unique_chunk_id += 1
bisect.insort(cache_msg, chunk)
self.chunk_lru[chunk.chunk_id] = chunk
self.current_cache_size += chunk.target_len
while self.current_cache_size > self.MAX_CACHE_SIZE:
dummy = self.chunk_lru.popitem(last=False)
self._remove_pop_chunk(dummy[1])
def cancel_media_chunk(self, chunk: MediaChunkHolder) -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
return
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
return
dummy = self.chunk_lru.pop(chunk.chunk_id, None)
if dummy is None:
return
self._remove_pop_chunk(dummy)
@functools.total_ordering
class MediaBlockHolder(object):
waiters: collections.deque[asyncio.Future]
@ -24,28 +209,25 @@ class MediaBlockHolder(object):
def __repr__(self) -> str:
return f"MediaBlockHolder,id:{self.chat_id}-{self.msg_id},start:{self.start},len:{self.length}/{self.target_len}"
def __eq__(self, other: 'MediaBlockHolder'|int):
def __eq__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int):
return self.start == other
return self.start == other.start
def __le__(self, other: 'MediaBlockHolder'|int):
def __le__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int):
return self.start <= other
return self.start <= other.start
def __gt__(self, other: 'MediaBlockHolder'|int):
def __gt__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int):
return self.start > other
return self.start > other.start
def __add__(self, other: 'MediaBlockHolder'|bytes):
if isinstance(other, bytes):
self.append_mem(other)
elif isinstance(other, MediaBlockHolder):
self.append_mem(other.mem)
else:
raise RuntimeError(f"{self} can't add {type(other)}")
def __add__(self, other: Union['MediaBlockHolder', bytes]):
if isinstance(other, MediaBlockHolder):
other = other.mem
self.append_mem(other.mem)
def is_completed(self) -> bool:
return self.length >= self.target_len
@ -85,12 +267,12 @@ class BlockInfo(object):
self.length = length
self.in_mem = in_mem
def __eq__(self, other: 'BlockInfo'|int):
def __eq__(self, other: Union['BlockInfo', int]):
if isinstance(other, int):
return self.offset == other
return self.offset == other.offset
def __le__(self, other: 'BlockInfo'|int):
def __le__(self, other: Union['BlockInfo', int]):
if isinstance(other, int):
return self.offset <= other
return self.offset <= other.offset

View File

@ -1,15 +1,12 @@
import asyncio
import json
import bisect
import time
import re
import rsa
import os
import functools
import collections
import traceback
import logging
from collections import OrderedDict
from typing import Union, Optional
from telethon import TelegramClient, types, hints, events
@ -18,192 +15,11 @@ from fastapi import Request
import configParse
from backend import apiutils
from backend.UserManager import UserManager
from backend.MediaCacheManager import MediaChunkHolder, MediaChunkHolderManager
logger = logging.getLogger(__file__.split("/")[-1])
class TgFileSystemClient(object):
@functools.total_ordering
class MediaChunkHolder(object):
waiters: collections.deque[asyncio.Future]
requester: list[Request] = []
chunk_id: int = 0
is_done: bool = False
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
self.chat_id = chat_id
self.msg_id = msg_id
self.start = start
self.target_len = target_len
self.mem = mem or bytes()
self.length = len(self.mem)
self.waiters = collections.deque()
def __repr__(self) -> str:
return f"MediaChunk,start:{self.start},len:{self.length}"
def __eq__(self, other: 'TgFileSystemClient.MediaChunkHolder'):
if isinstance(other, int):
return self.start == other
return self.start == other.start
def __le__(self, other: 'TgFileSystemClient.MediaChunkHolder'):
if isinstance(other, int):
return self.start <= other
return self.start <= other.start
def __gt__(self, other: 'TgFileSystemClient.MediaChunkHolder'):
if isinstance(other, int):
return self.start > other
return self.start > other.start
def __add__(self, other):
if isinstance(other, bytes):
self.append_chunk_mem(other)
elif isinstance(other, TgFileSystemClient.MediaChunkHolder):
self.append_chunk_mem(other.mem)
else:
raise RuntimeError("does not suported this type to add")
def is_completed(self) -> bool:
return self.length >= self.target_len
def set_done(self) -> None:
# self.is_done = True
# self.notify_waiters()
self.requester.clear()
def notify_waiters(self) -> None:
while self.waiters:
waiter = self.waiters.popleft()
if not waiter.done():
waiter.set_result(None)
def _set_chunk_mem(self, mem: Optional[bytes]) -> None:
self.mem = mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
def append_chunk_mem(self, mem: bytes) -> None:
self.mem = self.mem + mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
self.notify_waiters()
def add_chunk_requester(self, req: Request) -> None:
self.requester.append(req)
async def is_disconneted(self) -> bool:
while self.requester:
res = await self.requester[0].is_disconnected()
if res:
self.requester.pop(0)
continue
return res
return True
async def wait_chunk_update(self) -> None:
if self.is_done:
return
waiter = asyncio.Future()
self.waiters.append(waiter)
try:
await waiter
except:
waiter.cancel()
try:
self.waiters.remove(waiter)
except ValueError:
pass
class MediaChunkHolderManager(object):
MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb
current_cache_size: int = 0
# chat_id -> msg_id -> offset -> mem
chunk_cache: dict[int, dict[int,
list['TgFileSystemClient.MediaChunkHolder']]] = {}
# ChunkHolderId -> ChunkHolder
unique_chunk_id: int = 0
chunk_lru: OrderedDict[int, 'TgFileSystemClient.MediaChunkHolder']
def __init__(self) -> None:
self.chunk_lru = OrderedDict()
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list['TgFileSystemClient.MediaChunkHolder']]:
chat_cache = self.chunk_cache.get(msg.chat_id)
if chat_cache is None:
return None
return chat_cache.get(msg.id)
def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional['TgFileSystemClient.MediaChunkHolder']:
msg_cache = self._get_media_msg_cache(msg)
if msg_cache is None or len(msg_cache) == 0:
return None
pos = bisect.bisect_left(msg_cache, start)
if pos == len(msg_cache):
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
elif msg_cache[pos].start == start:
return msg_cache[pos]
elif pos > 0:
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
return None
def _remove_pop_chunk(self, pop_chunk: 'TgFileSystemClient.MediaChunkHolder') -> None:
self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove(
pop_chunk.start)
self.current_cache_size -= pop_chunk.target_len
if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0:
self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id)
if len(self.chunk_cache[pop_chunk.chat_id]) == 0:
self.chunk_cache.pop(pop_chunk.chat_id)
def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional['TgFileSystemClient.MediaChunkHolder']:
res = self._get_media_chunk_cache(msg, start)
if res is None:
return None
if lru:
self.chunk_lru.move_to_end(res.chunk_id)
return res
def set_media_chunk(self, chunk: 'TgFileSystemClient.MediaChunkHolder') -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
self.chunk_cache[chunk.chat_id] = {}
cache_chat = self.chunk_cache[chunk.chat_id]
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
cache_chat[chunk.msg_id] = []
cache_msg = cache_chat[chunk.msg_id]
chunk.chunk_id = self.unique_chunk_id
self.unique_chunk_id += 1
bisect.insort(cache_msg, chunk)
self.chunk_lru[chunk.chunk_id] = chunk
self.current_cache_size += chunk.target_len
while self.current_cache_size > self.MAX_CACHE_SIZE:
dummy = self.chunk_lru.popitem(last=False)
self._remove_pop_chunk(dummy[1])
def cancel_media_chunk(self, chunk: 'TgFileSystemClient.MediaChunkHolder') -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
return
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
return
dummy = self.chunk_lru.pop(chunk.chunk_id, None)
if dummy is None:
return
self._remove_pop_chunk(dummy)
MAX_WORKER_ROUTINE = 4
SINGLE_NET_CHUNK_SIZE = 256 * 1024 # 256kb
SINGLE_MEDIA_SIZE = 5 * 1024 * 1024 # 5mb
@ -236,7 +52,7 @@ class TgFileSystemClient(object):
self.task_queue = asyncio.Queue()
self.client = TelegramClient(
f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param)
self.media_chunk_manager = TgFileSystemClient.MediaChunkHolderManager()
self.media_chunk_manager = MediaChunkHolderManager()
self.db = db
def __del__(self) -> None:
@ -315,6 +131,7 @@ class TgFileSystemClient(object):
t.cancel()
except Exception as err:
logger.error(f"{err=}")
logger.error(traceback.format_exc())
async def _cache_whitelist_chat2(self):
for chat_id in self.client_param.whitelist_chat:
@ -366,6 +183,7 @@ class TgFileSystemClient(object):
await task[1]
except Exception as err:
logger.error(f"{err=}")
logger.error(traceback.format_exc())
finally:
self.task_queue.task_done()
@ -434,7 +252,7 @@ class TgFileSystemClient(object):
self.media_chunk_manager.cancel_media_chunk(media_holder)
except Exception as err:
logger.error(
f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{traceback.format_exc()}")
f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}")
finally:
media_holder.set_done()
logger.debug(
@ -457,7 +275,7 @@ class TgFileSystemClient(object):
align_pos = pos
align_size = min(self.SINGLE_MEDIA_SIZE,
file_size - align_pos)
holder = TgFileSystemClient.MediaChunkHolder(
holder = MediaChunkHolder(
msg.chat_id, msg.id, align_pos, align_size)
holder.add_chunk_requester(req)
self.media_chunk_manager.set_media_chunk(holder)
@ -486,8 +304,8 @@ class TgFileSystemClient(object):
pos = pos + need_len
yield cache_chunk.mem[offset:offset+need_len]
except Exception as err:
traceback.print_exc()
logger.error(f"stream iter:{err=}")
logger.error(traceback.format_exc())
finally:
async def _cancel_task_by_id(task_id: int):
for _ in range(self.task_queue.qsize()):

View File

@ -64,9 +64,11 @@ async def search_tg_file_list(body: TgToFileListRequestBody):
for item in res:
msg_info = json.loads(item)
file_name = apiutils.get_message_media_name_from_dict(msg_info)
chat_id = apiutils.get_message_chat_id_from_dict(msg_info)
msg_id = apiutils.get_message_msg_id_from_dict(msg_info)
msg_info['file_name'] = file_name
msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{msg_info.get('id')}/{file_name}"
msg_info['src_tg_link'] = f"https://t.me/c/1216816802/21206"
msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}"
msg_info['src_tg_link'] = f"https://t.me/c/{chat_id}/{msg_id}"
res_dict.append(msg_info)
client_dict = json.loads(client.to_json())

View File

@ -52,6 +52,20 @@ def get_message_media_name_from_dict(msg: dict[str, any]) -> str:
file_name = "unknown.tmp"
return file_name
def get_message_chat_id_from_dict(msg: dict[str, any]) -> int:
try:
return msg['peer_id']['channel_id']
except:
pass
return 0
def get_message_msg_id_from_dict(msg: dict[str, any]) -> int:
try:
return msg['id']
except:
pass
return 0
def timeit_sec(func):
@wraps(func)
def timeit_wrapper(*args, **kwargs):

View File

@ -94,7 +94,7 @@ def do_search_req():
st.session_state.force_skip = True
st.rerun()
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str):
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str):
file_size_str = f"{file_size/1024/1024:.2f}MB"
container = st.container()
container_columns = container.columns([1, 99])
@ -104,7 +104,7 @@ def do_search_req():
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*"
popover = container_columns[1].popover(expender_title, use_container_width=True)
popover_columns = popover.columns([1, 3])
popover_columns = popover.columns([1, 3, 1])
if url:
popover_columns[0].video(url)
else:
@ -112,7 +112,8 @@ def do_search_req():
popover_columns[1].markdown(f'{msg_ctx}')
popover_columns[1].markdown(f'**{file_name}**')
popover_columns[1].markdown(f'文件大小:*{file_size_str}*')
popover_columns[1].link_button('Download Link', url)
popover_columns[2].link_button('Download Link', url, use_container_width=True)
popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True)
@st.experimental_fragment
def show_search_res(res: dict[str, any]):
@ -134,7 +135,9 @@ def do_search_req():
file_name = None
file_size = 0
download_url = ""
src_link = ""
try:
src_link = v['src_tg_link']
msg_ctx = v['message']
msg_id = str(v['id'])
doc = v['media']['document']
@ -152,7 +155,7 @@ def do_search_req():
except Exception as err:
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
media_file_res_container(
i, msg_ctx, file_name, file_size, download_url)
i, msg_ctx, file_name, file_size, download_url, src_link)
page_switch_render()
show_text = ""

View File

@ -16,7 +16,7 @@ with open('logging_config.yaml', 'r') as f:
logging.config.dictConfig(yaml.safe_load(f.read()))
LOGGING_CONFIG["formatters"]["default"]["fmt"] = "[%(levelname)s] %(asctime)s [uvicorn.default]:%(message)s"
LOGGING_CONFIG["formatters"]["access"]["fmt"] = '[%(levelname)s]%(asctime)s [uvicorn.access]:%(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["fmt"] = '[%(levelname)s] %(asctime)s [uvicorn.access]:%(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["handlers"]["timed_rotating_api_file"] = {
"class": "logging.handlers.TimedRotatingFileHandler",
"filename": "logs/app.log",