diff --git a/TgFileSystemClient.py b/TgFileSystemClient.py index 5a5354a..dfe7428 100644 --- a/TgFileSystemClient.py +++ b/TgFileSystemClient.py @@ -1,5 +1,13 @@ import asyncio import json +import bisect +import time +import re +import rsa +import functools +import collections +import traceback +from collections import OrderedDict from typing import Union, Optional from telethon import TelegramClient, types, hints @@ -9,12 +17,161 @@ import apiutils class TgFileSystemClient(object): + @functools.total_ordering + class MediaChunkHolder(object): + waiters: collections.deque[asyncio.Future] + chunk_id: int = 0 + + def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None: + self.chat_id = chat_id + self.msg_id = msg_id + self.start = start + self.target_len = target_len + self.mem = mem or bytes() + self.length = len(self.mem) + self.waiters = collections.deque() + + def __repr__(self) -> str: + return f"MediaChunk,start:{self.start},len:{self.length}" + + def __eq__(self, other: 'TgFileSystemClient.MediaChunkHolder'): + if isinstance(other, int): + return self.start == other + return self.start == other.start + + def __le__(self, other: 'TgFileSystemClient.MediaChunkHolder'): + if isinstance(other, int): + return self.start <= other + return self.start <= other.start + + def __gt__(self, other: 'TgFileSystemClient.MediaChunkHolder'): + if isinstance(other, int): + return self.start > other + return self.start > other.start + + def __add__(self, other): + if isinstance(other, bytes): + self.append_chunk_mem(other) + elif isinstance(other, TgFileSystemClient.MediaChunkHolder): + self.append_chunk_mem(other.mem) + else: + raise RuntimeError("does not suported this type to add") + + def is_completed(self) -> bool: + return self.length >= self.target_len + + def _set_chunk_mem(self, mem: Optional[bytes]) -> None: + self.mem = mem + self.length = len(self.mem) + if self.length > self.target_len: + raise RuntimeWarning( + f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}") + + def append_chunk_mem(self, mem: bytes) -> None: + self.mem = self.mem + mem + self.length = len(self.mem) + if self.length > self.target_len: + raise RuntimeWarning( + f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}") + while self.waiters: + waiter = self.waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + + async def wait_chunk_update(self): + waiter = asyncio.Future() + self.waiters.append(waiter) + try: + await waiter + except: + waiter.cancel() + try: + self.waiters.remove(waiter) + except ValueError: + pass + + class MediaChunkHolderManager(object): + MAX_CACHE_SIZE = 1024 * 1024 * 1024 # 1Gb + current_cache_size: int = 0 + # chat_id -> msg_id -> offset -> mem + chunk_cache: dict[int, dict[int, + list['TgFileSystemClient.MediaChunkHolder']]] = {} + # ChunkHolderId -> ChunkHolder + unique_chunk_id: int = 0 + chunk_lru: OrderedDict[int, 'TgFileSystemClient.MediaChunkHolder'] + + def __init__(self) -> None: + self.chunk_lru = OrderedDict() + + def _get_media_msg_cache(self, msg: types.Message) -> Optional[list['TgFileSystemClient.MediaChunkHolder']]: + chat_cache = self.chunk_cache.get(msg.chat_id) + if chat_cache is None: + return None + return chat_cache.get(msg.id) + + def _get_media_chunk_cache(self, msg: types.Message, start: int) -> Optional['TgFileSystemClient.MediaChunkHolder']: + msg_cache = self._get_media_msg_cache(msg) + if msg_cache is None or len(msg_cache) == 0: + return None + pos = bisect.bisect_left(msg_cache, start) + if pos == len(msg_cache): + pos = pos - 1 + if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start: + return msg_cache[pos] + return None + elif msg_cache[pos].start == start: + return msg_cache[pos] + elif pos > 0: + pos = pos - 1 + if msg_cache[pos].start <= start and msg_cache[pos].start + msg_cache[pos].target_len > start: + return msg_cache[pos] + return None + return None + + def get_media_chunk(self, msg: types.Message, start: int, lru: bool = True) -> Optional['TgFileSystemClient.MediaChunkHolder']: + res = self._get_media_chunk_cache(msg, start) + if res is None: + return None + if lru: + self.chunk_lru.move_to_end(res.chunk_id) + return res + + def set_media_chunk(self, chunk: 'TgFileSystemClient.MediaChunkHolder') -> None: + if self.chunk_cache.get(chunk.chat_id) is None: + self.chunk_cache[chunk.chat_id] = {} + if self.chunk_cache[chunk.chat_id].get(chunk.msg_id) is None: + self.chunk_cache[chunk.chat_id][chunk.msg_id] = [] + chunk.chunk_id = self.unique_chunk_id + self.unique_chunk_id += 1 + bisect.insort(self.chunk_cache[chunk.chat_id][chunk.msg_id], chunk) + self.chunk_lru[chunk.chunk_id] = chunk + self.current_cache_size += chunk.target_len + while self.current_cache_size > self.MAX_CACHE_SIZE: + dummy = self.chunk_lru.popitem(last=False) + pop_chunk = dummy[1] + self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id].remove( + pop_chunk.start) + self.current_cache_size -= pop_chunk.target_len + if len(self.chunk_cache[pop_chunk.chat_id][pop_chunk.msg_id]) == 0: + self.chunk_cache[pop_chunk.chat_id].pop(pop_chunk.msg_id) + if len(self.chunk_cache[pop_chunk.chat_id]) == 0: + self.chunk_cache.pop(pop_chunk.chat_id) + + MAX_DOWNLOAD_ROUTINE = 4 + SINGLE_NET_CHUNK_SIZE = 256 * 1024 # 256kb + SINGLE_MEDIA_SIZE = 5 * 1024 * 1024 # 5mb api_id: int api_hash: str session_name: str proxy_param: dict[str, any] client: TelegramClient + media_chunk_manager: MediaChunkHolderManager dialogs_cache: Optional[hints.TotalList] = None + msg_cache: list[types.Message] = [] + download_routines: list[asyncio.Task] = [] + # task should: (task_id, callabledFunc) + task_queue: asyncio.Queue + task_id: int = 0 me: Union[types.User, types.InputPeerUser] def __init__(self, session_name: str, param: configParse.TgToFileSystemParameter) -> None: @@ -26,11 +183,16 @@ class TgFileSystemClient(object): 'addr': param.proxy.addr, 'port': param.proxy.port, } if param.proxy.enable else {} + self.task_queue = asyncio.Queue() self.client = TelegramClient( self.session_name, self.api_id, self.api_hash, proxy=self.proxy_param) + self.media_chunk_manager = TgFileSystemClient.MediaChunkHolderManager() def __del__(self) -> None: - self.client.disconnect() + if self.client.loop.is_running(): + self.client.loop.create_task(self.stop()) + else: + self.client.loop.run_until_complete(self.stop()) def __repr__(self) -> str: if not self.client.is_connected: @@ -65,16 +227,33 @@ class TgFileSystemClient(object): return self.client.is_connected() and self.me is not None async def start(self) -> None: + if self.is_valid(): + return if not self.client.is_connected(): await self.client.connect() self.me = await self.client.get_me() if self.me is None: raise RuntimeError( f"The {self.session_name} Client Does Not Login") + for _ in range(self.MAX_DOWNLOAD_ROUTINE): + download_rt = self.client.loop.create_task( + self._download_routine_handler()) + self.download_routines.append(download_rt) async def stop(self) -> None: + await self.client.loop.create_task(self._cancel_tasks()) + while not self.task_queue.empty(): + self.task_queue.get_nowait() + self.task_queue.task_done() await self.client.disconnect() + async def _cancel_tasks(self) -> None: + for t in self.download_routines: + try: + t.cancel() + except Exception as err: + print(f"{err=}") + @_acall_before_check async def get_message(self, chat_id: int, msg_id: int) -> types.Message: msg = await self.client.get_messages(chat_id, ids=msg_id) @@ -82,16 +261,18 @@ class TgFileSystemClient(object): @_acall_before_check async def get_dialogs(self, limit: int = 10, offset: int = 0, refresh: bool = False) -> hints.TotalList: - def _to_json(item) -> str: - return json.dumps({"id": item.id, "is_channel": item.is_channel, - "is_group": item.is_group, "is_user": item.is_user, "name": item.name, }) if self.dialogs_cache is not None and refresh is False: return self.dialogs_cache[offset:offset+limit] self.dialogs_cache = await self.client.get_dialogs() - for item in self.dialogs_cache: - item.to_json = _to_json return self.dialogs_cache[offset:offset+limit] + async def _download_routine_handler(self) -> None: + while self.client.is_connected(): + task = await self.task_queue.get() + await task[1] + self.task_queue.task_done() + print("task quit!!!!") + async def _get_offset_msg_id(self, chat_id: int, offset: int) -> int: if offset != 0: begin = await self.client.get_messages(chat_id, limit=1) @@ -115,7 +296,11 @@ class TgFileSystemClient(object): return res_list # search by myself res_list = hints.TotalList() + cnt = 0 async for msg in self.client.iter_messages(chat_id, offset_id=offset): + if cnt >= 10_000: + break + cnt += 1 if msg.text.find(search_word) == -1 and apiutils.get_message_media_name(msg).find(search_word) == -1: continue res_list.append(msg) @@ -123,11 +308,104 @@ class TgFileSystemClient(object): break return res_list + async def _download_media_chunk(self, msg: types.Message, media_holder: MediaChunkHolder) -> None: + try: + offset = media_holder.start + media_holder.length + target_size = media_holder.target_len - media_holder.length + remain_size = target_size + async for chunk in self.client.iter_download(msg, offset=offset, chunk_size=self.SINGLE_NET_CHUNK_SIZE): + if not isinstance(chunk, bytes): + chunk = chunk.tobytes() + remain_size -= len(chunk) + if remain_size <= 0: + media_holder.append_chunk_mem( + chunk[:len(chunk)+remain_size]) + break + media_holder.append_chunk_mem(chunk) + except Exception as err: + print( + f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder}") + finally: + pass + # print( + # f"downloaded chunk:{time.time()}.{offset=},{target_size=},{media_holder}") + + async def streaming_get_iter(self, msg: types.Message, start: int, end: int): + try: + # print( + # f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]") + self.task_id += 1 + cur_task_id = self.task_id + pos = start + while pos <= end: + cache_chunk = self.media_chunk_manager.get_media_chunk( + msg, pos) + if cache_chunk is None: + # post download task + # align pos download task + file_size = msg.media.document.size + # align_pos = pos // self.SINGLE_MEDIA_SIZE * self.SINGLE_MEDIA_SIZE + align_pos = pos + align_size = min(self.SINGLE_MEDIA_SIZE, + file_size - align_pos) + holder = TgFileSystemClient.MediaChunkHolder( + msg.chat_id, msg.id, align_pos, align_size) + self.media_chunk_manager.set_media_chunk(holder) + await self.task_queue.put((cur_task_id, self._download_media_chunk(msg, holder))) + # while self.task_queue.qsize() < self.MAX_DOWNLOAD_ROUTINE and align_pos <= end: + # align_pos = align_pos + align_size + # align_size = min(self.SINGLE_MEDIA_SIZE, + # file_size - align_pos) + # cache_chunk = self.media_chunk_manager.get_media_chunk( + # msg, align_pos, lru=False) + # if cache_chunk is not None: + # break + # holder = TgFileSystemClient.MediaChunkHolder( + # msg.chat_id, msg.id, align_pos, align_size) + # self.media_chunk_manager.set_media_chunk(holder) + # await self.task_queue.put((cur_task_id, self._download_media_chunk(msg, holder))) + elif not cache_chunk.is_completed(): + # yield return completed part + # await untill completed or pos > end + while pos < cache_chunk.start + cache_chunk.target_len and pos <= end: + offset = pos - cache_chunk.start + if offset >= cache_chunk.length: + await cache_chunk.wait_chunk_update() + continue + need_len = min(cache_chunk.length - + offset, end - pos + 1) + # print( + # f"return missed {need_len} bytes:[{pos}:{pos+need_len}].{cache_chunk=}") + pos = pos + need_len + yield cache_chunk.mem[offset:offset+need_len] + else: + offset = pos - cache_chunk.start + if offset >= cache_chunk.length: + raise RuntimeError( + f"lru cache missed!{pos=},{cache_chunk=}") + need_len = min(cache_chunk.length - offset, end - pos + 1) + # print( + # f"return hited {need_len} bytes:[{pos}:{pos+need_len}].{cache_chunk=}") + pos = pos + need_len + yield cache_chunk.mem[offset:offset+need_len] + except Exception as err: + traceback.print_exc() + print(f"stream iter:{err=}") + finally: + async def _cancel_task_by_id(task_id: int): + for _ in range(self.task_queue.qsize()): + task = self.task_queue.get_nowait() + self.task_queue.task_done() + if task[0] != task_id: + self.task_queue.put_nowait(task) + await self.client.loop.create_task(_cancel_task_by_id(cur_task_id)) + # print("yield quit") + def __enter__(self): - raise NotImplemented + raise NotImplementedError def __exit__(self): - raise NotImplemented + raise NotImplementedError async def __aenter__(self): await self.start() diff --git a/UserManager.py b/UserManager.py index a83a6b1..bdf7312 100644 --- a/UserManager.py +++ b/UserManager.py @@ -32,13 +32,13 @@ class UserManager(object): self.con.close() def update_user(self) -> None: - raise NotImplemented + raise NotImplementedError def update_message(self) -> None: - raise NotImplemented + raise NotImplementedError def get_user_info() -> None: - raise NotImplemented + raise NotImplementedError def _table_has_been_inited(self) -> bool: res = self.cur.execute("SELECT name FROM sqlite_master") diff --git a/apiutils.py b/apiutils.py index 74804fa..ece7e3d 100644 --- a/apiutils.py +++ b/apiutils.py @@ -38,12 +38,14 @@ def timeit(func): if configParse.get_TgToFileSystemParameter().base.timeit_enable: @wraps(func) def timeit_wrapper(*args, **kwargs): + print( + f'Function called {func.__name__}{args} {kwargs}') start_time = time.perf_counter() result = func(*args, **kwargs) end_time = time.perf_counter() total_time = end_time - start_time print( - f'Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') + f'Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') return result return timeit_wrapper return func @@ -53,12 +55,14 @@ def atimeit(func): if configParse.get_TgToFileSystemParameter().base.timeit_enable: @wraps(func) async def timeit_wrapper(*args, **kwargs): + print( + f'AFunction called {func.__name__}{args} {kwargs}') start_time = time.perf_counter() result = await func(*args, **kwargs) end_time = time.perf_counter() total_time = end_time - start_time print( - f'AFunction {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') + f'AFunction quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') return result return timeit_wrapper return func diff --git a/configParse.py b/configParse.py index 3c0a2f4..075a68a 100644 --- a/configParse.py +++ b/configParse.py @@ -1,4 +1,6 @@ import toml + +import functools from pydantic import BaseModel @@ -21,10 +23,8 @@ class TgToFileSystemParameter(BaseModel): port: int proxy: TgProxyParameter -__cache_res = None +@functools.lru_cache def get_TgToFileSystemParameter(path: str = "./config.toml", force_reload: bool = False) -> TgToFileSystemParameter: - global __cache_res - if __cache_res is not None and not force_reload: - return __cache_res - __cache_res = TgToFileSystemParameter.model_validate(toml.load(path)) - return __cache_res + if force_reload: + get_TgToFileSystemParameter.cache_clear() + return TgToFileSystemParameter.model_validate(toml.load(path)) diff --git a/requirements.txt b/requirements.txt index c5f835a..2dfa0c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ telethon # python-socks[asyncio] fastapi uvicorn[standard] +streamlit +qrcode diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..0b433f2 --- /dev/null +++ b/run.sh @@ -0,0 +1,3 @@ +source ./.venv/bin/activate +uvicorn start:app --port 7777 --host="0.0.0.0" +# python ./api.py diff --git a/start.py b/start.py index ce4e215..2f87386 100644 --- a/start.py +++ b/start.py @@ -1,5 +1,4 @@ import asyncio -import time import json import uvicorn @@ -16,6 +15,7 @@ from TgFileSystemClientManager import TgFileSystemClientManager from TgFileSystemClient import TgFileSystemClient clients_mgr: TgFileSystemClientManager = None +web_front_task = None @asynccontextmanager @@ -23,6 +23,18 @@ async def lifespan(app: FastAPI): global clients_mgr param = configParse.get_TgToFileSystemParameter() clients_mgr = TgFileSystemClientManager(param) + + async def run_web_server(): + cmd = "streamlit run ./web_streamlit.py --server.port 2000" + proc = await asyncio.create_subprocess_shell(cmd, stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + stdout, stderr = await proc.communicate() + print(f'[{cmd!r} exited with {proc.returncode}]') + if stdout: + print(f'[stdout]\n{stdout.decode()}') + if stderr: + print(f'[stderr]\n{stderr.decode()}') + web_front_task = asyncio.create_task(run_web_server()) yield app = FastAPI(lifespan=lifespan) @@ -36,6 +48,12 @@ app.add_middleware( ) +@app.post("/tg/api/v1/file/login") +@apiutils.atimeit +async def login_new_tg_file_client(): + raise NotImplementedError + + class TgToFileListRequestBody(BaseModel): token: str search: str = "" @@ -75,18 +93,12 @@ async def get_tg_file_list(body: TgToFileListRequestBody): return Response(json.dumps(response_dict), status_code=status.HTTP_200_OK) except Exception as err: print(f"{err=}") - return Response(f"{err=}", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) @app.get("/tg/api/v1/file/msg") @apiutils.atimeit async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request): - async def get_msg_media_range_requests(client: TgFileSystemClient, msg: types.Message, start: int, end: int): - MAX_CHUNK_SIZE = 1024 * 1024 - pos = start - async for chunk in client.client.iter_download(msg, offset=pos, chunk_size=min(end + 1 - pos, MAX_CHUNK_SIZE)): - pos = pos + len(chunk) - yield chunk.tobytes() msg_id = mid chat_id = cid headers = { @@ -109,26 +121,28 @@ async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Requ status_code = status.HTTP_200_OK mime_type = msg.media.document.mime_type headers["content-type"] = mime_type + # headers["content-length"] = str(file_size) file_name = apiutils.get_message_media_name(msg) if file_name == "": maybe_file_type = mime_type.split("/")[-1] file_name = f"{chat_id}.{msg_id}.{maybe_file_type}" - headers["Content-Disposition"] = f'Content-Disposition: inline; filename="{file_name}"' + headers[ + "Content-Disposition"] = f'Content-Disposition: inline; filename="{file_name.encode("utf-8")}"' if range_header is not None: start, end = apiutils.get_range_header(range_header, file_size) size = end - start + 1 - headers["content-length"] = str(size) + # headers["content-length"] = str(size) headers["content-range"] = f"bytes {start}-{end}/{file_size}" status_code = status.HTTP_206_PARTIAL_CONTENT return StreamingResponse( - get_msg_media_range_requests(client, msg, start, end), + client.streaming_get_iter(msg, start, end), headers=headers, status_code=status_code, ) except Exception as err: print(f"{err=}") - return Response(f"{err=}", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) if __name__ == "__main__": diff --git a/test.py b/test.py index a983274..6adff0d 100644 --- a/test.py +++ b/test.py @@ -1,3 +1,6 @@ +import time +import asyncio + from telethon import TelegramClient import configParse @@ -6,15 +9,22 @@ param = configParse.get_TgToFileSystemParameter() # Remember to use your own values from my.telegram.org! api_id = param.tgApi.api_id api_hash = param.tgApi.api_hash -client = TelegramClient('anon', api_id, api_hash, proxy={ +client1 = TelegramClient('anon', api_id, api_hash, proxy={ + # 'proxy_type': 'socks5', + # 'addr': '172.25.32.1', + # 'port': 7890, +}) +client2 = TelegramClient('anon1', api_id, api_hash, proxy={ 'proxy_type': 'socks5', 'addr': '172.25.32.1', 'port': 7890, }) +# client.session.set_dc(2, "91.108.56.198", 443) # client = TelegramClient('anon', api_id, api_hash, proxy=("socks5", '127.0.0.1', 7890)) # proxy=("socks5", '127.0.0.1', 4444) -async def main(): + +async def main(client): # Getting information about yourself me = await client.get_me() @@ -71,7 +81,23 @@ async def main(): # print(message.stringify()) # print(message.to_json()) # print(message.to_dict()) - # await client.download_media(message) + async def download_task(s: int): + last_p = 0 + last_t = time.time() + def progress_callback(p, file_size): + nonlocal last_p, last_t + t = time.time() + bd = p-last_p + td = t-last_t + print(f"{s}:avg:{bd/td/1024:>10.2f}kbps,{p/1024/1024:>7.2f}/{file_size/1024/1024:>7.2f}/{p/file_size:>5.2%}") + last_p = p + last_t = time.time() + await client.download_media(message, progress_callback=progress_callback ) + t_list = [] + # for i in range(4): + # ti = client.loop.create_task(download_task(i)) + # t_list.append(ti) + await asyncio.gather(*t_list) # You can download media from messages, too! # The method will return the path where the file was saved. @@ -79,8 +105,16 @@ async def main(): # path = await message.download_media() # print('File saved to', path) # printed after download is done -with client: - client.loop.run_until_complete(main()) +# with client: +# client.loop.run_until_complete(main()) +try: + client1.start() + # client2.start() + client1.loop.run_until_complete(main(client1)) + # client2.loop.run_until_complete(main(client2)) +finally: + client1.disconnect() + # client2.disconnect() async def start_tg_client(param: configParse.TgToFileSystemParameter):