diff --git a/backend/MediaCacheManager.py b/backend/MediaCacheManager.py index 153dca5..20fdd39 100644 --- a/backend/MediaCacheManager.py +++ b/backend/MediaCacheManager.py @@ -1,12 +1,197 @@ import functools import logging +import bisect import collections import asyncio +import collections +from typing import Union, Optional import diskcache +from fastapi import Request +from telethon import types logger = logging.getLogger(__file__.split("/")[-1]) +@functools.total_ordering +class MediaChunkHolder(object): + waiters: collections.deque[asyncio.Future] + requester: list[Request] = [] + chunk_id: int = 0 + is_done: bool = False + + def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None: + self.chat_id = chat_id + self.msg_id = msg_id + self.start = start + self.target_len = target_len + self.mem = mem or bytes() + self.length = len(self.mem) + self.waiters = collections.deque() + + def __repr__(self) -> str: + return f"MediaChunk,start:{self.start},len:{self.length}" + + def __eq__(self, other: 'MediaChunkHolder'): + if isinstance(other, int): + return self.start == other + return self.start == other.start + + def __le__(self, other: 'MediaChunkHolder'): + if isinstance(other, int): + return self.start <= other + return self.start <= other.start + + def __gt__(self, other: 'MediaChunkHolder'): + if isinstance(other, int): + return self.start > other + return self.start > other.start + + def __add__(self, other: Union['MediaChunkHolder', bytes]): + if isinstance(other, MediaChunkHolder): + other = other.mem + self.append_chunk_mem(other) + + def is_completed(self) -> bool: + return self.length >= self.target_len + + def set_done(self) -> None: + # self.is_done = True + # self.notify_waiters() + self.requester.clear() + + def notify_waiters(self) -> None: + while self.waiters: + waiter = self.waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + + def _set_chunk_mem(self, mem: Optional[bytes]) -> None: + self.mem = mem + self.length = len(self.mem) + if self.length > self.target_len: + raise RuntimeWarning( + f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}") + + def append_chunk_mem(self, mem: bytes) -> None: + self.mem = self.mem + mem + self.length = len(self.mem) + if self.length > self.target_len: + raise RuntimeWarning( + f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}") + self.notify_waiters() + + def add_chunk_requester(self, req: Request) -> None: + self.requester.append(req) + + async def is_disconneted(self) -> bool: + while self.requester: + res = await self.requester[0].is_disconnected() + if res: + self.requester.pop(0) + continue + return res + return True + + async def wait_chunk_update(self) -> None: + if self.is_done: + return + waiter = asyncio.Future() + self.waiters.append(waiter) + try: + await waiter + except: + waiter.cancel() + try: + self.waiters.remove(waiter) + except ValueError: + pass + +class MediaChunkHolderManager(object): + MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb + current_cache_size: int = 0 + # chat_id -> msg_id -> offset -> mem + chunk_cache: dict[int, dict[int, + list[MediaChunkHolder]]] = {} + # ChunkHolderId -> ChunkHolder + unique_chunk_id: int = 0 + chunk_lru: collections.OrderedDict[int, MediaChunkHolder] + + def __init__(self) -> None: + self.chunk_lru = collections.OrderedDict() + + def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]: + chat_cache = self.chunk_cache.get(msg.chat_id) + if chat_cache is None: + return None + return chat_cache.get(msg.id) + + def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional[MediaChunkHolder]: + msg_cache = self._get_media_msg_cache(msg) + if msg_cache is None or len(msg_cache) == 0: + return None + pos = bisect.bisect_left(msg_cache, start) + if pos == len(msg_cache): + pos = pos - 1 + if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start: + return msg_cache[pos] + return None + elif msg_cache[pos].start == start: + return msg_cache[pos] + elif pos > 0: + pos = pos - 1 + if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start: + return msg_cache[pos] + return None + return None + + def _remove_pop_chunk(self, pop_chunk: MediaChunkHolder) -> None: + self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove( + pop_chunk.start) + self.current_cache_size -= pop_chunk.target_len + if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0: + self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id) + if len(self.chunk_cache[pop_chunk.chat_id]) == 0: + self.chunk_cache.pop(pop_chunk.chat_id) + + def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional[MediaChunkHolder]: + res = self._get_media_chunk_cache(msg, start) + if res is None: + return None + if lru: + self.chunk_lru.move_to_end(res.chunk_id) + return res + + def set_media_chunk(self, chunk: MediaChunkHolder) -> None: + cache_chat = self.chunk_cache.get(chunk.chat_id) + if cache_chat is None: + self.chunk_cache[chunk.chat_id] = {} + cache_chat = self.chunk_cache[chunk.chat_id] + cache_msg = cache_chat.get(chunk.msg_id) + if cache_msg is None: + cache_chat[chunk.msg_id] = [] + cache_msg = cache_chat[chunk.msg_id] + chunk.chunk_id = self.unique_chunk_id + self.unique_chunk_id += 1 + bisect.insort(cache_msg, chunk) + self.chunk_lru[chunk.chunk_id] = chunk + self.current_cache_size += chunk.target_len + while self.current_cache_size > self.MAX_CACHE_SIZE: + dummy = self.chunk_lru.popitem(last=False) + self._remove_pop_chunk(dummy[1]) + + def cancel_media_chunk(self, chunk: MediaChunkHolder) -> None: + cache_chat = self.chunk_cache.get(chunk.chat_id) + if cache_chat is None: + return + cache_msg = cache_chat.get(chunk.msg_id) + if cache_msg is None: + return + dummy = self.chunk_lru.pop(chunk.chunk_id, None) + if dummy is None: + return + self._remove_pop_chunk(dummy) + + @functools.total_ordering class MediaBlockHolder(object): waiters: collections.deque[asyncio.Future] @@ -24,28 +209,25 @@ class MediaBlockHolder(object): def __repr__(self) -> str: return f"MediaBlockHolder,id:{self.chat_id}-{self.msg_id},start:{self.start},len:{self.length}/{self.target_len}" - def __eq__(self, other: 'MediaBlockHolder'|int): + def __eq__(self, other: Union['MediaBlockHolder', int]): if isinstance(other, int): return self.start == other return self.start == other.start - def __le__(self, other: 'MediaBlockHolder'|int): + def __le__(self, other: Union['MediaBlockHolder', int]): if isinstance(other, int): return self.start <= other return self.start <= other.start - def __gt__(self, other: 'MediaBlockHolder'|int): + def __gt__(self, other: Union['MediaBlockHolder', int]): if isinstance(other, int): return self.start > other return self.start > other.start - def __add__(self, other: 'MediaBlockHolder'|bytes): - if isinstance(other, bytes): - self.append_mem(other) - elif isinstance(other, MediaBlockHolder): - self.append_mem(other.mem) - else: - raise RuntimeError(f"{self} can't add {type(other)}") + def __add__(self, other: Union['MediaBlockHolder', bytes]): + if isinstance(other, MediaBlockHolder): + other = other.mem + self.append_mem(other.mem) def is_completed(self) -> bool: return self.length >= self.target_len @@ -85,12 +267,12 @@ class BlockInfo(object): self.length = length self.in_mem = in_mem - def __eq__(self, other: 'BlockInfo'|int): + def __eq__(self, other: Union['BlockInfo', int]): if isinstance(other, int): return self.offset == other return self.offset == other.offset - def __le__(self, other: 'BlockInfo'|int): + def __le__(self, other: Union['BlockInfo', int]): if isinstance(other, int): return self.offset <= other return self.offset <= other.offset diff --git a/backend/TgFileSystemClient.py b/backend/TgFileSystemClient.py index bc26b79..97acf8f 100644 --- a/backend/TgFileSystemClient.py +++ b/backend/TgFileSystemClient.py @@ -1,15 +1,12 @@ import asyncio import json -import bisect import time import re import rsa import os import functools -import collections import traceback import logging -from collections import OrderedDict from typing import Union, Optional from telethon import TelegramClient, types, hints, events @@ -18,192 +15,11 @@ from fastapi import Request import configParse from backend import apiutils from backend.UserManager import UserManager +from backend.MediaCacheManager import MediaChunkHolder, MediaChunkHolderManager logger = logging.getLogger(__file__.split("/")[-1]) class TgFileSystemClient(object): - @functools.total_ordering - class MediaChunkHolder(object): - waiters: collections.deque[asyncio.Future] - requester: list[Request] = [] - chunk_id: int = 0 - is_done: bool = False - - def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None: - self.chat_id = chat_id - self.msg_id = msg_id - self.start = start - self.target_len = target_len - self.mem = mem or bytes() - self.length = len(self.mem) - self.waiters = collections.deque() - - def __repr__(self) -> str: - return f"MediaChunk,start:{self.start},len:{self.length}" - - def __eq__(self, other: 'TgFileSystemClient.MediaChunkHolder'): - if isinstance(other, int): - return self.start == other - return self.start == other.start - - def __le__(self, other: 'TgFileSystemClient.MediaChunkHolder'): - if isinstance(other, int): - return self.start <= other - return self.start <= other.start - - def __gt__(self, other: 'TgFileSystemClient.MediaChunkHolder'): - if isinstance(other, int): - return self.start > other - return self.start > other.start - - def __add__(self, other): - if isinstance(other, bytes): - self.append_chunk_mem(other) - elif isinstance(other, TgFileSystemClient.MediaChunkHolder): - self.append_chunk_mem(other.mem) - else: - raise RuntimeError("does not suported this type to add") - - def is_completed(self) -> bool: - return self.length >= self.target_len - - def set_done(self) -> None: - # self.is_done = True - # self.notify_waiters() - self.requester.clear() - - def notify_waiters(self) -> None: - while self.waiters: - waiter = self.waiters.popleft() - if not waiter.done(): - waiter.set_result(None) - - def _set_chunk_mem(self, mem: Optional[bytes]) -> None: - self.mem = mem - self.length = len(self.mem) - if self.length > self.target_len: - raise RuntimeWarning( - f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}") - - def append_chunk_mem(self, mem: bytes) -> None: - self.mem = self.mem + mem - self.length = len(self.mem) - if self.length > self.target_len: - raise RuntimeWarning( - f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}") - self.notify_waiters() - - def add_chunk_requester(self, req: Request) -> None: - self.requester.append(req) - - async def is_disconneted(self) -> bool: - while self.requester: - res = await self.requester[0].is_disconnected() - if res: - self.requester.pop(0) - continue - return res - return True - - async def wait_chunk_update(self) -> None: - if self.is_done: - return - waiter = asyncio.Future() - self.waiters.append(waiter) - try: - await waiter - except: - waiter.cancel() - try: - self.waiters.remove(waiter) - except ValueError: - pass - - class MediaChunkHolderManager(object): - MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb - current_cache_size: int = 0 - # chat_id -> msg_id -> offset -> mem - chunk_cache: dict[int, dict[int, - list['TgFileSystemClient.MediaChunkHolder']]] = {} - # ChunkHolderId -> ChunkHolder - unique_chunk_id: int = 0 - chunk_lru: OrderedDict[int, 'TgFileSystemClient.MediaChunkHolder'] - - def __init__(self) -> None: - self.chunk_lru = OrderedDict() - - def _get_media_msg_cache(self, msg: types.Message) -> Optional[list['TgFileSystemClient.MediaChunkHolder']]: - chat_cache = self.chunk_cache.get(msg.chat_id) - if chat_cache is None: - return None - return chat_cache.get(msg.id) - - def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional['TgFileSystemClient.MediaChunkHolder']: - msg_cache = self._get_media_msg_cache(msg) - if msg_cache is None or len(msg_cache) == 0: - return None - pos = bisect.bisect_left(msg_cache, start) - if pos == len(msg_cache): - pos = pos - 1 - if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start: - return msg_cache[pos] - return None - elif msg_cache[pos].start == start: - return msg_cache[pos] - elif pos > 0: - pos = pos - 1 - if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start: - return msg_cache[pos] - return None - return None - - def _remove_pop_chunk(self, pop_chunk: 'TgFileSystemClient.MediaChunkHolder') -> None: - self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove( - pop_chunk.start) - self.current_cache_size -= pop_chunk.target_len - if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0: - self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id) - if len(self.chunk_cache[pop_chunk.chat_id]) == 0: - self.chunk_cache.pop(pop_chunk.chat_id) - - def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional['TgFileSystemClient.MediaChunkHolder']: - res = self._get_media_chunk_cache(msg, start) - if res is None: - return None - if lru: - self.chunk_lru.move_to_end(res.chunk_id) - return res - - def set_media_chunk(self, chunk: 'TgFileSystemClient.MediaChunkHolder') -> None: - cache_chat = self.chunk_cache.get(chunk.chat_id) - if cache_chat is None: - self.chunk_cache[chunk.chat_id] = {} - cache_chat = self.chunk_cache[chunk.chat_id] - cache_msg = cache_chat.get(chunk.msg_id) - if cache_msg is None: - cache_chat[chunk.msg_id] = [] - cache_msg = cache_chat[chunk.msg_id] - chunk.chunk_id = self.unique_chunk_id - self.unique_chunk_id += 1 - bisect.insort(cache_msg, chunk) - self.chunk_lru[chunk.chunk_id] = chunk - self.current_cache_size += chunk.target_len - while self.current_cache_size > self.MAX_CACHE_SIZE: - dummy = self.chunk_lru.popitem(last=False) - self._remove_pop_chunk(dummy[1]) - - def cancel_media_chunk(self, chunk: 'TgFileSystemClient.MediaChunkHolder') -> None: - cache_chat = self.chunk_cache.get(chunk.chat_id) - if cache_chat is None: - return - cache_msg = cache_chat.get(chunk.msg_id) - if cache_msg is None: - return - dummy = self.chunk_lru.pop(chunk.chunk_id, None) - if dummy is None: - return - self._remove_pop_chunk(dummy) - MAX_WORKER_ROUTINE = 4 SINGLE_NET_CHUNK_SIZE = 256 * 1024 # 256kb SINGLE_MEDIA_SIZE = 5 * 1024 * 1024 # 5mb @@ -236,7 +52,7 @@ class TgFileSystemClient(object): self.task_queue = asyncio.Queue() self.client = TelegramClient( f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param) - self.media_chunk_manager = TgFileSystemClient.MediaChunkHolderManager() + self.media_chunk_manager = MediaChunkHolderManager() self.db = db def __del__(self) -> None: @@ -315,6 +131,7 @@ class TgFileSystemClient(object): t.cancel() except Exception as err: logger.error(f"{err=}") + logger.error(traceback.format_exc()) async def _cache_whitelist_chat2(self): for chat_id in self.client_param.whitelist_chat: @@ -366,6 +183,7 @@ class TgFileSystemClient(object): await task[1] except Exception as err: logger.error(f"{err=}") + logger.error(traceback.format_exc()) finally: self.task_queue.task_done() @@ -434,7 +252,7 @@ class TgFileSystemClient(object): self.media_chunk_manager.cancel_media_chunk(media_holder) except Exception as err: logger.error( - f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{traceback.format_exc()}") + f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}") finally: media_holder.set_done() logger.debug( @@ -457,7 +275,7 @@ class TgFileSystemClient(object): align_pos = pos align_size = min(self.SINGLE_MEDIA_SIZE, file_size - align_pos) - holder = TgFileSystemClient.MediaChunkHolder( + holder = MediaChunkHolder( msg.chat_id, msg.id, align_pos, align_size) holder.add_chunk_requester(req) self.media_chunk_manager.set_media_chunk(holder) @@ -486,8 +304,8 @@ class TgFileSystemClient(object): pos = pos + need_len yield cache_chunk.mem[offset:offset+need_len] except Exception as err: - traceback.print_exc() logger.error(f"stream iter:{err=}") + logger.error(traceback.format_exc()) finally: async def _cancel_task_by_id(task_id: int): for _ in range(self.task_queue.qsize()): diff --git a/backend/api.py b/backend/api.py index a195609..8ba3cc8 100644 --- a/backend/api.py +++ b/backend/api.py @@ -64,9 +64,11 @@ async def search_tg_file_list(body: TgToFileListRequestBody): for item in res: msg_info = json.loads(item) file_name = apiutils.get_message_media_name_from_dict(msg_info) + chat_id = apiutils.get_message_chat_id_from_dict(msg_info) + msg_id = apiutils.get_message_msg_id_from_dict(msg_info) msg_info['file_name'] = file_name - msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{msg_info.get('id')}/{file_name}" - msg_info['src_tg_link'] = f"https://t.me/c/1216816802/21206" + msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}" + msg_info['src_tg_link'] = f"https://t.me/c/{chat_id}/{msg_id}" res_dict.append(msg_info) client_dict = json.loads(client.to_json()) diff --git a/backend/apiutils.py b/backend/apiutils.py index a478b23..1cfeba1 100644 --- a/backend/apiutils.py +++ b/backend/apiutils.py @@ -52,6 +52,20 @@ def get_message_media_name_from_dict(msg: dict[str, any]) -> str: file_name = "unknown.tmp" return file_name +def get_message_chat_id_from_dict(msg: dict[str, any]) -> int: + try: + return msg['peer_id']['channel_id'] + except: + pass + return 0 + +def get_message_msg_id_from_dict(msg: dict[str, any]) -> int: + try: + return msg['id'] + except: + pass + return 0 + def timeit_sec(func): @wraps(func) def timeit_wrapper(*args, **kwargs): diff --git a/frontend/home.py b/frontend/home.py index 88ae868..884d50b 100644 --- a/frontend/home.py +++ b/frontend/home.py @@ -94,7 +94,7 @@ def do_search_req(): st.session_state.force_skip = True st.rerun() - def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str): + def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str): file_size_str = f"{file_size/1024/1024:.2f}MB" container = st.container() container_columns = container.columns([1, 99]) @@ -104,7 +104,7 @@ def do_search_req(): expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} — *{file_size_str}*" popover = container_columns[1].popover(expender_title, use_container_width=True) - popover_columns = popover.columns([1, 3]) + popover_columns = popover.columns([1, 3, 1]) if url: popover_columns[0].video(url) else: @@ -112,7 +112,8 @@ def do_search_req(): popover_columns[1].markdown(f'{msg_ctx}') popover_columns[1].markdown(f'**{file_name}**') popover_columns[1].markdown(f'文件大小:*{file_size_str}*') - popover_columns[1].link_button('⬇️Download Link', url) + popover_columns[2].link_button('⬇️Download Link', url, use_container_width=True) + popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True) @st.experimental_fragment def show_search_res(res: dict[str, any]): @@ -134,7 +135,9 @@ def do_search_req(): file_name = None file_size = 0 download_url = "" + src_link = "" try: + src_link = v['src_tg_link'] msg_ctx = v['message'] msg_id = str(v['id']) doc = v['media']['document'] @@ -152,7 +155,7 @@ def do_search_req(): except Exception as err: msg_ctx = f"{err=}\r\n\r\n" + msg_ctx media_file_res_container( - i, msg_ctx, file_name, file_size, download_url) + i, msg_ctx, file_name, file_size, download_url, src_link) page_switch_render() show_text = "" diff --git a/start.py b/start.py index ab3e6dc..81a3199 100644 --- a/start.py +++ b/start.py @@ -16,7 +16,7 @@ with open('logging_config.yaml', 'r') as f: logging.config.dictConfig(yaml.safe_load(f.read())) LOGGING_CONFIG["formatters"]["default"]["fmt"] = "[%(levelname)s] %(asctime)s [uvicorn.default]:%(message)s" -LOGGING_CONFIG["formatters"]["access"]["fmt"] = '[%(levelname)s]%(asctime)s [uvicorn.access]:%(client_addr)s - "%(request_line)s" %(status_code)s' +LOGGING_CONFIG["formatters"]["access"]["fmt"] = '[%(levelname)s] %(asctime)s [uvicorn.access]:%(client_addr)s - "%(request_line)s" %(status_code)s' LOGGING_CONFIG["handlers"]["timed_rotating_api_file"] = { "class": "logging.handlers.TimedRotatingFileHandler", "filename": "logs/app.log",