import asyncio import json import time import re import rsa import os import functools import traceback import logging from typing import Union, Optional, Literal, Callable from telethon import TelegramClient, types, hints, events from telethon.custom import QRLogin from fastapi import Request import configParse from backend import apiutils from backend.UserManager import UserManager from backend.MediaCacheManager import MediaChunkHolder, MediaChunkHolderManager logger = logging.getLogger(__file__.split("/")[-1]) class TgFileSystemClient(object): MAX_WORKER_ROUTINE = 8 SINGLE_NET_CHUNK_SIZE = 256 * 1024 # 256kb SINGLE_MEDIA_SIZE = 5 * 1024 * 1024 # 5mb api_id: int api_hash: str session_name: str proxy_param: dict[str, any] client: TelegramClient media_chunk_manager: MediaChunkHolderManager dialogs_cache: Optional[hints.TotalList] = None worker_routines: list[asyncio.Task] qr_login: QRLogin | None = None login_task: asyncio.Task | None = None # rsa key sign: str public_key: rsa.PublicKey private_key: rsa.PrivateKey # task should: (task_id, callabledFunc) task_queue: asyncio.Queue task_id: int = 0 me: Union[types.User, types.InputPeerUser] = None # client config client_param: configParse.TgToFileSystemParameter.ClientConfigPatameter def __init__( self, session_name: str, param: configParse.TgToFileSystemParameter, db: UserManager, ) -> None: self.api_id = param.tgApi.api_id self.api_hash = param.tgApi.api_hash self.session_name = session_name self.proxy_param = ( { "proxy_type": param.proxy.proxy_type, "addr": param.proxy.addr, "port": param.proxy.port, } if param.proxy.enable else {} ) self.public_key, self.private_key = rsa.newkeys(1024) self.client_param = next( (client_param for client_param in param.clients if client_param.token == session_name), configParse.TgToFileSystemParameter.ClientConfigPatameter(), ) self.sign = self.client_param.token self.task_queue = asyncio.Queue() self.client = TelegramClient( f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param, ) self.media_chunk_manager = MediaChunkHolderManager() self.db = db self.worker_routines = [] def __del__(self) -> None: if self.client.loop.is_running(): self.client.loop.create_task(self.stop()) else: self.client.loop.run_until_complete(self.stop()) def __repr__(self) -> str: if not self.client.is_connected: return f"client disconnected, session_name:{self.session_name}" return f"client connected, session_name:{self.session_name}, username:{self.me.username}, phone:{self.me.phone}, detail:{self.me.stringify()}" def _check_before_call(func): def call_check_wrapper(self, *args, **kwargs): if not self.is_valid(): raise RuntimeError("Client does not run.") result = func(self, *args, **kwargs) return result return call_check_wrapper def _acheck_before_call(func): async def call_check_wrapper(self, *args, **kwargs): if not self.is_valid(): raise RuntimeError("Client does not run.") result = await func(self, *args, **kwargs) return result return call_check_wrapper @_check_before_call def to_dict(self) -> dict: return self.me.to_dict() @_check_before_call def to_json(self) -> str: return self.me.to_json() def is_valid(self) -> bool: return self.client.is_connected() and self.me is not None @_check_before_call def _register_update_event(self, from_users: list[int] = []) -> None: @self.client.on(events.NewMessage(incoming=True, from_users=from_users)) async def _incoming_new_message_handler(event) -> None: msg: types.Message = event.message self.db.insert_by_message(self.me, msg) async def login(self, mode: Literal["phone", "qrcode"] = "qrcode") -> str: if self.is_valid(): return "" if mode == "phone": raise NotImplementedError if self.qr_login is not None: return self.qr_login.url self.qr_login = await self.client.qr_login() async def wait_for_qr_login(): try: await self.qr_login.wait() await self.start() except Exception as err: logger.warning(f"wait for login, {err=}, {traceback.format_exc()}") finally: self.login_task = None self.qr_login = None self.login_task = self.client.loop.create_task(wait_for_qr_login()) return self.qr_login.url async def start(self) -> None: if self.is_valid(): return if not self.client.is_connected(): await self.client.connect() self.me = await self.client.get_me() if self.me is None: raise RuntimeError(f"The {self.session_name} Client Does Not Login") for _ in range(self.MAX_WORKER_ROUTINE): worker_routine = self.client.loop.create_task(self._worker_routine_handler()) self.worker_routines.append(worker_routine) if len(self.client_param.whitelist_chat) > 0: self._register_update_event(from_users=self.client_param.whitelist_chat) await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat())) async def stop(self) -> None: await self.client.loop.create_task(self._cancel_tasks()) while not self.task_queue.empty(): self.task_queue.get_nowait() self.task_queue.task_done() await self.client.disconnect() async def _cancel_tasks(self) -> None: for t in self.worker_routines: try: t.cancel() except Exception as err: logger.error(f"{err=}") logger.error(traceback.format_exc()) async def _cache_whitelist_chat_full_policy(self, chat_id: int, callback: Callable = None): async for msg in self.client.iter_messages(chat_id): if len(self.db.get_msg_by_unique_id(UserManager.generate_unique_id_by_msg(self.me, msg))) != 0: continue self.db.insert_by_message(self.me, msg) if callback is not None: callback() logger.info(f"{chat_id} quit cache task.") async def _cache_whitelist_chat_lazy_policy(self, chat_id: int, callback: Callable = None): # update newest msg newest_msg = self.db.get_newest_msg_by_chat_id(chat_id) if len(newest_msg) > 0: newest_msg = newest_msg[0] async for msg in self.client.iter_messages(chat_id): if msg.id <= self.db.get_column_msg_id(newest_msg): break self.db.insert_by_message(self.me, msg) # update oldest msg oldest_msg = self.db.get_oldest_msg_by_chat_id(chat_id) if len(oldest_msg) > 0: oldest_msg = oldest_msg[0] offset = self.db.get_column_msg_id(oldest_msg) async for msg in self.client.iter_messages(chat_id, offset_id=offset): self.db.insert_by_message(self.me, msg) else: async for msg in self.client.iter_messages(chat_id): self.db.insert_by_message(self.me, msg) if callback is not None: callback() logger.info(f"{chat_id} quit cache task.") async def _cache_whitelist_chat(self): max_cache_tasks_num = TgFileSystemClient.MAX_WORKER_ROUTINE // 2 tasks_sem = asyncio.Semaphore(value=max_cache_tasks_num) def _sem_release_callback(): tasks_sem.release() for chat_id in self.client_param.whitelist_chat: await tasks_sem.acquire() await self.task_queue.put( (self._get_unique_task_id(), self._cache_whitelist_chat_lazy_policy(chat_id, callback=_sem_release_callback)) ) @_acheck_before_call async def get_message(self, chat_id: int | str, msg_id: int) -> types.Message: msg = await self.client.get_messages(chat_id, ids=msg_id) return msg @_acheck_before_call async def get_dialogs(self, limit: int = 10, offset: int = 0, refresh: bool = False) -> hints.TotalList: if self.dialogs_cache is not None and refresh is False: return self.dialogs_cache[offset : offset + limit] self.dialogs_cache = await self.client.get_dialogs() return self.dialogs_cache[offset : offset + limit] async def _worker_routine_handler(self) -> None: while self.client.is_connected(): try: task = await self.task_queue.get() await task[1] except Exception as err: logger.error(f"{err=}") logger.error(traceback.format_exc()) finally: self.task_queue.task_done() def _get_unique_task_id(self) -> int: self.task_id += 1 return self.task_id async def _get_offset_msg_id(self, chat_id: int, offset: int) -> int: if offset != 0: begin = await self.client.get_messages(chat_id, limit=1) if len(begin) == 0: return hints.TotalList() first_id = begin[0].id offset = first_id + offset return offset @_acheck_before_call async def get_entity(self, chat_id_or_name) -> hints.Entity: return await self.client.get_entity(chat_id_or_name) @_acheck_before_call async def get_messages(self, chat_id: int, limit: int = 10, offset: int = 0) -> hints.TotalList: offset = await self._get_offset_msg_id(chat_id, offset) res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset) return res_list @_acheck_before_call async def get_messages_by_search( self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inner_search: bool = False, ignore_case: bool = False, ) -> hints.TotalList: offset = await self._get_offset_msg_id(chat_id, offset) if inner_search: res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word) return res_list # search by myself res_list = hints.TotalList() cnt = 0 async for msg in self.client.iter_messages(chat_id, offset_id=offset): if cnt >= 1_000: break cnt += 1 if msg.text.find(search_word) == -1 and apiutils.get_message_media_name(msg).find(search_word) == -1: continue res_list.append(msg) if len(res_list) >= limit: break return res_list async def get_messages_by_search_db( self, chat_ids: list[int], search_word: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False, ) -> list[any]: res = self.db.get_msg_by_chat_id_and_keyword( chat_ids, search_word, limit=limit, offset=offset, inc=inc, ignore_case=ignore_case, ) res = [self.db.get_column_msg_js(v) for v in res] return res async def _download_media_chunk(self, msg: types.Message, media_holder: MediaChunkHolder) -> None: logger.info(f"start downloading new chunk:{media_holder=}") try: offset = media_holder.start + media_holder.length target_size = media_holder.target_len - media_holder.length remain_size = target_size async for chunk in self.client.iter_download(msg, offset=offset, chunk_size=self.SINGLE_NET_CHUNK_SIZE): if not isinstance(chunk, bytes): chunk = chunk.tobytes() remain_size -= len(chunk) if remain_size <= 0: media_holder.append_chunk_mem(chunk[: len(chunk) + remain_size]) else: media_holder.append_chunk_mem(chunk) if media_holder.is_completed(): break if await media_holder.is_disconneted(): raise asyncio.CancelledError("all requester canceled.") except asyncio.CancelledError as err: logger.info(f"cancel holder:{media_holder}") self.media_chunk_manager.cancel_media_chunk(media_holder) except Exception as err: logger.error( f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}" ) else: if not self.media_chunk_manager.move_media_chunk_to_disk(media_holder): logger.warning(f"move to disk failed, {media_holder=}") logger.debug(f"downloaded chunk:{offset=},{target_size=},{media_holder}") finally: pass async def streaming_get_iter(self, msg: types.Message, start: int, end: int, req: Request): try: logger.debug(f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]") cur_task_id = self._get_unique_task_id() pos = start while not await req.is_disconnected() and pos <= end: cache_chunk = self.media_chunk_manager.get_media_chunk(msg, pos) if cache_chunk is None: # post download task # align pos download task file_size = msg.media.document.size # align_pos = pos // self.SINGLE_MEDIA_SIZE * self.SINGLE_MEDIA_SIZE align_pos = pos align_size = min(self.SINGLE_MEDIA_SIZE, file_size - align_pos) holder = self.media_chunk_manager.create_media_chunk_holder(msg.chat_id, msg.id, align_pos, align_size) logger.info(f"new holder create:{holder}") holder.add_chunk_requester(req) self.media_chunk_manager.set_media_chunk(holder) self.task_queue.put_nowait((cur_task_id, self._download_media_chunk(msg, holder))) elif not cache_chunk.is_completed(): # yield return completed part # await untill completed or pos > end cache_chunk.add_chunk_requester(req) while pos < cache_chunk.start + cache_chunk.target_len and pos <= end: if await req.is_disconnected(): break offset = pos - cache_chunk.start if offset >= cache_chunk.length: await cache_chunk.wait_chunk_update() continue need_len = min(cache_chunk.length - offset, end - pos + 1) pos = pos + need_len yield cache_chunk.mem[offset : offset + need_len] else: offset = pos - cache_chunk.start if offset >= cache_chunk.length: raise RuntimeError(f"lru cache missed!{pos=},{cache_chunk=}") need_len = min(cache_chunk.length - offset, end - pos + 1) pos = pos + need_len yield cache_chunk.mem[offset : offset + need_len] except Exception as err: logger.error(f"stream iter:{err=}") logger.error(traceback.format_exc()) finally: async def _cancel_task_by_id(task_id: int): for _ in range(self.task_queue.qsize()): task = self.task_queue.get_nowait() self.task_queue.task_done() if task[0] != task_id: self.task_queue.put_nowait(task) await self.client.loop.create_task(_cancel_task_by_id(cur_task_id)) logger.debug(f"yield quit,{msg.chat_id=},{msg.id=},[{start}:{end}]") def __enter__(self): raise NotImplementedError def __exit__(self): raise NotImplementedError async def __aenter__(self): await self.start() async def __aexit__(self): await self.stop()