diff --git a/README.md b/README.md index 27722ea..2d4a340 100644 --- a/README.md +++ b/README.md @@ -1 +1,19 @@ -自用脚本,随缘更新 +# TgToolBox + +A Telegram toolbox + +## Run the project + +### Install requirements.txt + +### Config + +### Run + +## TODO + +- [ ] Support photo +- [ ] chat search +- [ ] token encrypt +- [ ] The encryption key has an expiration time +- [ ] Photo gallery diff --git a/backend/TgFileSystemClient.py b/backend/TgFileSystemClient.py index a93b587..b70469c 100644 --- a/backend/TgFileSystemClient.py +++ b/backend/TgFileSystemClient.py @@ -2,7 +2,6 @@ import asyncio import json import time import re -import rsa import os import functools import traceback @@ -35,10 +34,6 @@ class TgFileSystemClient(object): worker_routines: list[asyncio.Task] qr_login: QRLogin | None = None login_task: asyncio.Task | None = None - # rsa key - sign: str - public_key: rsa.PublicKey - private_key: rsa.PrivateKey # task should: (task_id, callabledFunc) task_queue: asyncio.Queue task_id: int = 0 @@ -51,6 +46,7 @@ class TgFileSystemClient(object): session_name: str, param: configParse.TgToFileSystemParameter, db: UserManager, + chunk_manager: MediaChunkHolderManager, ) -> None: self.api_id = param.tgApi.api_id self.api_hash = param.tgApi.api_hash @@ -64,12 +60,10 @@ class TgFileSystemClient(object): if param.proxy.enable else {} ) - self.public_key, self.private_key = rsa.newkeys(1024) self.client_param = next( (client_param for client_param in param.clients if client_param.token == session_name), configParse.TgToFileSystemParameter.ClientConfigPatameter(), ) - self.sign = self.client_param.token self.task_queue = asyncio.Queue() self.client = TelegramClient( f"{os.path.dirname(__file__)}/db/{self.session_name}.session", @@ -77,7 +71,7 @@ class TgFileSystemClient(object): self.api_hash, proxy=self.proxy_param, ) - self.media_chunk_manager = MediaChunkHolderManager() + self.media_chunk_manager = chunk_manager self.db = db self.worker_routines = [] diff --git a/backend/TgFileSystemClientManager.py b/backend/TgFileSystemClientManager.py index 52a145b..a61f1a1 100644 --- a/backend/TgFileSystemClientManager.py +++ b/backend/TgFileSystemClientManager.py @@ -1,13 +1,14 @@ -from typing import Any import asyncio import time import hashlib +import rsa import os import traceback import logging from backend.TgFileSystemClient import TgFileSystemClient from backend.UserManager import UserManager +from backend.MediaCacheManager import MediaChunkHolderManager import configParse logger = logging.getLogger(__file__.split("/")[-1]) @@ -18,6 +19,10 @@ class TgFileSystemClientManager(object): is_init: bool = False param: configParse.TgToFileSystemParameter clients: dict[str, TgFileSystemClient] = {} + # rsa key + cache_sign: str + public_key: rsa.PublicKey + private_key: rsa.PrivateKey @classmethod def get_instance(cls): @@ -29,6 +34,8 @@ class TgFileSystemClientManager(object): self.param = param self.db = UserManager() self.loop = asyncio.get_running_loop() + self.media_chunk_manager = MediaChunkHolderManager() + self.public_key, self.private_key = rsa.newkeys(1024) if self.loop.is_running(): self.loop.create_task(self._start_clients()) else: @@ -76,7 +83,7 @@ class TgFileSystemClientManager(object): def create_client(self, client_id: str = None) -> TgFileSystemClient: if client_id is None: client_id = self.generate_client_id() - client = TgFileSystemClient(client_id, self.param, self.db) + client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager) return client def _register_client(self, client: TgFileSystemClient) -> bool: diff --git a/backend/api.py b/backend/api.py index f31a993..a86a2a0 100644 --- a/backend/api.py +++ b/backend/api.py @@ -1,15 +1,16 @@ import asyncio import json import os +import sys import logging import traceback +from typing import Annotated from urllib.parse import quote import uvicorn -from fastapi import FastAPI, status, Request +from fastapi import FastAPI, status, Request, Depends, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response, StreamingResponse -from contextlib import asynccontextmanager from telethon import types, hints, utils from pydantic import BaseModel @@ -21,7 +22,6 @@ from backend.TgFileSystemClientManager import TgFileSystemClientManager logger = logging.getLogger(__file__.split("/")[-1]) -@asynccontextmanager async def lifespan(app: FastAPI): clients_mgr = TgFileSystemClientManager.get_instance() res = await clients_mgr.get_status() @@ -131,48 +131,8 @@ async def get_tg_file_list(body: TgToFileListRequestBody): @app.get("/tg/api/v1/file/msg") @apiutils.atimeit async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request): - msg_id = mid - chat_id = cid - headers = { - # "content-type": "video/mp4", - "accept-ranges": "bytes", - "content-encoding": "identity", - # "content-length": stream_file_size, - "access-control-expose-headers": ("content-type, accept-ranges, content-length, " "content-range, content-encoding"), - } - range_header = request.headers.get("range") try: - clients_mgr = TgFileSystemClientManager.get_instance() - client = await clients_mgr.get_client_force(token) - msg = await client.get_message(chat_id, msg_id) - file_size = msg.media.document.size - start = 0 - end = file_size - 1 - status_code = status.HTTP_200_OK - mime_type = msg.media.document.mime_type - headers["content-type"] = mime_type - # headers["content-length"] = str(file_size) - file_name = apiutils.get_message_media_name(msg) - if file_name == "": - maybe_file_type = mime_type.split("/")[-1] - file_name = f"{chat_id}.{msg_id}.{maybe_file_type}" - headers["Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'" - - if range_header is not None: - start, end = apiutils.get_range_header(range_header, file_size) - size = end - start + 1 - # headers["content-length"] = str(size) - headers["content-range"] = f"bytes {start}-{end}/{file_size}" - status_code = status.HTTP_206_PARTIAL_CONTENT - else: - headers["content-length"] = str(file_size) - headers["content-range"] = f"bytes 0-{file_size-1}/{file_size}" - return StreamingResponse( - client.streaming_get_iter(msg, start, end, request), - headers=headers, - media_type=mime_type, - status_code=status_code, - ) + return api.get_media_file_stream(token, cid, mid, request) except Exception as err: logger.error(f"{err=},{traceback.format_exc()}") return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) @@ -263,6 +223,29 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) +async def get_verify(q: str | None, skip: int = 0): + logger.info("run common param") + if skip < 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{q=},{skip=}") + + +@app.get("/tg/api/v1/test", dependencies=[Depends(get_verify)]) +async def test_get_depends_verify_method(other: str = ""): + return Response() + + +async def post_verify(body: TgToChatListRequestBody | None = None): + if not body or not body.token: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + return body + + +@app.post("/tg/api/v1/test", dependencies=[Depends(post_verify)]) +async def test_get_depends_verify_method(body: TgToChatListRequestBody): + return Response() + + if __name__ == "__main__": param = configParse.get_TgToFileSystemParameter() - uvicorn.run(app, host="0.0.0.0", port=param.base.port) + isDebug = True if sys.gettrace() else False + uvicorn.run(app, host="0.0.0.0", port=param.base.port, reload=isDebug) diff --git a/backend/api_implement.py b/backend/api_implement.py index 4b50d7c..0f57ea3 100644 --- a/backend/api_implement.py +++ b/backend/api_implement.py @@ -1,8 +1,12 @@ import traceback import json import logging +from urllib.parse import quote from telethon import types, hints, utils +import fastapi +from fastapi import Request +from fastapi.responses import StreamingResponse, Response import configParse from backend import apiutils @@ -57,3 +61,50 @@ async def get_clients_manager_status(detail: bool) -> dict[str, any]: return ret ret["clist"] = await get_chat_details(clients_mgr) return ret + + +async def get_media_file_stream(token: str, cid: int, mid: int, request: Request) -> StreamingResponse: + msg_id = mid + chat_id = cid + headers = { + # "content-type": "video/mp4", + "accept-ranges": "bytes", + "content-encoding": "identity", + # "content-length": stream_file_size, + "access-control-expose-headers": ("content-type, accept-ranges, content-length, " "content-range, content-encoding"), + } + range_header = request.headers.get("range") + + clients_mgr = TgFileSystemClientManager.get_instance() + client = await clients_mgr.get_client_force(token) + msg = await client.get_message(chat_id, msg_id) + if not isinstance(msg.media, types.MessageMediaDocument) and not isinstance(msg.media, types.MessageMediaPhoto): + raise RuntimeError(f"request don't support: {msg.media=}") + file_size = msg.media.document.size + start = 0 + end = file_size - 1 + status_code = fastapi.status.HTTP_200_OK + mime_type = msg.media.document.mime_type + headers["content-type"] = mime_type + # headers["content-length"] = str(file_size) + file_name = apiutils.get_message_media_name(msg) + if file_name == "": + maybe_file_type = mime_type.split("/")[-1] + file_name = f"{chat_id}.{msg_id}.{maybe_file_type}" + headers["Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'" + + if range_header is not None: + start, end = apiutils.get_range_header(range_header, file_size) + size = end - start + 1 + # headers["content-length"] = str(size) + headers["content-range"] = f"bytes {start}-{end}/{file_size}" + status_code = fastapi.status.HTTP_206_PARTIAL_CONTENT + else: + headers["content-length"] = str(file_size) + headers["content-range"] = f"bytes 0-{file_size-1}/{file_size}" + return StreamingResponse( + client.streaming_get_iter(msg, start, end, request), + headers=headers, + media_type=mime_type, + status_code=status_code, + ) diff --git a/backend/apiutils.py b/backend/apiutils.py index 1cfeba1..a75c958 100644 --- a/backend/apiutils.py +++ b/backend/apiutils.py @@ -2,13 +2,14 @@ import time import logging from fastapi import status, HTTPException -from telethon import types +from telethon import types, utils from functools import wraps import configParse logger = logging.getLogger(__file__.split("/")[-1]) + def get_range_header(range_header: str, file_size: int) -> tuple[int, int]: def _invalid_range(): return HTTPException( @@ -28,87 +29,201 @@ def get_range_header(range_header: str, file_size: int) -> tuple[int, int]: return start, end -def get_message_media_name(msg: types.Message) -> str: - if msg.media is None or msg.media.document is None: - return "" - for attr in msg.media.document.attributes: +def _get_message_media_document_kind_and_names(document: types.MessageMediaDocument) -> tuple[str, str]: + """Gets kind and possible names for :tl:`DocumentAttribute`.""" + kind = "document" + possible_names = [] + for attr in document.attributes: if isinstance(attr, types.DocumentAttributeFilename): - return attr.file_name - return "" + possible_names.insert(0, attr.file_name) + + elif isinstance(attr, types.DocumentAttributeAudio): + kind = "audio" + if attr.performer and attr.title: + possible_names.append("{} - {}".format(attr.performer, attr.title)) + elif attr.performer: + possible_names.append(attr.performer) + elif attr.title: + possible_names.append(attr.title) + elif attr.voice: + kind = "voice" + + return kind, possible_names + + +def get_message_media_name(msg: types.Message) -> str: + if msg.media is None: + return "" + match type(msg.media): + case types.MessageMediaPhoto: + return f"{msg.media.photo.id}.jpg" + case types.MessageMediaDocument: + kind, possible_names = _get_message_media_document_kind_and_names(msg.media.document) + try: + name = None if possible_names is None else next(x for x in possible_names if x) + except StopIteration: + name = None + if name: + return name + extension = utils.get_extension(msg.media) + peer_id = utils.get_peer_id(msg) + return f"{kind}_{peer_id}-{msg.id}{extension}" + case _: + return "" + + +def _get_message_media_valid_photo(msg: types.Message) -> types.Photo | None: + if msg.media is None: + return None + photo = msg.media + if isinstance(photo, types.MessageMediaPhoto): + photo = photo.photo + if not isinstance(photo, types.Photo): + return None + return photo + + +def _sort_message_media_photo_thumbs(thumbs: list[any]) -> list[any]: + def sort_thumbs(thumb): + if isinstance(thumb, types.PhotoStrippedSize): + return 1, len(thumb.bytes) + if isinstance(thumb, types.PhotoCachedSize): + return 1, len(thumb.bytes) + if isinstance(thumb, types.PhotoSize): + return 1, thumb.size + if isinstance(thumb, types.PhotoSizeProgressive): + return 1, max(thumb.sizes) + if isinstance(thumb, types.VideoSize): + return 2, thumb.size + + # Empty size or invalid should go last + return 0, 0 + + thumbs = list(sorted(thumbs), key=sort_thumbs) + for i in reversed(range(len(thumbs))): + if isinstance(thumbs[i], types.PhotoPathSize): + thumbs.pop(i) + + return thumbs + + +def _get_message_media_photo_file_last_photo_size(thumbs: list[any]): + thumbs = _sort_message_media_photo_thumbs(thumbs) + + size = thumbs[-1] if thumbs else None + if not size or isinstance(size, types.PhotoSizeEmpty): + return None + return size + + +def get_message_media_photo_file_name(msg: types.Message) -> str: + photo = _get_message_media_valid_photo(msg) + if not photo: + return "" + + size = _get_message_media_photo_file_last_photo_size(photo.sizes + (photo.video_sizes or [])) + if not size: + return "" + if isinstance(size, types.VideoSize): + return f"{photo.id}.mp4" + return f"{photo.id}.jpg" + + +def get_message_media_photo_file_size(msg: types.Message) -> int: + photo = _get_message_media_valid_photo(msg) + if not photo: + return 0 + + size = _get_message_media_photo_file_last_photo_size(photo.sizes + (photo.video_sizes or [])) + if not size: + return 0 + + if isinstance(size, types.PhotoStrippedSize): + return len(utils.stripped_photo_to_jpg(size.bytes)) + elif isinstance(size, types.PhotoCachedSize): + return len(size.bytes) + + if isinstance(size, types.PhotoSizeProgressive): + return max(size.sizes) + return size.size + def get_message_media_name_from_dict(msg: dict[str, any]) -> str: doc = None try: - doc = msg['media']['document'] + doc = msg["media"]["document"] except: pass file_name = None if doc is not None: - for attr in doc['attributes']: - file_name = attr.get('file_name') + for attr in doc["attributes"]: + file_name = attr.get("file_name") if file_name != "" and file_name is not None: break if file_name == "" or file_name is None: file_name = "unknown.tmp" return file_name + def get_message_chat_id_from_dict(msg: dict[str, any]) -> int: try: - return msg['peer_id']['channel_id'] + return msg["peer_id"]["channel_id"] except: pass return 0 + def get_message_msg_id_from_dict(msg: dict[str, any]) -> int: try: - return msg['id'] + return msg["id"] except: pass return 0 + def timeit_sec(func): @wraps(func) def timeit_wrapper(*args, **kwargs): - logger.debug( - f'Function called {func.__name__}{args} {kwargs}') + logger.debug(f"Function called {func.__name__}{args} {kwargs}") start_time = time.perf_counter() result = func(*args, **kwargs) end_time = time.perf_counter() total_time = end_time - start_time - logger.debug( - f'Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') + logger.debug(f"Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds") return result + return timeit_wrapper + def timeit(func): if configParse.get_TgToFileSystemParameter().base.timeit_enable: + @wraps(func) def timeit_wrapper(*args, **kwargs): - logger.debug( - f'Function called {func.__name__}{args} {kwargs}') + logger.debug(f"Function called {func.__name__}{args} {kwargs}") start_time = time.perf_counter() result = func(*args, **kwargs) end_time = time.perf_counter() total_time = end_time - start_time - logger.debug( - f'Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') + logger.debug(f"Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds") return result + return timeit_wrapper return func def atimeit(func): if configParse.get_TgToFileSystemParameter().base.timeit_enable: + @wraps(func) async def timeit_wrapper(*args, **kwargs): - logger.debug( - f'AFunction called {func.__name__}{args} {kwargs}') + logger.debug(f"AFunction called {func.__name__}{args} {kwargs}") start_time = time.perf_counter() result = await func(*args, **kwargs) end_time = time.perf_counter() total_time = end_time - start_time - logger.debug( - f'AFunction quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') + logger.debug(f"AFunction quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds") return result + return timeit_wrapper return func diff --git a/frontend/remote_api.py b/frontend/remote_api.py index 137e443..7ec0811 100644 --- a/frontend/remote_api.py +++ b/frontend/remote_api.py @@ -80,7 +80,6 @@ def convert_tg_link_to_proxy_link(link: str) -> str: request_url = background_server_url + link_convert_api_route + f"?link={link}" response = requests.get(request_url) if response.status_code != 200: - print(f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}") - return "" + return f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}" response_js = json.loads(response.content.decode("utf-8")) return response_js["url"] diff --git a/frontend/search.py b/frontend/search.py index f4c414c..c4897be 100644 --- a/frontend/search.py +++ b/frontend/search.py @@ -31,7 +31,7 @@ def loop(): wait_client_ready.status("Server Initializing") st.session_state.chat_dict = api.get_white_list_chat_dict() wait_client_ready.empty() - st.query_params.search_key = st.text_input("**搜索🔎**", value=keyword) + st.query_params.search_key = st.text_input("**Search🔎**", value=keyword) chat_list = [] for _, chat_info in st.session_state.chat_dict.items(): chat_list.append(chat_info["title"]) @@ -39,13 +39,13 @@ def loop(): columns = st.columns([4, 4, 1]) with columns[0]: st.query_params.search_res_limit = str( - st.number_input("**每页结果**", min_value=1, max_value=100, value=res_limit, format="%d") + st.number_input("**Results per page**", min_value=1, max_value=100, value=res_limit, format="%d") ) with columns[1]: st.session_state.chat_select_list = st.multiselect("**Search in**", chat_list, default=chat_list) with columns[2]: - st.text("排序") - st.query_params.is_order = st.toggle("顺序", value=isorder) + st.text("Sort") + st.query_params.is_order = st.toggle("Time🔼", value=isorder) search_limit_container = st.container() with search_limit_container: