chore: move MeidaChunkHolder and Manager to MediaCacheManager.py

This commit is contained in:
hehesheng 2024-06-01 16:33:28 +08:00
parent f48a35ad17
commit d6e46533df
6 changed files with 227 additions and 208 deletions

View File

@ -1,12 +1,197 @@
import functools import functools
import logging import logging
import bisect
import collections import collections
import asyncio import asyncio
import collections
from typing import Union, Optional
import diskcache import diskcache
from fastapi import Request
from telethon import types
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
@functools.total_ordering
class MediaChunkHolder(object):
waiters: collections.deque[asyncio.Future]
requester: list[Request] = []
chunk_id: int = 0
is_done: bool = False
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
self.chat_id = chat_id
self.msg_id = msg_id
self.start = start
self.target_len = target_len
self.mem = mem or bytes()
self.length = len(self.mem)
self.waiters = collections.deque()
def __repr__(self) -> str:
return f"MediaChunk,start:{self.start},len:{self.length}"
def __eq__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start == other
return self.start == other.start
def __le__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start <= other
return self.start <= other.start
def __gt__(self, other: 'MediaChunkHolder'):
if isinstance(other, int):
return self.start > other
return self.start > other.start
def __add__(self, other: Union['MediaChunkHolder', bytes]):
if isinstance(other, MediaChunkHolder):
other = other.mem
self.append_chunk_mem(other)
def is_completed(self) -> bool:
return self.length >= self.target_len
def set_done(self) -> None:
# self.is_done = True
# self.notify_waiters()
self.requester.clear()
def notify_waiters(self) -> None:
while self.waiters:
waiter = self.waiters.popleft()
if not waiter.done():
waiter.set_result(None)
def _set_chunk_mem(self, mem: Optional[bytes]) -> None:
self.mem = mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
def append_chunk_mem(self, mem: bytes) -> None:
self.mem = self.mem + mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
self.notify_waiters()
def add_chunk_requester(self, req: Request) -> None:
self.requester.append(req)
async def is_disconneted(self) -> bool:
while self.requester:
res = await self.requester[0].is_disconnected()
if res:
self.requester.pop(0)
continue
return res
return True
async def wait_chunk_update(self) -> None:
if self.is_done:
return
waiter = asyncio.Future()
self.waiters.append(waiter)
try:
await waiter
except:
waiter.cancel()
try:
self.waiters.remove(waiter)
except ValueError:
pass
class MediaChunkHolderManager(object):
MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb
current_cache_size: int = 0
# chat_id -> msg_id -> offset -> mem
chunk_cache: dict[int, dict[int,
list[MediaChunkHolder]]] = {}
# ChunkHolderId -> ChunkHolder
unique_chunk_id: int = 0
chunk_lru: collections.OrderedDict[int, MediaChunkHolder]
def __init__(self) -> None:
self.chunk_lru = collections.OrderedDict()
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]:
chat_cache = self.chunk_cache.get(msg.chat_id)
if chat_cache is None:
return None
return chat_cache.get(msg.id)
def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional[MediaChunkHolder]:
msg_cache = self._get_media_msg_cache(msg)
if msg_cache is None or len(msg_cache) == 0:
return None
pos = bisect.bisect_left(msg_cache, start)
if pos == len(msg_cache):
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
elif msg_cache[pos].start == start:
return msg_cache[pos]
elif pos > 0:
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
return None
def _remove_pop_chunk(self, pop_chunk: MediaChunkHolder) -> None:
self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove(
pop_chunk.start)
self.current_cache_size -= pop_chunk.target_len
if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0:
self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id)
if len(self.chunk_cache[pop_chunk.chat_id]) == 0:
self.chunk_cache.pop(pop_chunk.chat_id)
def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional[MediaChunkHolder]:
res = self._get_media_chunk_cache(msg, start)
if res is None:
return None
if lru:
self.chunk_lru.move_to_end(res.chunk_id)
return res
def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
self.chunk_cache[chunk.chat_id] = {}
cache_chat = self.chunk_cache[chunk.chat_id]
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
cache_chat[chunk.msg_id] = []
cache_msg = cache_chat[chunk.msg_id]
chunk.chunk_id = self.unique_chunk_id
self.unique_chunk_id += 1
bisect.insort(cache_msg, chunk)
self.chunk_lru[chunk.chunk_id] = chunk
self.current_cache_size += chunk.target_len
while self.current_cache_size > self.MAX_CACHE_SIZE:
dummy = self.chunk_lru.popitem(last=False)
self._remove_pop_chunk(dummy[1])
def cancel_media_chunk(self, chunk: MediaChunkHolder) -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
return
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
return
dummy = self.chunk_lru.pop(chunk.chunk_id, None)
if dummy is None:
return
self._remove_pop_chunk(dummy)
@functools.total_ordering @functools.total_ordering
class MediaBlockHolder(object): class MediaBlockHolder(object):
waiters: collections.deque[asyncio.Future] waiters: collections.deque[asyncio.Future]
@ -24,28 +209,25 @@ class MediaBlockHolder(object):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"MediaBlockHolder,id:{self.chat_id}-{self.msg_id},start:{self.start},len:{self.length}/{self.target_len}" return f"MediaBlockHolder,id:{self.chat_id}-{self.msg_id},start:{self.start},len:{self.length}/{self.target_len}"
def __eq__(self, other: 'MediaBlockHolder'|int): def __eq__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int): if isinstance(other, int):
return self.start == other return self.start == other
return self.start == other.start return self.start == other.start
def __le__(self, other: 'MediaBlockHolder'|int): def __le__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int): if isinstance(other, int):
return self.start <= other return self.start <= other
return self.start <= other.start return self.start <= other.start
def __gt__(self, other: 'MediaBlockHolder'|int): def __gt__(self, other: Union['MediaBlockHolder', int]):
if isinstance(other, int): if isinstance(other, int):
return self.start > other return self.start > other
return self.start > other.start return self.start > other.start
def __add__(self, other: 'MediaBlockHolder'|bytes): def __add__(self, other: Union['MediaBlockHolder', bytes]):
if isinstance(other, bytes): if isinstance(other, MediaBlockHolder):
self.append_mem(other) other = other.mem
elif isinstance(other, MediaBlockHolder):
self.append_mem(other.mem) self.append_mem(other.mem)
else:
raise RuntimeError(f"{self} can't add {type(other)}")
def is_completed(self) -> bool: def is_completed(self) -> bool:
return self.length >= self.target_len return self.length >= self.target_len
@ -85,12 +267,12 @@ class BlockInfo(object):
self.length = length self.length = length
self.in_mem = in_mem self.in_mem = in_mem
def __eq__(self, other: 'BlockInfo'|int): def __eq__(self, other: Union['BlockInfo', int]):
if isinstance(other, int): if isinstance(other, int):
return self.offset == other return self.offset == other
return self.offset == other.offset return self.offset == other.offset
def __le__(self, other: 'BlockInfo'|int): def __le__(self, other: Union['BlockInfo', int]):
if isinstance(other, int): if isinstance(other, int):
return self.offset <= other return self.offset <= other
return self.offset <= other.offset return self.offset <= other.offset

View File

@ -1,15 +1,12 @@
import asyncio import asyncio
import json import json
import bisect
import time import time
import re import re
import rsa import rsa
import os import os
import functools import functools
import collections
import traceback import traceback
import logging import logging
from collections import OrderedDict
from typing import Union, Optional from typing import Union, Optional
from telethon import TelegramClient, types, hints, events from telethon import TelegramClient, types, hints, events
@ -18,192 +15,11 @@ from fastapi import Request
import configParse import configParse
from backend import apiutils from backend import apiutils
from backend.UserManager import UserManager from backend.UserManager import UserManager
from backend.MediaCacheManager import MediaChunkHolder, MediaChunkHolderManager
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
class TgFileSystemClient(object): class TgFileSystemClient(object):
@functools.total_ordering
class MediaChunkHolder(object):
waiters: collections.deque[asyncio.Future]
requester: list[Request] = []
chunk_id: int = 0
is_done: bool = False
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
self.chat_id = chat_id
self.msg_id = msg_id
self.start = start
self.target_len = target_len
self.mem = mem or bytes()
self.length = len(self.mem)
self.waiters = collections.deque()
def __repr__(self) -> str:
return f"MediaChunk,start:{self.start},len:{self.length}"
def __eq__(self, other: 'TgFileSystemClient.MediaChunkHolder'):
if isinstance(other, int):
return self.start == other
return self.start == other.start
def __le__(self, other: 'TgFileSystemClient.MediaChunkHolder'):
if isinstance(other, int):
return self.start <= other
return self.start <= other.start
def __gt__(self, other: 'TgFileSystemClient.MediaChunkHolder'):
if isinstance(other, int):
return self.start > other
return self.start > other.start
def __add__(self, other):
if isinstance(other, bytes):
self.append_chunk_mem(other)
elif isinstance(other, TgFileSystemClient.MediaChunkHolder):
self.append_chunk_mem(other.mem)
else:
raise RuntimeError("does not suported this type to add")
def is_completed(self) -> bool:
return self.length >= self.target_len
def set_done(self) -> None:
# self.is_done = True
# self.notify_waiters()
self.requester.clear()
def notify_waiters(self) -> None:
while self.waiters:
waiter = self.waiters.popleft()
if not waiter.done():
waiter.set_result(None)
def _set_chunk_mem(self, mem: Optional[bytes]) -> None:
self.mem = mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
def append_chunk_mem(self, mem: bytes) -> None:
self.mem = self.mem + mem
self.length = len(self.mem)
if self.length > self.target_len:
raise RuntimeWarning(
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}")
self.notify_waiters()
def add_chunk_requester(self, req: Request) -> None:
self.requester.append(req)
async def is_disconneted(self) -> bool:
while self.requester:
res = await self.requester[0].is_disconnected()
if res:
self.requester.pop(0)
continue
return res
return True
async def wait_chunk_update(self) -> None:
if self.is_done:
return
waiter = asyncio.Future()
self.waiters.append(waiter)
try:
await waiter
except:
waiter.cancel()
try:
self.waiters.remove(waiter)
except ValueError:
pass
class MediaChunkHolderManager(object):
MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb
current_cache_size: int = 0
# chat_id -> msg_id -> offset -> mem
chunk_cache: dict[int, dict[int,
list['TgFileSystemClient.MediaChunkHolder']]] = {}
# ChunkHolderId -> ChunkHolder
unique_chunk_id: int = 0
chunk_lru: OrderedDict[int, 'TgFileSystemClient.MediaChunkHolder']
def __init__(self) -> None:
self.chunk_lru = OrderedDict()
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list['TgFileSystemClient.MediaChunkHolder']]:
chat_cache = self.chunk_cache.get(msg.chat_id)
if chat_cache is None:
return None
return chat_cache.get(msg.id)
def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional['TgFileSystemClient.MediaChunkHolder']:
msg_cache = self._get_media_msg_cache(msg)
if msg_cache is None or len(msg_cache) == 0:
return None
pos = bisect.bisect_left(msg_cache, start)
if pos == len(msg_cache):
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
elif msg_cache[pos].start == start:
return msg_cache[pos]
elif pos > 0:
pos = pos - 1
if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start:
return msg_cache[pos]
return None
return None
def _remove_pop_chunk(self, pop_chunk: 'TgFileSystemClient.MediaChunkHolder') -> None:
self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove(
pop_chunk.start)
self.current_cache_size -= pop_chunk.target_len
if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0:
self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id)
if len(self.chunk_cache[pop_chunk.chat_id]) == 0:
self.chunk_cache.pop(pop_chunk.chat_id)
def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional['TgFileSystemClient.MediaChunkHolder']:
res = self._get_media_chunk_cache(msg, start)
if res is None:
return None
if lru:
self.chunk_lru.move_to_end(res.chunk_id)
return res
def set_media_chunk(self, chunk: 'TgFileSystemClient.MediaChunkHolder') -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
self.chunk_cache[chunk.chat_id] = {}
cache_chat = self.chunk_cache[chunk.chat_id]
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
cache_chat[chunk.msg_id] = []
cache_msg = cache_chat[chunk.msg_id]
chunk.chunk_id = self.unique_chunk_id
self.unique_chunk_id += 1
bisect.insort(cache_msg, chunk)
self.chunk_lru[chunk.chunk_id] = chunk
self.current_cache_size += chunk.target_len
while self.current_cache_size > self.MAX_CACHE_SIZE:
dummy = self.chunk_lru.popitem(last=False)
self._remove_pop_chunk(dummy[1])
def cancel_media_chunk(self, chunk: 'TgFileSystemClient.MediaChunkHolder') -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id)
if cache_chat is None:
return
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
return
dummy = self.chunk_lru.pop(chunk.chunk_id, None)
if dummy is None:
return
self._remove_pop_chunk(dummy)
MAX_WORKER_ROUTINE = 4 MAX_WORKER_ROUTINE = 4
SINGLE_NET_CHUNK_SIZE = 256 * 1024 # 256kb SINGLE_NET_CHUNK_SIZE = 256 * 1024 # 256kb
SINGLE_MEDIA_SIZE = 5 * 1024 * 1024 # 5mb SINGLE_MEDIA_SIZE = 5 * 1024 * 1024 # 5mb
@ -236,7 +52,7 @@ class TgFileSystemClient(object):
self.task_queue = asyncio.Queue() self.task_queue = asyncio.Queue()
self.client = TelegramClient( self.client = TelegramClient(
f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param) f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param)
self.media_chunk_manager = TgFileSystemClient.MediaChunkHolderManager() self.media_chunk_manager = MediaChunkHolderManager()
self.db = db self.db = db
def __del__(self) -> None: def __del__(self) -> None:
@ -315,6 +131,7 @@ class TgFileSystemClient(object):
t.cancel() t.cancel()
except Exception as err: except Exception as err:
logger.error(f"{err=}") logger.error(f"{err=}")
logger.error(traceback.format_exc())
async def _cache_whitelist_chat2(self): async def _cache_whitelist_chat2(self):
for chat_id in self.client_param.whitelist_chat: for chat_id in self.client_param.whitelist_chat:
@ -366,6 +183,7 @@ class TgFileSystemClient(object):
await task[1] await task[1]
except Exception as err: except Exception as err:
logger.error(f"{err=}") logger.error(f"{err=}")
logger.error(traceback.format_exc())
finally: finally:
self.task_queue.task_done() self.task_queue.task_done()
@ -434,7 +252,7 @@ class TgFileSystemClient(object):
self.media_chunk_manager.cancel_media_chunk(media_holder) self.media_chunk_manager.cancel_media_chunk(media_holder)
except Exception as err: except Exception as err:
logger.error( logger.error(
f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{traceback.format_exc()}") f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}")
finally: finally:
media_holder.set_done() media_holder.set_done()
logger.debug( logger.debug(
@ -457,7 +275,7 @@ class TgFileSystemClient(object):
align_pos = pos align_pos = pos
align_size = min(self.SINGLE_MEDIA_SIZE, align_size = min(self.SINGLE_MEDIA_SIZE,
file_size - align_pos) file_size - align_pos)
holder = TgFileSystemClient.MediaChunkHolder( holder = MediaChunkHolder(
msg.chat_id, msg.id, align_pos, align_size) msg.chat_id, msg.id, align_pos, align_size)
holder.add_chunk_requester(req) holder.add_chunk_requester(req)
self.media_chunk_manager.set_media_chunk(holder) self.media_chunk_manager.set_media_chunk(holder)
@ -486,8 +304,8 @@ class TgFileSystemClient(object):
pos = pos + need_len pos = pos + need_len
yield cache_chunk.mem[offset:offset+need_len] yield cache_chunk.mem[offset:offset+need_len]
except Exception as err: except Exception as err:
traceback.print_exc()
logger.error(f"stream iter:{err=}") logger.error(f"stream iter:{err=}")
logger.error(traceback.format_exc())
finally: finally:
async def _cancel_task_by_id(task_id: int): async def _cancel_task_by_id(task_id: int):
for _ in range(self.task_queue.qsize()): for _ in range(self.task_queue.qsize()):

View File

@ -64,9 +64,11 @@ async def search_tg_file_list(body: TgToFileListRequestBody):
for item in res: for item in res:
msg_info = json.loads(item) msg_info = json.loads(item)
file_name = apiutils.get_message_media_name_from_dict(msg_info) file_name = apiutils.get_message_media_name_from_dict(msg_info)
chat_id = apiutils.get_message_chat_id_from_dict(msg_info)
msg_id = apiutils.get_message_msg_id_from_dict(msg_info)
msg_info['file_name'] = file_name msg_info['file_name'] = file_name
msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{msg_info.get('id')}/{file_name}" msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}"
msg_info['src_tg_link'] = f"https://t.me/c/1216816802/21206" msg_info['src_tg_link'] = f"https://t.me/c/{chat_id}/{msg_id}"
res_dict.append(msg_info) res_dict.append(msg_info)
client_dict = json.loads(client.to_json()) client_dict = json.loads(client.to_json())

View File

@ -52,6 +52,20 @@ def get_message_media_name_from_dict(msg: dict[str, any]) -> str:
file_name = "unknown.tmp" file_name = "unknown.tmp"
return file_name return file_name
def get_message_chat_id_from_dict(msg: dict[str, any]) -> int:
try:
return msg['peer_id']['channel_id']
except:
pass
return 0
def get_message_msg_id_from_dict(msg: dict[str, any]) -> int:
try:
return msg['id']
except:
pass
return 0
def timeit_sec(func): def timeit_sec(func):
@wraps(func) @wraps(func)
def timeit_wrapper(*args, **kwargs): def timeit_wrapper(*args, **kwargs):

View File

@ -94,7 +94,7 @@ def do_search_req():
st.session_state.force_skip = True st.session_state.force_skip = True
st.rerun() st.rerun()
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str): def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str):
file_size_str = f"{file_size/1024/1024:.2f}MB" file_size_str = f"{file_size/1024/1024:.2f}MB"
container = st.container() container = st.container()
container_columns = container.columns([1, 99]) container_columns = container.columns([1, 99])
@ -104,7 +104,7 @@ def do_search_req():
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*" expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*"
popover = container_columns[1].popover(expender_title, use_container_width=True) popover = container_columns[1].popover(expender_title, use_container_width=True)
popover_columns = popover.columns([1, 3]) popover_columns = popover.columns([1, 3, 1])
if url: if url:
popover_columns[0].video(url) popover_columns[0].video(url)
else: else:
@ -112,7 +112,8 @@ def do_search_req():
popover_columns[1].markdown(f'{msg_ctx}') popover_columns[1].markdown(f'{msg_ctx}')
popover_columns[1].markdown(f'**{file_name}**') popover_columns[1].markdown(f'**{file_name}**')
popover_columns[1].markdown(f'文件大小:*{file_size_str}*') popover_columns[1].markdown(f'文件大小:*{file_size_str}*')
popover_columns[1].link_button('Download Link', url) popover_columns[2].link_button('Download Link', url, use_container_width=True)
popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True)
@st.experimental_fragment @st.experimental_fragment
def show_search_res(res: dict[str, any]): def show_search_res(res: dict[str, any]):
@ -134,7 +135,9 @@ def do_search_req():
file_name = None file_name = None
file_size = 0 file_size = 0
download_url = "" download_url = ""
src_link = ""
try: try:
src_link = v['src_tg_link']
msg_ctx = v['message'] msg_ctx = v['message']
msg_id = str(v['id']) msg_id = str(v['id'])
doc = v['media']['document'] doc = v['media']['document']
@ -152,7 +155,7 @@ def do_search_req():
except Exception as err: except Exception as err:
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
media_file_res_container( media_file_res_container(
i, msg_ctx, file_name, file_size, download_url) i, msg_ctx, file_name, file_size, download_url, src_link)
page_switch_render() page_switch_render()
show_text = "" show_text = ""