diff --git a/.gitignore b/.gitignore index 48de330..c9ee82c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ .vscode *.session *.session-journal +*.session.sql *.toml *.db *.service diff --git a/backend/TgFileSystemClient.py b/backend/TgFileSystemClient.py index b547820..29394e8 100644 --- a/backend/TgFileSystemClient.py +++ b/backend/TgFileSystemClient.py @@ -188,11 +188,13 @@ class TgFileSystemClient(object): media_chunk_manager: MediaChunkHolderManager dialogs_cache: Optional[hints.TotalList] = None msg_cache: list[types.Message] = [] - download_routines: list[asyncio.Task] = [] + worker_routines: list[asyncio.Task] = [] # task should: (task_id, callabledFunc) task_queue: asyncio.Queue task_id: int = 0 me: Union[types.User, types.InputPeerUser] + # client config + client_param: configParse.TgToFileSystemParameter.ClientConfigPatameter def __init__(self, session_name: str, param: configParse.TgToFileSystemParameter, db: UserManager) -> None: self.api_id = param.tgApi.api_id @@ -203,6 +205,7 @@ class TgFileSystemClient(object): 'addr': param.proxy.addr, 'port': param.proxy.port, } if param.proxy.enable else {} + self.client_param = next((client_param for client_param in param.clients if client_param.token == session_name), configParse.TgToFileSystemParameter.ClientConfigPatameter()) self.task_queue = asyncio.Queue() self.client = TelegramClient( f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param) @@ -248,11 +251,11 @@ class TgFileSystemClient(object): return self.client.is_connected() and self.me is not None @_call_before_check - def _register_update_event(self) -> None: - @self.client.on(events.NewMessage(incoming=True, from_users=[666462447])) + def _register_update_event(self, from_users: list[int] = []) -> None: + @self.client.on(events.NewMessage(incoming=True, from_users=from_users)) async def _incoming_new_message_handler(event) -> None: msg: types.Message = event.message - print(f"message: {msg.to_json()}") + self.db.insert_by_message(self.me, msg) async def start(self) -> None: if self.is_valid(): @@ -266,8 +269,10 @@ class TgFileSystemClient(object): for _ in range(self.MAX_WORKER_ROUTINE): worker_routine = self.client.loop.create_task( self._worker_routine_handler()) - self.download_routines.append(worker_routine) - self._register_update_event() + self.worker_routines.append(worker_routine) + if len(self.client_param.whitelist_chat) > 0: + self._register_update_event(from_users=self.client_param.whitelist_chat) + await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat())) async def stop(self) -> None: await self.client.loop.create_task(self._cancel_tasks()) @@ -277,12 +282,34 @@ class TgFileSystemClient(object): await self.client.disconnect() async def _cancel_tasks(self) -> None: - for t in self.download_routines: + for t in self.worker_routines: try: t.cancel() except Exception as err: print(f"{err=}") + async def _cache_whitelist_chat(self): + for chat_id in self.client_param.whitelist_chat: + # update newest msg + newest_msg = self.db.get_newest_msg_by_chat_id(chat_id) + if len(newest_msg) > 0: + newest_msg = newest_msg[0] + async for msg in self.client.iter_messages(chat_id): + if msg.id <= self.db.get_column_msg_id(newest_msg): + break + self.db.insert_by_message(self.me, msg) + # update oldest msg + oldest_msg = self.db.get_oldest_msg_by_chat_id(chat_id) + if len(oldest_msg) > 0: + oldest_msg = oldest_msg[0] + offset = self.db.get_column_msg_id(oldest_msg) + async for msg in self.client.iter_messages(chat_id, offset_id=offset): + self.db.insert_by_message(self.me, msg) + else: + async for msg in self.client.iter_messages(chat_id): + self.db.insert_by_message(self.me, msg) + + @_acall_before_check async def get_message(self, chat_id: int, msg_id: int) -> types.Message: msg = await self.client.get_messages(chat_id, ids=msg_id) @@ -297,9 +324,17 @@ class TgFileSystemClient(object): async def _worker_routine_handler(self) -> None: while self.client.is_connected(): - task = await self.task_queue.get() - await task[1] - self.task_queue.task_done() + try: + task = await self.task_queue.get() + await task[1] + except Exception as err: + print(f"{err=}") + finally: + self.task_queue.task_done() + + def _get_unique_task_id(self) -> int: + self.task_id += 1 + return self.task_id async def _get_offset_msg_id(self, chat_id: int, offset: int) -> int: if offset != 0: @@ -317,7 +352,7 @@ class TgFileSystemClient(object): return res_list @_acall_before_check - async def get_messages_by_search(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inner_search: bool = False) -> hints.TotalList: + async def get_messages_by_search(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inner_search: bool = False, ignore_case: bool = False) -> hints.TotalList: offset = await self._get_offset_msg_id(chat_id, offset) if inner_search: res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word) @@ -326,7 +361,7 @@ class TgFileSystemClient(object): res_list = hints.TotalList() cnt = 0 async for msg in self.client.iter_messages(chat_id, offset_id=offset): - if cnt >= 10_000: + if cnt >= 1_000: break cnt += 1 if msg.text.find(search_word) == -1 and apiutils.get_message_media_name(msg).find(search_word) == -1: @@ -335,6 +370,13 @@ class TgFileSystemClient(object): if len(res_list) >= limit: break return res_list + + async def get_messages_by_search_db(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, ignore_case: bool = False) -> list[any]: + if chat_id not in self.client_param.whitelist_chat: + return [] + res = self.db.get_msg_by_chat_id_and_keyword(chat_id, search_word, limit=limit, offset=offset, ignore_case=ignore_case) + res = [self.db.get_column_msg_js(v) for v in res] + return res async def _download_media_chunk(self, msg: types.Message, media_holder: MediaChunkHolder) -> None: try: @@ -364,8 +406,7 @@ class TgFileSystemClient(object): try: # print( # f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]") - self.task_id += 1 - cur_task_id = self.task_id + cur_task_id = self._get_unique_task_id() pos = start while pos <= end: cache_chunk = self.media_chunk_manager.get_media_chunk( diff --git a/backend/TgFileSystemClientManager.py b/backend/TgFileSystemClientManager.py index fd48ca9..1b45a09 100644 --- a/backend/TgFileSystemClientManager.py +++ b/backend/TgFileSystemClientManager.py @@ -1,4 +1,5 @@ from typing import Any +import asyncio import time import hashlib import os @@ -16,9 +17,22 @@ class TgFileSystemClientManager(object): def __init__(self, param: configParse.TgToFileSystemParameter) -> None: self.param = param self.db = UserManager() + self.loop = asyncio.get_running_loop() + if self.loop.is_running(): + self.loop.create_task(self._start_clients()) + else: + self.loop.run_until_complete(self._start_clients()) def __del__(self) -> None: - pass + self.clients.clear() + + async def _start_clients(self) -> None: + # init cache clients + for client_config in self.param.clients: + client = self.create_client(client_id=client_config.token) + if not client.is_valid(): + await client.start() + self._register_client(client) def check_client_session_exist(self, client_id: str) -> bool: session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session" @@ -34,11 +48,11 @@ class TgFileSystemClientManager(object): client = TgFileSystemClient(client_id, self.param, self.db) return client - def register_client(self, client: TgFileSystemClient) -> bool: + def _register_client(self, client: TgFileSystemClient) -> bool: self.clients[client.session_name] = client return True - def deregister_client(self, client_id: str) -> bool: + def _unregister_client(self, client_id: str) -> bool: self.clients.pop(client_id) return True @@ -54,6 +68,6 @@ class TgFileSystemClientManager(object): client = self.create_client(client_id=client_id) if not client.is_valid(): await client.start() - self.register_client(client) + self._register_client(client) return client diff --git a/backend/UserManager.py b/backend/UserManager.py index 53c46a7..f801673 100644 --- a/backend/UserManager.py +++ b/backend/UserManager.py @@ -1,4 +1,5 @@ import os +from enum import Enum, IntEnum, unique, auto import sqlite3 from pydantic import BaseModel @@ -42,27 +43,61 @@ class UserManager(object): def update_message(self) -> None: raise NotImplementedError + def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]: + res = self.cur.execute( + "SELECT * FROM message WHERE chat_id == ? ORDER BY msg_id DESC", (chat_id,)) + return res.fetchall() + + def get_msg_by_chat_id_and_keyword(self, chat_id: int, keyword: str, limit: int = 10, offset: int = 0, ignore_case: bool = False) -> list[any]: + keyword_condition = "msg_ctx LIKE '%{key}%' OR file_name LIKE '%{key}%'" + if ignore_case: + keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')" + keyword_condition = keyword_condition.format(key=keyword) + execute_script = f"SELECT * FROM message WHERE chat_id == {chat_id} AND ({keyword_condition}) ORDER BY msg_id DESC LIMIT {limit} OFFSET {offset}" + res = self.cur.execute(execute_script) + return res + + def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]: + res = self.cur.execute( + "SELECT * FROM message WHERE chat_id == ? ORDER BY msg_id LIMIT 1", (chat_id,)) + return res.fetchall() + + def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]: + res = self.cur.execute( + "SELECT * FROM message WHERE chat_id == ? ORDER BY msg_id DESC LIMIT 1", (chat_id,)) + return res.fetchall() + + @unique + class MessageTypeEnum(Enum): + OTHERS = "others" + TEXT = "text" + PHOTO = "photo" + FILE = "file" + def insert_by_message(self, me: types.User, msg: types.Message): user_id = me.id chat_id = msg.chat_id msg_id = msg.id unique_id = str(user_id) + str(chat_id) + str(msg_id) - msg_type = "others" + msg_type = UserManager.MessageTypeEnum.OTHERS.value mime_type = "" file_name = "" msg_ctx = msg.message msg_js = msg.to_json() - if msg.media is None: - msg_type = "text" - elif isinstance(msg.media, types.MessageMediaPhoto): - msg_type = "photo" - elif isinstance(msg.media, types.MessageMediaDocument): - msg_type = "file" - document = msg.media.document - mime_type = document.mime_type - for attr in document.attributes: - if isinstance(attr, types.DocumentAttributeFilename): - file_name = attr.file_name + try: + if msg.media is None: + msg_type = UserManager.MessageTypeEnum.TEXT.value + elif isinstance(msg.media, types.MessageMediaPhoto): + msg_type = UserManager.MessageTypeEnum.PHOTO.value + elif isinstance(msg.media, types.MessageMediaDocument): + document = msg.media.document + mime_type = document.mime_type + for attr in document.attributes: + if isinstance(attr, types.DocumentAttributeFilename): + file_name = attr.file_name + msg_type = UserManager.MessageTypeEnum.FILE.value + except Exception as err: + print(f"{err=}") insert_data = (unique_id, user_id, chat_id, msg_id, msg_type, msg_ctx, mime_type, file_name, msg_js) execute_script = "INSERT INTO message (unique_id, user_id, chat_id, msg_id, msg_type, msg_ctx, mime_type, file_name, msg_js) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" @@ -72,6 +107,34 @@ class UserManager(object): except Exception as err: print(f"{err=}") + @unique + class ColumnEnum(IntEnum): + UNIQUE_ID = 0 + USER_ID = auto() + CHAT_ID = auto() + MSG_ID = auto() + MSG_TYPE = auto() + MSG_CTX = auto() + MIME_TYPE = auto() + FILE_NAME = auto() + MSG_JS = auto() + COLUMN_LEN = auto() + + def get_column_by_enum(self, column: tuple[any], index: ColumnEnum) -> any: + if len(column) == UserManager.ColumnEnum.COLUMN_LEN: + return column[index] + return None + + def get_column_msg_id(self, column: tuple[any]) -> int | None: + if len(column) == UserManager.ColumnEnum.COLUMN_LEN: + return column[UserManager.ColumnEnum.MSG_ID] + return None + + def get_column_msg_js(self, column: tuple[any]) -> str | None: + if len(column) == UserManager.ColumnEnum.COLUMN_LEN: + return column[UserManager.ColumnEnum.MSG_JS] + return None + def get_user_info() -> None: raise NotImplementedError @@ -94,5 +157,11 @@ if __name__ == "__main__": "UPDATE user SET (client_id, username, phone) = (123, 'hehe', 66666) WHERE client_id == 123") res = db.cur.execute("SELECT name FROM sqlite_master") print(res.fetchall()) - res = db.cur.execute("SELECT msg_ctx FROM message WHERE true AND msg_ctx like '%Cyan%'") + res = db.cur.execute( + "SELECT msg_id, msg_ctx, file_name FROM message WHERE chat_id == -1001216816802") + # res.execute("SELECT * FROM message WHERE chat_id == ? ORDER BY msg_id DESC LIMIT 1", (-1001216816802,)) + # print(res.fetchall()) + # print("\n\n\n\n\n\n") + res.execute("SELECT COUNT(msg_id) FROM message") + # res = db.cur.execute("SELECT DISTINCT chat_id FROM message") print(res.fetchall()) diff --git a/backend/api.py b/backend/api.py index 19ffd74..31a9c83 100644 --- a/backend/api.py +++ b/backend/api.py @@ -50,6 +50,28 @@ class TgToFileListRequestBody(BaseModel): refresh: bool = False inner: bool = False +@app.post("/tg/api/v1/file/search") +@apiutils.atimeit +async def search_tg_file_list(body: TgToFileListRequestBody): + try: + res = hints.TotalList() + res_type = "msg" + client = await clients_mgr.get_client_force(body.token) + res_dict = {} + res = await client.get_messages_by_search_db(body.chat_id, body.search, limit=body.length, offset=body.index) + res_dict = [json.loads(item) for item in res] + + response_dict = { + "client": json.loads(client.to_json()), + "type": res_type, + "length": len(res_dict), + "list": res_dict, + } + return Response(json.dumps(response_dict), status_code=status.HTTP_200_OK) + except Exception as err: + print(f"{err=}") + return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) + @app.post("/tg/api/v1/file/list") @apiutils.atimeit @@ -133,6 +155,17 @@ async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Requ return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) +@app.get("/tg/api/v1/file/msg/{file_name}") +@apiutils.atimeit +async def get_tg_file_media_stream2(file_name: str, sign: str, req: Request): + raise NotImplementedError + + +@app.get("/tg/api/v1/file/msg_convert") +@apiutils.atimeit +async def convert_tg_msg_link_media_stream(link: str, token: str): + raise NotImplementedError + if __name__ == "__main__": param = configParse.get_TgToFileSystemParameter() uvicorn.run(app, host="0.0.0.0", port=param.base.port) diff --git a/backend/apiutils.py b/backend/apiutils.py index ece7e3d..20160c9 100644 --- a/backend/apiutils.py +++ b/backend/apiutils.py @@ -32,6 +32,7 @@ def get_message_media_name(msg: types.Message) -> str: for attr in msg.media.document.attributes: if isinstance(attr, types.DocumentAttributeFilename): return attr.file_name + return "" def timeit(func): diff --git a/configParse.py b/configParse.py index 7cfb25d..8c28805 100644 --- a/configParse.py +++ b/configParse.py @@ -15,6 +15,7 @@ class TgToFileSystemParameter(BaseModel): class ClientConfigPatameter(BaseModel): token: str = "" interval: float = 0.1 + whitelist_chat: list[int] = [] clients: list[ClientConfigPatameter] class ApiParameter(BaseModel): diff --git a/frontend/home.py b/frontend/home.py index 6cccf51..8cc1a10 100644 --- a/frontend/home.py +++ b/frontend/home.py @@ -2,83 +2,139 @@ import sys import os import json +sys.path.append(os.getcwd()) + import streamlit import qrcode import pandas import requests -sys.path.append(os.getcwd()) - import configParse # qr = qrcode.make("https://www.baidu.com") # streamlit.image(qrcode.make("https://www.baidu.com").get_image()) +if streamlit.session_state.get('page_index') is None: + streamlit.session_state.page_index = 0 +if streamlit.session_state.get('search_key') is None: + streamlit.session_state.search_key = "" + param = configParse.get_TgToFileSystemParameter() -background_server_url = f"{param.web.base_url}:{param.base.port}/tg/api/v1/file/list" +background_server_url = f"{param.web.base_url}:{param.base.port}/tg/api/v1/file/search" download_server_url = f"{param.web.base_url}:{param.base.port}/tg/api/v1/file/msg?token={param.web.token}&cid={param.web.chat_id[0]}&mid=" -search_input = streamlit.text_input("输入想搜的:") +search_input = streamlit.text_input("搜索关键字:") col1, col2 = streamlit.columns(2) -search_clicked = False -search_res_limit = streamlit.number_input("限制搜索量", min_value=1, max_value=100, value=10, format="%d") +search_res_limit = streamlit.number_input( + "搜索结果数", min_value=1, max_value=100, value=10, format="%d") search_clicked = streamlit.button("Search") -if not search_clicked or search_input == "": +if (not search_clicked or search_input == "") and search_input != streamlit.session_state.search_input: + streamlit.session_state.page_index = 0 streamlit.stop() +streamlit.session_state.search_input = search_input -test_body = { - "token": param.web.token, - "search": f"{search_input}", - "chat_id": param.web.chat_id[0], - "index": 0, - "length": search_res_limit, - "refresh": False, - "inner": False, -} +@streamlit.experimental_fragment +def show_search_res(): + offset_index = streamlit.session_state.page_index * search_res_limit -req = requests.post(background_server_url, data=json.dumps(test_body)) -if req.status_code != 200: - streamlit.stop() -search_res = json.loads(req.content.decode("utf-8")) - - -message_list = [] -file_name_list = [] -file_size_list = [] -download_url_list = [] -message_id_list = [] -for v in search_res['list']: - message_list.append(v['message']) - doc = v['media']['document'] - file_size = doc['size'] or 0 - file_size_list.append(f"{file_size/1024/1024:.2f}MB") - file_name = "" - for attr in doc['attributes']: - file_name = attr.get('file_name') - if file_name is not None: - file_name_list.append(file_name) - break - if file_name == "": - file_name_list.append("Not A File") - msg_id = str(v['id']) - message_id_list.append(msg_id) - download_url_list.append(download_server_url+msg_id) - -df = pandas.DataFrame( - { - "message": message_list, - "file name": file_name_list, - "file size": file_size_list, - "url": download_url_list, - "id": message_id_list, + req_body = { + "token": param.web.token, + "search": f"{search_input}", + "chat_id": param.web.chat_id[0], + "index": offset_index, + "length": search_res_limit, + "refresh": False, + "inner": False, } -) -streamlit.dataframe( - df, - column_config={ - "url": streamlit.column_config.LinkColumn("URL"), - }, - hide_index=True, -) + req = requests.post(background_server_url, data=json.dumps(req_body)) + if req.status_code != 200: + streamlit.stop() + search_res = json.loads(req.content.decode("utf-8")) + + + message_list = [] + file_name_list = [] + file_size_list = [] + download_url_list = [] + message_id_list = [] + select_box_list = [] + for v in search_res['list']: + message_list.append(v['message']) + doc = None + file_size = 0 + try: + doc = v['media']['document'] + file_size = doc['size'] + except: + pass + file_size_list.append(f"{file_size/1024/1024:.2f}MB") + file_name = None + for attr in doc['attributes']: + file_name = attr.get('file_name') + if file_name is not None: + file_name_list.append(file_name) + break + if file_name is None: + file_name_list.append("Not A File") + msg_id = str(v['id']) + message_id_list.append(msg_id) + download_url_list.append(download_server_url+msg_id) + select_box_list.append(False) + + df = pandas.DataFrame( + { + "select_box": select_box_list, + "message": message_list, + "file name": file_name_list, + "file size": file_size_list, + "url": download_url_list, + "id": message_id_list, + } + ) + + # streamlit.text_area("debug", value=f'{df}') + if df.empty: + streamlit.info("No result") + streamlit.stop() + data = streamlit.data_editor( + df, + column_config={ + "select_box": streamlit.column_config.CheckboxColumn("✅", default=False), + "url": streamlit.column_config.LinkColumn("URL"), + }, + disabled=["message", + "file name", + "file size", + "url", + "id",], + hide_index=True, + ) + columns = streamlit.columns(3) + with columns[0]: + pre_button = streamlit.button("Prev", use_container_width=True) + if pre_button: + streamlit.session_state.page_index = max(streamlit.session_state.page_index - 1, 0) + streamlit.rerun() + with columns[1]: + # streamlit.text(f"{streamlit.session_state.page_index + 1}") + streamlit.markdown(f"

{streamlit.session_state.page_index + 1}

", unsafe_allow_html=True) + # streamlit.markdown(f"", unsafe_allow_html=True) + with columns[2]: + next_button = streamlit.button("Next", use_container_width=True) + if next_button: + streamlit.session_state.page_index = streamlit.session_state.page_index + 1 + streamlit.rerun() + + show_text = "" + select_list = data['select_box'] + url_list = data['url'] + for i in range(len(select_list)): + if select_list[i]: + show_text = show_text + url_list[i] + '\n' + if show_text != "": + streamlit.text_area("链接", value=show_text) + +show_search_res() +