diff --git a/.gitignore b/.gitignore index 5b89af8..200e571 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,15 @@ __pycache__ .venv +.idea .vscode *.session *.session-journal +*-journal *.session.sql *.toml *.db *.service log -cacheTest +cache_media tmp logs diff --git a/backend/MediaCacheManager.py b/backend/MediaCacheManager.py index 67260ae..6f007ff 100644 --- a/backend/MediaCacheManager.py +++ b/backend/MediaCacheManager.py @@ -1,3 +1,4 @@ +import os import functools import logging import bisect @@ -18,6 +19,11 @@ class MediaChunkHolder(object): waiters: collections.deque[asyncio.Future] requester: list[Request] = [] chunk_id: int = 0 + unique_id: str = "" + + @staticmethod + def generate_id(chat_id: int, msg_id: int, start: int) -> str: + return f"{chat_id}:{msg_id}:{start}" def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None: self.chat_id = chat_id @@ -27,6 +33,7 @@ class MediaChunkHolder(object): self.mem = mem or bytes() self.length = len(self.mem) self.waiters = collections.deque() + self.unique_id = MediaChunkHolder.generate_id(chat_id, msg_id, start) def __repr__(self) -> str: return f"MediaChunk,start:{self.start},len:{self.length}" @@ -53,7 +60,7 @@ class MediaChunkHolder(object): def is_completed(self) -> bool: return self.length >= self.target_len - + def notify_waiters(self) -> None: while self.waiters: waiter = self.waiters.popleft() @@ -114,8 +121,11 @@ class MediaChunkHolderManager(object): unique_chunk_id: int = 0 chunk_lru: collections.OrderedDict[int, MediaChunkHolder] + disk_chunk_cache: diskcache.Cache + def __init__(self) -> None: self.chunk_lru = collections.OrderedDict() + self.disk_chunk_cache = diskcache.Cache(f"{os.path.dirname(__file__)}/cache_media", size_limit=2**30) def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]: chat_cache = self.chunk_cache.get(msg.chat_id) @@ -160,14 +170,8 @@ class MediaChunkHolderManager(object): return res def set_media_chunk(self, chunk: MediaChunkHolder) -> None: - cache_chat = self.chunk_cache.get(chunk.chat_id) - if cache_chat is None: - self.chunk_cache[chunk.chat_id] = {} - cache_chat = self.chunk_cache[chunk.chat_id] - cache_msg = cache_chat.get(chunk.msg_id) - if cache_msg is None: - cache_chat[chunk.msg_id] = [] - cache_msg = cache_chat[chunk.msg_id] + cache_chat = self.chunk_cache.setdefault(chunk.chat_id, {}) + cache_msg = cache_chat.setdefault(chunk.msg_id, []) chunk.chunk_id = self.unique_chunk_id self.unique_chunk_id += 1 bisect.insort(cache_msg, chunk) @@ -283,4 +287,3 @@ class MediaBlockHolderManager(object): def __init__(self, limit_size: int = DEFAULT_MAX_CACHE_SIZE, dir: str = 'cache') -> None: pass - diff --git a/backend/TgFileSystemClient.py b/backend/TgFileSystemClient.py index 4112d36..a489a7e 100644 --- a/backend/TgFileSystemClient.py +++ b/backend/TgFileSystemClient.py @@ -10,6 +10,7 @@ import logging from typing import Union, Optional from telethon import TelegramClient, types, hints, events +from telethon.custom import QRLogin from fastapi import Request import configParse @@ -19,6 +20,7 @@ from backend.MediaCacheManager import MediaChunkHolder, MediaChunkHolderManager logger = logging.getLogger(__file__.split("/")[-1]) + class TgFileSystemClient(object): MAX_WORKER_ROUTINE = 4 SINGLE_NET_CHUNK_SIZE = 512 * 1024 # 512kb @@ -32,6 +34,8 @@ class TgFileSystemClient(object): dialogs_cache: Optional[hints.TotalList] = None msg_cache: list[types.Message] = [] worker_routines: list[asyncio.Task] = [] + qr_login: QRLogin | None = None + login_task: asyncio.Task | None = None # task should: (task_id, callabledFunc) task_queue: asyncio.Queue task_id: int = 0 @@ -39,19 +43,35 @@ class TgFileSystemClient(object): # client config client_param: configParse.TgToFileSystemParameter.ClientConfigPatameter - def __init__(self, session_name: str, param: configParse.TgToFileSystemParameter, db: UserManager) -> None: + def __init__( + self, + session_name: str, + param: configParse.TgToFileSystemParameter, + db: UserManager, + ) -> None: self.api_id = param.tgApi.api_id self.api_hash = param.tgApi.api_hash self.session_name = session_name - self.proxy_param = { - 'proxy_type': param.proxy.proxy_type, - 'addr': param.proxy.addr, - 'port': param.proxy.port, - } if param.proxy.enable else {} - self.client_param = next((client_param for client_param in param.clients if client_param.token == session_name), configParse.TgToFileSystemParameter.ClientConfigPatameter()) + self.proxy_param = ( + { + "proxy_type": param.proxy.proxy_type, + "addr": param.proxy.addr, + "port": param.proxy.port, + } + if param.proxy.enable + else {} + ) + self.client_param = next( + (client_param for client_param in param.clients if client_param.token == session_name), + configParse.TgToFileSystemParameter.ClientConfigPatameter(), + ) self.task_queue = asyncio.Queue() self.client = TelegramClient( - f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param) + f"{os.path.dirname(__file__)}/db/{self.session_name}.session", + self.api_id, + self.api_hash, + proxy=self.proxy_param, + ) self.media_chunk_manager = MediaChunkHolderManager() self.db = db @@ -72,6 +92,7 @@ class TgFileSystemClient(object): raise RuntimeError("Client does not run.") result = func(self, *args, **kwargs) return result + return call_check_wrapper def _acheck_before_call(func): @@ -80,6 +101,7 @@ class TgFileSystemClient(object): raise RuntimeError("Client does not run.") result = await func(self, *args, **kwargs) return result + return call_check_wrapper @_check_before_call @@ -100,6 +122,28 @@ class TgFileSystemClient(object): msg: types.Message = event.message self.db.insert_by_message(self.me, msg) + async def login(self, mode: Union["phone", "qrcode"] = "qrcode") -> str: + if self.is_valid(): + return "" + if mode == "phone": + raise NotImplementedError + if self.qr_login is not None: + return self.qr_login.url + self.qr_login = await self.client.qr_login() + + async def wait_for_qr_login(): + try: + await self.qr_login.wait() + await self.start() + except Exception as err: + logger.warning(f"wait for login, {err=}, {traceback.format_exc()}") + finally: + self.login_task = None + self.qr_login = None + + self.login_task = self.client.loop.create_task(wait_for_qr_login()) + return self.qr_login.url + async def start(self) -> None: if self.is_valid(): return @@ -107,11 +151,9 @@ class TgFileSystemClient(object): await self.client.connect() self.me = await self.client.get_me() if self.me is None: - raise RuntimeError( - f"The {self.session_name} Client Does Not Login") + raise RuntimeError(f"The {self.session_name} Client Does Not Login") for _ in range(self.MAX_WORKER_ROUTINE): - worker_routine = self.client.loop.create_task( - self._worker_routine_handler()) + worker_routine = self.client.loop.create_task(self._worker_routine_handler()) self.worker_routines.append(worker_routine) if len(self.client_param.whitelist_chat) > 0: self._register_update_event(from_users=self.client_param.whitelist_chat) @@ -162,7 +204,6 @@ class TgFileSystemClient(object): async for msg in self.client.iter_messages(chat_id): self.db.insert_by_message(self.me, msg) logger.info(f"{chat_id} quit cache task.") - @_acheck_before_call async def get_message(self, chat_id: int, msg_id: int) -> types.Message: @@ -172,9 +213,9 @@ class TgFileSystemClient(object): @_acheck_before_call async def get_dialogs(self, limit: int = 10, offset: int = 0, refresh: bool = False) -> hints.TotalList: if self.dialogs_cache is not None and refresh is False: - return self.dialogs_cache[offset:offset+limit] + return self.dialogs_cache[offset : offset + limit] self.dialogs_cache = await self.client.get_dialogs() - return self.dialogs_cache[offset:offset+limit] + return self.dialogs_cache[offset : offset + limit] async def _worker_routine_handler(self) -> None: while self.client.is_connected(): @@ -186,7 +227,7 @@ class TgFileSystemClient(object): logger.error(traceback.format_exc()) finally: self.task_queue.task_done() - + def _get_unique_task_id(self) -> int: self.task_id += 1 return self.task_id @@ -207,7 +248,15 @@ class TgFileSystemClient(object): return res_list @_acheck_before_call - async def get_messages_by_search(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inner_search: bool = False, ignore_case: bool = False) -> hints.TotalList: + async def get_messages_by_search( + self, + chat_id: int, + search_word: str, + limit: int = 10, + offset: int = 0, + inner_search: bool = False, + ignore_case: bool = False, + ) -> hints.TotalList: offset = await self._get_offset_msg_id(chat_id, offset) if inner_search: res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word) @@ -225,11 +274,26 @@ class TgFileSystemClient(object): if len(res_list) >= limit: break return res_list - - async def get_messages_by_search_db(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]: + + async def get_messages_by_search_db( + self, + chat_id: int, + search_word: str, + limit: int = 10, + offset: int = 0, + inc: bool = False, + ignore_case: bool = False, + ) -> list[any]: if chat_id not in self.client_param.whitelist_chat: return [] - res = self.db.get_msg_by_chat_id_and_keyword(chat_id, search_word, limit=limit, offset=offset, inc=inc, ignore_case=ignore_case) + res = self.db.get_msg_by_chat_id_and_keyword( + chat_id, + search_word, + limit=limit, + offset=offset, + inc=inc, + ignore_case=ignore_case, + ) res = [self.db.get_column_msg_js(v) for v in res] return res @@ -243,43 +307,38 @@ class TgFileSystemClient(object): chunk = chunk.tobytes() remain_size -= len(chunk) if remain_size <= 0: - media_holder.append_chunk_mem( - chunk[:len(chunk)+remain_size]) + media_holder.append_chunk_mem(chunk[: len(chunk) + remain_size]) else: media_holder.append_chunk_mem(chunk) if media_holder.is_completed(): break if await media_holder.is_disconneted(): - raise asyncio.CancelledError("all requester canceled.") + raise asyncio.CancelledError("all requester canceled.") except asyncio.CancelledError as err: logger.info(f"cancel holder:{media_holder}") self.media_chunk_manager.cancel_media_chunk(media_holder) except Exception as err: logger.error( - f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}") + f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}" + ) finally: - logger.debug( - f"downloaded chunk:{time.time()}.{offset=},{target_size=},{media_holder}") + logger.debug(f"downloaded chunk:{time.time()}.{offset=},{target_size=},{media_holder}") async def streaming_get_iter(self, msg: types.Message, start: int, end: int, req: Request): try: - logger.debug( - f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]") + logger.debug(f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]") cur_task_id = self._get_unique_task_id() pos = start while not await req.is_disconnected() and pos <= end: - cache_chunk = self.media_chunk_manager.get_media_chunk( - msg, pos) + cache_chunk = self.media_chunk_manager.get_media_chunk(msg, pos) if cache_chunk is None: # post download task # align pos download task file_size = msg.media.document.size # align_pos = pos // self.SINGLE_MEDIA_SIZE * self.SINGLE_MEDIA_SIZE align_pos = pos - align_size = min(self.SINGLE_MEDIA_SIZE, - file_size - align_pos) - holder = MediaChunkHolder( - msg.chat_id, msg.id, align_pos, align_size) + align_size = min(self.SINGLE_MEDIA_SIZE, file_size - align_pos) + holder = MediaChunkHolder(msg.chat_id, msg.id, align_pos, align_size) holder.add_chunk_requester(req) self.media_chunk_manager.set_media_chunk(holder) self.task_queue.put_nowait((cur_task_id, self._download_media_chunk(msg, holder))) @@ -294,28 +353,28 @@ class TgFileSystemClient(object): if offset >= cache_chunk.length: await cache_chunk.wait_chunk_update() continue - need_len = min(cache_chunk.length - - offset, end - pos + 1) + need_len = min(cache_chunk.length - offset, end - pos + 1) pos = pos + need_len - yield cache_chunk.mem[offset:offset+need_len] + yield cache_chunk.mem[offset : offset + need_len] else: offset = pos - cache_chunk.start if offset >= cache_chunk.length: - raise RuntimeError( - f"lru cache missed!{pos=},{cache_chunk=}") + raise RuntimeError(f"lru cache missed!{pos=},{cache_chunk=}") need_len = min(cache_chunk.length - offset, end - pos + 1) pos = pos + need_len - yield cache_chunk.mem[offset:offset+need_len] + yield cache_chunk.mem[offset : offset + need_len] except Exception as err: logger.error(f"stream iter:{err=}") logger.error(traceback.format_exc()) finally: + async def _cancel_task_by_id(task_id: int): for _ in range(self.task_queue.qsize()): task = self.task_queue.get_nowait() self.task_queue.task_done() if task[0] != task_id: self.task_queue.put_nowait(task) + await self.client.loop.create_task(_cancel_task_by_id(cur_task_id)) logger.debug(f"yield quit,{msg.chat_id=},{msg.id=},[{start}:{end}]") diff --git a/backend/TgFileSystemClientManager.py b/backend/TgFileSystemClientManager.py index 0bbf08f..71cbfba 100644 --- a/backend/TgFileSystemClientManager.py +++ b/backend/TgFileSystemClientManager.py @@ -3,6 +3,7 @@ import asyncio import time import hashlib import os +import traceback import logging from backend.TgFileSystemClient import TgFileSystemClient @@ -11,14 +12,17 @@ import configParse logger = logging.getLogger(__file__.split("/")[-1]) + class TgFileSystemClientManager(object): MAX_MANAGE_CLIENTS: int = 10 + is_init: asyncio.Future param: configParse.TgToFileSystemParameter clients: dict[str, TgFileSystemClient] = {} def __init__(self, param: configParse.TgToFileSystemParameter) -> None: self.param = param self.db = UserManager() + self.is_init = asyncio.Future() self.loop = asyncio.get_running_loop() if self.loop.is_running(): self.loop.create_task(self._start_clients()) @@ -27,22 +31,42 @@ class TgFileSystemClientManager(object): def __del__(self) -> None: self.clients.clear() - + async def _start_clients(self) -> None: # init cache clients for client_config in self.param.clients: client = self.create_client(client_id=client_config.token) - if not client.is_valid(): - await client.start() self._register_client(client) + for _, client in self.clients.items(): + try: + if not client.is_valid(): + await client.start() + except Exception as err: + logger.warning(f"start client: {err=}, {traceback.format_exc()}") + self.is_init.set_result(True) + + async def get_status(self) -> dict[str, any]: + clients_status = [ + { + "status": client.is_valid(), + } + for _, client in self.clients.items() + ] + return {"init": await self.is_init, "clients": clients_status} + + async def login_clients(self) -> str: + for _, client in self.clients.items(): + login_url = await client.login() + if login_url != "": + return login_url + return "" def check_client_session_exist(self, client_id: str) -> bool: session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session" return os.path.isfile(session_db_file) def generate_client_id(self) -> str: - return hashlib.md5( - (str(time.perf_counter()) + self.param.base.salt).encode('utf-8')).hexdigest() + return hashlib.md5((str(time.perf_counter()) + self.param.base.salt).encode("utf-8")).hexdigest() def create_client(self, client_id: str = None) -> TgFileSystemClient: if client_id is None: @@ -61,7 +85,7 @@ class TgFileSystemClientManager(object): def get_client(self, client_id: str) -> TgFileSystemClient: client = self.clients.get(client_id) return client - + async def get_client_force(self, client_id: str) -> TgFileSystemClient: client = self.get_client(client_id) if client is None: @@ -72,4 +96,3 @@ class TgFileSystemClientManager(object): await client.start() self._register_client(client) return client - diff --git a/backend/UserManager.py b/backend/UserManager.py index 3c2c5d3..d8612d9 100644 --- a/backend/UserManager.py +++ b/backend/UserManager.py @@ -9,6 +9,7 @@ from telethon import types logger = logging.getLogger(__file__.split("/")[-1]) + class UserUpdateParam(BaseModel): client_id: str username: str @@ -31,6 +32,8 @@ class MessageUpdateParam(BaseModel): class UserManager(object): def __init__(self) -> None: + if not os.path.exists(os.path.dirname(__file__) + "/db"): + os.mkdir(os.path.dirname(__file__) + "/db") self.con = sqlite3.connect(f"{os.path.dirname(__file__)}/db/user.db") self.cur = self.con.cursor() if not self._table_has_been_inited(): @@ -45,7 +48,7 @@ class UserManager(object): def update_message(self) -> None: raise NotImplementedError - + def generate_unique_id_by_msg(self, me: types.User, msg: types.Message) -> str: user_id = me.id chat_id = msg.chat_id @@ -55,10 +58,20 @@ class UserManager(object): def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]: res = self.cur.execute( - "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC", (chat_id,)) + "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC", + (chat_id,), + ) return res.fetchall() - def get_msg_by_chat_id_and_keyword(self, chat_id: int, keyword: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]: + def get_msg_by_chat_id_and_keyword( + self, + chat_id: int, + keyword: str, + limit: int = 10, + offset: int = 0, + inc: bool = False, + ignore_case: bool = False, + ) -> list[any]: keyword_condition = "msg_ctx LIKE '%{key}%' OR file_name LIKE '%{key}%'" if ignore_case: keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')" @@ -70,17 +83,23 @@ class UserManager(object): def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]: res = self.cur.execute( - "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time LIMIT 1", (chat_id,)) + "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time LIMIT 1", + (chat_id,), + ) return res.fetchall() def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]: res = self.cur.execute( - "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC LIMIT 1", (chat_id,)) + "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC LIMIT 1", + (chat_id,), + ) return res.fetchall() def get_msg_by_unique_id(self, unique_id: str) -> list[any]: res = self.cur.execute( - "SELECT * FROM message WHERE unique_id = ? ORDER BY date_time DESC LIMIT 1", (unique_id,)) + "SELECT * FROM message WHERE unique_id = ? ORDER BY date_time DESC LIMIT 1", + (unique_id,), + ) return res.fetchall() @unique @@ -128,8 +147,18 @@ class UserManager(object): msg_type = UserManager.MessageTypeEnum.FILE.value except Exception as err: logger.error(f"{err=}") - insert_data = (unique_id, user_id, chat_id, msg_id, - msg_type, msg_ctx, mime_type, file_name, msg_js, date_time) + insert_data = ( + unique_id, + user_id, + chat_id, + msg_id, + msg_type, + msg_ctx, + mime_type, + file_name, + msg_js, + date_time, + ) execute_script = "INSERT INTO message (unique_id, user_id, chat_id, msg_id, msg_type, msg_ctx, mime_type, file_name, msg_js, date_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" try: self.cur.execute(execute_script, insert_data) @@ -175,11 +204,11 @@ class UserManager(object): def _first_runtime_run_once(self) -> None: if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='user'").fetchall()) == 0: - self.cur.execute( - "CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)") + self.cur.execute("CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)") if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0: self.cur.execute( - "CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx text, mime_type text, file_name text, msg_js text, date_time int NOT NULL)") + "CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx text, mime_type text, file_name text, msg_js text, date_time int NOT NULL)" + ) if __name__ == "__main__": diff --git a/backend/api.py b/backend/api.py index 8ba3cc8..16e26a3 100644 --- a/backend/api.py +++ b/backend/api.py @@ -185,10 +185,16 @@ async def get_tg_file_media(chat_id: int|str, msg_id: int, file_name: str, sign: return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) -@app.post("/tg/api/v1/client/login") +@app.get("/tg/api/v1/client/login") @apiutils.atimeit async def login_new_tg_file_client(): - raise NotImplementedError + url = await clients_mgr.login_clients() + return {"url": url} + + +@app.get("/tg/api/v1/client/status") +async def get_tg_file_client_status(request: Request): + return await clients_mgr.get_status() @app.get("/tg/api/v1/client/link_convert") diff --git a/frontend/api.py b/frontend/api.py deleted file mode 100644 index e69de29..0000000 diff --git a/frontend/home.py b/frontend/home.py index 884d50b..c2a66ff 100644 --- a/frontend/home.py +++ b/frontend/home.py @@ -1,171 +1,17 @@ -import sys -import os -import json - -sys.path.append(os.getcwd() + "/../") import streamlit as st -import qrcode -import pandas -import requests -import configParse -import utils +import remote_api as api -# qr = qrcode.make("https://www.baidu.com") -# st.image(qrcode.make("https://www.baidu.com").get_image()) -param = configParse.get_TgToFileSystemParameter() -background_server_url = f"{param.base.exposed_url}/tg/api/v1/file/search" +st.set_page_config(page_title="TgToolbox", page_icon="🕹️", layout="wide", initial_sidebar_state="collapsed") -st.set_page_config(page_title="TgToolbox", page_icon='🕹️', layout='wide') +backend_status = api.get_backend_client_status() +need_login = False -if 'page_index' not in st.session_state: - st.session_state.page_index = 1 -if 'force_skip' not in st.session_state: - st.session_state.force_skip = False +for v in backend_status["clients"]: + if not v["status"]: + need_login = True -if 'search_key' not in st.query_params: - st.query_params.search_key = "" -if 'is_order' not in st.query_params: - st.query_params.is_order = False -if 'search_res_limit' not in st.query_params: - st.query_params.search_res_limit = "10" - -@st.experimental_fragment -def search_container(): - st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key) - columns = st.columns([7, 1]) - with columns[0]: - st.query_params.search_res_limit = str(st.number_input( - "**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d")) - with columns[1]: - st.text("排序") - st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order)) - -search_container() - -search_clicked = st.button('Search', type='primary', use_container_width=True) -if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None): - st.stop() - -if not st.session_state.force_skip: - st.session_state.page_index = 1 -if st.session_state.force_skip: - st.session_state.force_skip = False - -@st.experimental_fragment -def do_search_req(): - search_limit = int(st.query_params.search_res_limit) - offset_index = (st.session_state.page_index - 1) * search_limit - is_order = utils.strtobool(st.query_params.is_order) - - req_body = { - "token": param.web.token, - "search": f"{st.query_params.search_key}", - "chat_id": param.web.chat_id[0], - "index": offset_index, - "length": search_limit, - "refresh": False, - "inner": False, - "inc": is_order, - } - - req = requests.post(background_server_url, data=json.dumps(req_body)) - if req.status_code != 200: - st.stop() - search_res = json.loads(req.content.decode("utf-8")) - - def page_switch_render(): - columns = st.columns(3) - with columns[0]: - if st.button("Prev", use_container_width=True): - st.session_state.page_index = st.session_state.page_index - 1 - st.session_state.page_index = max( - st.session_state.page_index, 1) - st.session_state.force_skip = True - st.rerun() - with columns[1]: - # st.text(f"{st.session_state.page_index}") - st.markdown( - f"
{st.session_state.page_index}
", unsafe_allow_html=True) - # st.markdown(f"", unsafe_allow_html=True) - with columns[2]: - if st.button("Next", use_container_width=True): - st.session_state.page_index = st.session_state.page_index + 1 - st.session_state.force_skip = True - st.rerun() - - def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str): - file_size_str = f"{file_size/1024/1024:.2f}MB" - container = st.container() - container_columns = container.columns([1, 99]) - - st.session_state.search_res_select_list[index] = container_columns[0].checkbox( - "search_res_checkbox_" + str(index), label_visibility='collapsed') - - expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} — *{file_size_str}*" - popover = container_columns[1].popover(expender_title, use_container_width=True) - popover_columns = popover.columns([1, 3, 1]) - if url: - popover_columns[0].video(url) - else: - popover_columns[0].video('./static/404.webm', format="video/webm") - popover_columns[1].markdown(f'{msg_ctx}') - popover_columns[1].markdown(f'**{file_name}**') - popover_columns[1].markdown(f'文件大小:*{file_size_str}*') - popover_columns[2].link_button('⬇️Download Link', url, use_container_width=True) - popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True) - - @st.experimental_fragment - def show_search_res(res: dict[str, any]): - sign_token = "" - try: - sign_token = res['client']['sign'] - except Exception as err: - pass - search_res_list = res.get('list') - if search_res_list is None or len(search_res_list) == 0: - st.info("No result") - page_switch_render() - st.stop() - st.session_state.search_res_select_list = [False] * len(search_res_list) - url_list = [] - for i in range(len(search_res_list)): - v = search_res_list[i] - msg_ctx= "" - file_name = None - file_size = 0 - download_url = "" - src_link = "" - try: - src_link = v['src_tg_link'] - msg_ctx = v['message'] - msg_id = str(v['id']) - doc = v['media']['document'] - file_size = doc['size'] - if doc is not None: - for attr in doc['attributes']: - file_name = attr.get('file_name') - if file_name != "" and file_name is not None: - break - if file_name == "" or file_name is None: - file_name = "Can not get file name" - download_url = v['download_url'] - download_url += f'?sign={sign_token}' - url_list.append(download_url) - except Exception as err: - msg_ctx = f"{err=}\r\n\r\n" + msg_ctx - media_file_res_container( - i, msg_ctx, file_name, file_size, download_url, src_link) - page_switch_render() - - show_text = "" - select_list = st.session_state.search_res_select_list - for i in range(len(select_list)): - if select_list[i]: - show_text = show_text + url_list[i] + '\n' - st.text_area("链接", value=show_text) - - show_search_res(search_res) - - -do_search_req() +if need_login: + import login +else: + import search diff --git a/frontend/login.py b/frontend/login.py new file mode 100644 index 0000000..20bdebe --- /dev/null +++ b/frontend/login.py @@ -0,0 +1,24 @@ +import sys +import os + +import streamlit as st +import qrcode + +sys.path.append(os.getcwd() + "/../") +import configParse +import utils +import remote_api as api + +url = api.login_client_by_qr_code_url() + +if url is None or url == "": + st.text("Something wrong, no login url got.") + st.stop() + +st.markdown("### Please scan the qr code by telegram client.") +qr = qrcode.make(url) +st.image(qr.get_image()) + +st.markdown("**Click the Refrash button if you have been scaned**") +if st.button("Refresh"): + st.rerun() diff --git a/frontend/remote_api.py b/frontend/remote_api.py new file mode 100644 index 0000000..89c2c97 --- /dev/null +++ b/frontend/remote_api.py @@ -0,0 +1,58 @@ +import sys +import os +import json +import logging + +import requests + +sys.path.append(os.getcwd() + "/../") +import configParse + +logger = logging.getLogger(__file__.split("/")[-1]) + +param = configParse.get_TgToFileSystemParameter() + +background_server_url = f"{param.base.exposed_url}" +search_api_route = "/tg/api/v1/file/search" +status_api_route = "/tg/api/v1/client/status" +login_api_route = "/tg/api/v1/client/login" + + +def login_client_by_qr_code_url() -> str: + request_url = background_server_url + login_api_route + response = requests.get(request_url) + if response.status_code != 200: + logger.warning(f"Could not login, err:{response.status_code}, {response.content.decode('utf-8')}") + return None + url_info = json.loads(response.content.decode("utf-8")) + return url_info.get("url") + + +def get_backend_client_status() -> dict[str, any]: + request_url = background_server_url + status_api_route + response = requests.get(request_url) + if response.status_code != 200: + logger.warning(f"get_status, backend is running? err:{response.status_code}, {response.content.decode('utf-8')}") + return None + return json.loads(response.content.decode("utf-8")) + + +def search_database_by_keyword(keyword: str, offset: int, limit: int, is_order: bool) -> list[any] | None: + request_url = background_server_url + search_api_route + req_body = { + "token": param.web.token, + "search": keyword, + "chat_id": param.web.chat_id[0], + "index": offset, + "length": limit, + "refresh": False, + "inner": False, + "inc": is_order, + } + + response = requests.post(request_url, data=json.dumps(req_body)) + if response.status_code != 200: + logger.warning(f"search_database_by_keyword err:{response.status_code}, {response.content.decode('utf-8')}") + return None + search_res = json.loads(response.content.decode("utf-8")) + return search_res diff --git a/frontend/search.py b/frontend/search.py new file mode 100644 index 0000000..a3c1273 --- /dev/null +++ b/frontend/search.py @@ -0,0 +1,151 @@ +import sys +import os + +import streamlit as st + +sys.path.append(os.getcwd() + "/../") +import configParse +import utils +import remote_api as api + +param = configParse.get_TgToFileSystemParameter() + +if 'page_index' not in st.session_state: + st.session_state.page_index = 1 +if 'force_skip' not in st.session_state: + st.session_state.force_skip = False + +if 'search_key' not in st.query_params: + st.query_params.search_key = "" +if 'is_order' not in st.query_params: + st.query_params.is_order = False +if 'search_res_limit' not in st.query_params: + st.query_params.search_res_limit = "10" + +@st.experimental_fragment +def search_container(): + st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key) + columns = st.columns([7, 1]) + with columns[0]: + st.query_params.search_res_limit = str(st.number_input( + "**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d")) + with columns[1]: + st.text("排序") + st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order)) + +search_container() + +search_clicked = st.button('Search', type='primary', use_container_width=True) +if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None): + st.stop() + +if not st.session_state.force_skip: + st.session_state.page_index = 1 +if st.session_state.force_skip: + st.session_state.force_skip = False + +@st.experimental_fragment +def do_search_req(): + search_limit = int(st.query_params.search_res_limit) + offset_index = (st.session_state.page_index - 1) * search_limit + is_order = utils.strtobool(st.query_params.is_order) + + search_res = api.search_database_by_keyword(st.query_params.search_key, offset_index, search_limit, is_order) + if search_res is None: + st.stop() + + def page_switch_render(): + columns = st.columns(3) + with columns[0]: + if st.button("Prev", use_container_width=True): + st.session_state.page_index = st.session_state.page_index - 1 + st.session_state.page_index = max( + st.session_state.page_index, 1) + st.session_state.force_skip = True + st.rerun() + with columns[1]: + # st.text(f"{st.session_state.page_index}") + st.markdown( + f"{st.session_state.page_index}
", unsafe_allow_html=True) + # st.markdown(f"", unsafe_allow_html=True) + with columns[2]: + if st.button("Next", use_container_width=True): + st.session_state.page_index = st.session_state.page_index + 1 + st.session_state.force_skip = True + st.rerun() + + def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str): + file_size_str = f"{file_size/1024/1024:.2f}MB" + container = st.container() + container_columns = container.columns([1, 99]) + + st.session_state.search_res_select_list[index] = container_columns[0].checkbox( + "search_res_checkbox_" + str(index), label_visibility='collapsed') + + expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} — *{file_size_str}*" + popover = container_columns[1].popover(expender_title, use_container_width=True) + popover_columns = popover.columns([1, 3, 1]) + if url: + popover_columns[0].video(url) + else: + popover_columns[0].video('./static/404.webm', format="video/webm") + popover_columns[1].markdown(f'{msg_ctx}') + popover_columns[1].markdown(f'**{file_name}**') + popover_columns[1].markdown(f'文件大小:*{file_size_str}*') + popover_columns[2].link_button('⬇️Download Link', url, use_container_width=True) + popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True) + + @st.experimental_fragment + def show_search_res(res: dict[str, any]): + search_res_list = res.get("list") + if search_res_list is None or len(search_res_list) == 0: + st.info("No result") + page_switch_render() + st.stop() + sign_token = "" + try: + sign_token = res['client']['sign'] + except Exception as err: + pass + st.session_state.search_res_select_list = [False] * len(search_res_list) + url_list = [] + for i in range(len(search_res_list)): + v = search_res_list[i] + msg_ctx= "" + file_name = None + file_size = 0 + download_url = "" + src_link = "" + try: + src_link = v['src_tg_link'] + msg_ctx = v['message'] + msg_id = str(v['id']) + doc = v['media']['document'] + file_size = doc['size'] + if doc is not None: + for attr in doc['attributes']: + file_name = attr.get('file_name') + if file_name != "" and file_name is not None: + break + if file_name == "" or file_name is None: + file_name = "Can not get file name" + download_url = v['download_url'] + download_url += f'?sign={sign_token}' + url_list.append(download_url) + except Exception as err: + msg_ctx = f"{err=}\r\n\r\n" + msg_ctx + media_file_res_container( + i, msg_ctx, file_name, file_size, download_url, src_link) + page_switch_render() + + show_text = "" + select_list = st.session_state.search_res_select_list + for i in range(len(select_list)): + if select_list[i]: + show_text = show_text + url_list[i] + '\n' + st.text_area("链接", value=show_text) + + show_search_res(search_res) + + +do_search_req() diff --git a/requirements.txt b/requirements.txt index 2dfa0c7..4576a5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ toml telethon # python-socks[asyncio] +diskcache fastapi uvicorn[standard] streamlit