From ff997c7434f395d0c71ab23661d01317828ef68e Mon Sep 17 00:00:00 2001 From: Hehesheng Date: Sun, 16 Jun 2024 22:38:40 +0800 Subject: [PATCH] feat: sign generate --- .gitignore | 1 + backend/MediaCacheManager.py | 2 +- backend/TgFileSystemClient.py | 4 +- backend/TgFileSystemClientManager.py | 115 ++++++++++++++++++++++++--- backend/api.py | 52 ++++++++---- backend/api_implement.py | 13 +-- configParse.py | 7 +- frontend/home.py | 9 ++- frontend/remote_api.py | 10 ++- frontend/search.py | 4 +- 10 files changed, 172 insertions(+), 45 deletions(-) diff --git a/.gitignore b/.gitignore index 200e571..5b374f5 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__ *.toml *.db *.service +*.pem log cache_media tmp diff --git a/backend/MediaCacheManager.py b/backend/MediaCacheManager.py index cb3b7c1..474c191 100644 --- a/backend/MediaCacheManager.py +++ b/backend/MediaCacheManager.py @@ -162,7 +162,7 @@ class MediaChunkHolderManager(object): def __init__(self) -> None: self.chunk_lru = collections.OrderedDict() self.disk_chunk_cache = diskcache.Cache( - f"{os.path.dirname(__file__)}/cache_media", size_limit=MediaChunkHolderManager.MAX_CACHE_SIZE * 2 + f"{os.path.dirname(__file__)}/db/cache_media", size_limit=MediaChunkHolderManager.MAX_CACHE_SIZE * 2 ) self._restore_cache() diff --git a/backend/TgFileSystemClient.py b/backend/TgFileSystemClient.py index b70469c..e1d91c1 100644 --- a/backend/TgFileSystemClient.py +++ b/backend/TgFileSystemClient.py @@ -61,8 +61,8 @@ class TgFileSystemClient(object): else {} ) self.client_param = next( - (client_param for client_param in param.clients if client_param.token == session_name), - configParse.TgToFileSystemParameter.ClientConfigPatameter(), + (client_param for client_param in param.clients if client_param.name == session_name), + configParse.TgToFileSystemParameter.ClientConfigPatameter(name="__tmp__"), ) self.task_queue = asyncio.Queue() self.client = TelegramClient( diff --git a/backend/TgFileSystemClientManager.py b/backend/TgFileSystemClientManager.py index a61f1a1..22c54a9 100644 --- a/backend/TgFileSystemClientManager.py +++ b/backend/TgFileSystemClientManager.py @@ -1,8 +1,11 @@ import asyncio import time +import base64 import hashlib import rsa import os +from enum import IntEnum, unique, auto +import time import traceback import logging @@ -14,7 +17,16 @@ import configParse logger = logging.getLogger(__file__.split("/")[-1]) +@unique +class EnumSignLevel(IntEnum): + ADMIN = auto() + NORMAL = auto() + VIST = auto() + NONE = auto() + + class TgFileSystemClientManager(object): + TIME_MS_24HOURS: int = 24 * 60 * 60 * 1000 MAX_MANAGE_CLIENTS: int = 10 is_init: bool = False param: configParse.TgToFileSystemParameter @@ -35,7 +47,7 @@ class TgFileSystemClientManager(object): self.db = UserManager() self.loop = asyncio.get_running_loop() self.media_chunk_manager = MediaChunkHolderManager() - self.public_key, self.private_key = rsa.newkeys(1024) + self._init_rsa_keys() if self.loop.is_running(): self.loop.create_task(self._start_clients()) else: @@ -47,7 +59,7 @@ class TgFileSystemClientManager(object): async def _start_clients(self) -> None: # init cache clients for client_config in self.param.clients: - client = self.create_client(client_id=client_config.token) + client = self.create_client(client_config.name) self._register_client(client) for _, client in self.clients.items(): try: @@ -57,11 +69,97 @@ class TgFileSystemClientManager(object): logger.warning(f"start client: {err=}, {traceback.format_exc()}") self.is_init = True + def _init_rsa_keys(self): + key_dir = f"{os.path.dirname(__file__)}/db" + pub_key_path = f"{key_dir}/pub.pem" + pri_key_path = f"{key_dir}/pri.pem" + if not os.path.isfile(pub_key_path) or not os.path.isfile(pri_key_path): + self.public_key, self.private_key = rsa.newkeys(512) + with open(pub_key_path, "wb") as f: + f.write(self.public_key.save_pkcs1()) + with open(pri_key_path, "wb") as f: + f.write(self.private_key.save_pkcs1()) + else: + with open(pub_key_path, "rb") as f: + self.public_key = rsa.PublicKey.load_pkcs1(f.read()) + with open(pri_key_path, "rb") as f: + self.private_key = rsa.PrivateKey.load_pkcs1(f.read()) + + def generate_sign( + self, client_id: str, sign_type: EnumSignLevel = EnumSignLevel.NORMAL, salt: str = None, valid_time: int = -1 + ) -> str: + timestamp = int(time.time()) + if valid_time == -1: + timestamp += self.TIME_MS_24HOURS + elif valid_time == 0: + timestamp = 0 + else: + timestamp += valid_time * 1000 + need_encrypt_str = f"ts={timestamp};l={sign_type.value};" + if salt: + need_encrypt_str += f"s={hashlib.md5(salt).hexdigest()[:8]};" + # rsa 512 bits only + valid_len = 512 // 8 - 11 + valid_len -= len(need_encrypt_str) + # id=xxxxx; + valid_len -= len("id=;") + if valid_len < 0: + logger.error(f"{need_encrypt_str=},{traceback.format_exc()}") + raise RuntimeError(f"generate sign too big") + real_client_id = client_id[:valid_len] + if len(real_client_id) != len(client_id): + logger.warning(f"client id too long: {client_id} -> {real_client_id}") + need_encrypt_str += f"id={real_client_id};" + need_encrypt_bin = need_encrypt_str.encode() + sign_bin = rsa.encrypt(need_encrypt_bin, self.public_key) + sign = base64.b64encode(sign_bin).decode() + logger.info(f"generate {sign_type.name} sign: {sign}") + return sign + + def parse_sign(self, sign: str) -> dict[str, any] | None: + try: + res_dict = {} + sign_bin = base64.b64decode(sign) + decrypt_bin = rsa.decrypt(sign_bin, self.private_key) + decrypt_str = decrypt_bin.decode() + for key_value_str in decrypt_str.split(";"): + if key_value_str == "": + continue + key, value = key_value_str.split("=") + res_dict[key] = value + except Exception as err: + logger.warning(f"verify sign {err=}, {traceback.format_exc()}") + return None + return res_dict + + @staticmethod + def get_sign_client_id(key_map: dict[str, any]) -> str: + return key_map.get("id") + + def verify_sign( + self, + sign: str, + client_id: str = None, + v_ts: bool = True, + target_level: EnumSignLevel = EnumSignLevel.NONE, + salt: str = None, + ) -> bool: + key_map = self.parse_sign(sign) + if not key_map: + return False + if client_id and (not key_map.get("id") or not client_id.startswith(key_map.get("id"))): + return False + if not key_map.get("l") or target_level.value < int(key_map.get("l")): + return False + if v_ts and int(key_map.get("ts", 0)) > 0 and (int(time.time()) - int(key_map.get("ts", 0)) > 0): + return False + if salt and hashlib.md5(key_map.get("s", "")).hexdigest() != salt: + return False + return True + async def get_status(self) -> dict[str, any]: clients_status = [ - { - "status": client.is_valid(), - } + {"status": client.is_valid(), "name": client.session_name, "sign": self.generate_sign(client.session_name)} for _, client in self.clients.items() ] return {"init": self.is_init, "clients": clients_status} @@ -77,12 +175,7 @@ class TgFileSystemClientManager(object): session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session" return os.path.isfile(session_db_file) - def generate_client_id(self) -> str: - return hashlib.md5((str(time.perf_counter()) + self.param.base.salt).encode("utf-8")).hexdigest() - - def create_client(self, client_id: str = None) -> TgFileSystemClient: - if client_id is None: - client_id = self.generate_client_id() + def create_client(self, client_id: str) -> TgFileSystemClient: client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager) return client diff --git a/backend/api.py b/backend/api.py index 1b8f3dd..ee31508 100644 --- a/backend/api.py +++ b/backend/api.py @@ -41,7 +41,7 @@ app.add_middleware( class TgToFileListRequestBody(BaseModel): - token: str + sign: str search: str = "" chat_ids: list[int] = [] index: int = 0 @@ -51,15 +51,31 @@ class TgToFileListRequestBody(BaseModel): inc: bool = False -@app.post("/tg/api/v1/file/search") +async def verify_post_sign(body: TgToFileListRequestBody): + clients_mgr = TgFileSystemClientManager.get_instance() + if not clients_mgr.verify_sign(body.sign): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{body}") + + +async def verify_get_sign(sign: str): + clients_mgr = TgFileSystemClientManager.get_instance() + sign = sign.replace(" ", "+") + if not clients_mgr.verify_sign(sign): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{sign}") + return sign + + +@app.post("/tg/api/v1/file/search", dependencies=[Depends(verify_post_sign)]) @apiutils.atimeit async def search_tg_file_list(body: TgToFileListRequestBody): try: - param = configParse.get_TgToFileSystemParameter() clients_mgr = TgFileSystemClientManager.get_instance() + param = configParse.get_TgToFileSystemParameter() res = hints.TotalList() res_type = "msg" - client = await clients_mgr.get_client_force(body.token) + sign_info = clients_mgr.parse_sign(body.sign) + client_id = TgFileSystemClientManager.get_sign_client_id(sign_info) + client = await clients_mgr.get_client_force(client_id) res_dict = [] res = await client.get_messages_by_search_db( body.chat_ids, body.search, limit=body.length, inc=body.inc, offset=body.index @@ -75,7 +91,7 @@ async def search_tg_file_list(body: TgToFileListRequestBody): res_dict.append(msg_info) client_dict = json.loads(client.to_json()) - client_dict["sign"] = body.token + client_dict["sign"] = body.sign response_dict = { "client": client_dict, @@ -128,17 +144,18 @@ async def get_tg_file_list(body: TgToFileListRequestBody): return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) -@app.get("/tg/api/v1/file/msg") +@app.get("/tg/api/v1/file/msg", deprecated=[Depends(verify_get_sign)]) @apiutils.atimeit -async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request): +async def get_tg_file_media_stream(sign: str, cid: int, mid: int, request: Request): try: - return await api.get_media_file_stream(token, cid, mid, request) + sign = sign.replace(" ", "+") + return await api.get_media_file_stream(sign, cid, mid, request) except Exception as err: logger.error(f"{err=},{traceback.format_exc()}") return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) -@app.get("/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}") +@app.get("/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}", dependencies=[Depends(verify_get_sign)]) @apiutils.atimeit async def get_tg_file_media(chat_id: int | str, msg_id: int, file_name: str, sign: str, req: Request): try: @@ -223,15 +240,20 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) -async def get_verify(q: str | None, skip: int = 0): - logger.info("run common param") - if skip < 0: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{q=},{skip=}") +async def get_verify(id: str = None): + if id is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{id=}") + client_mgr = TgFileSystemClientManager.get_instance() + client = await client_mgr.get_client_force(id) + if not client.is_valid(): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{id=}") @app.get("/tg/api/v1/test", dependencies=[Depends(get_verify)]) -async def test_get_depends_verify_method(other: str = ""): - return Response() +async def test_get_depends_verify_method(id: str, other: str = ""): + client_mgr = TgFileSystemClientManager.get_instance() + client = await client_mgr.get_client_force(id) + return Response((await client.client.get_me()).stringify()) async def post_verify(body: TgToChatListRequestBody | None = None): diff --git a/backend/api_implement.py b/backend/api_implement.py index 0f57ea3..063dac9 100644 --- a/backend/api_implement.py +++ b/backend/api_implement.py @@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse, Response import configParse from backend import apiutils -from backend.TgFileSystemClientManager import TgFileSystemClientManager +from backend.TgFileSystemClientManager import TgFileSystemClientManager, EnumSignLevel logger = logging.getLogger(__file__.split("/")[-1]) @@ -38,9 +38,8 @@ async def link_convert(link: str) -> str: msg = await client.get_message(chat_id_or_name, msg_id) file_name = apiutils.get_message_media_name(msg) param = configParse.get_TgToFileSystemParameter() - url = ( - f"{param.base.exposed_url}/tg/api/v1/file/get/{utils.get_peer_id(msg.peer_id)}/{msg.id}/{file_name}?sign={client.sign}" - ) + sign = clients_mgr.generate_sign(client.session_name, EnumSignLevel.VIST) + url = f"{param.base.exposed_url}/tg/api/v1/file/get/{utils.get_peer_id(msg.peer_id)}/{msg.id}/{file_name}?sign={sign}" return url @@ -63,7 +62,7 @@ async def get_clients_manager_status(detail: bool) -> dict[str, any]: return ret -async def get_media_file_stream(token: str, cid: int, mid: int, request: Request) -> StreamingResponse: +async def get_media_file_stream(sign: str, cid: int, mid: int, request: Request) -> StreamingResponse: msg_id = mid chat_id = cid headers = { @@ -76,7 +75,9 @@ async def get_media_file_stream(token: str, cid: int, mid: int, request: Request range_header = request.headers.get("range") clients_mgr = TgFileSystemClientManager.get_instance() - client = await clients_mgr.get_client_force(token) + sign_info = clients_mgr.parse_sign(sign) + client_id = TgFileSystemClientManager.get_sign_client_id(sign_info) + client = await clients_mgr.get_client_force(client_id) msg = await client.get_message(chat_id, msg_id) if not isinstance(msg.media, types.MessageMediaDocument) and not isinstance(msg.media, types.MessageMediaPhoto): raise RuntimeError(f"request don't support: {msg.media=}") diff --git a/configParse.py b/configParse.py index e3d77aa..1910842 100644 --- a/configParse.py +++ b/configParse.py @@ -6,15 +6,16 @@ from pydantic import BaseModel class TgToFileSystemParameter(BaseModel): + class BaseParameter(BaseModel): - salt: str = "" exposed_url: str = "http://127.0.0.1:7777" port: int = 7777 timeit_enable: bool = False + base: BaseParameter class ClientConfigPatameter(BaseModel): - token: str = "" + name: str interval: float = 0.1 whitelist_chat: list[int] = [] clients: list[ClientConfigPatameter] @@ -33,7 +34,7 @@ class TgToFileSystemParameter(BaseModel): class TgWebParameter(BaseModel): enable: bool = False - token: str = "" + name: str = "" port: int = 2000 web: TgWebParameter diff --git a/frontend/home.py b/frontend/home.py index feec3f5..e49b9fa 100644 --- a/frontend/home.py +++ b/frontend/home.py @@ -8,6 +8,7 @@ st.set_page_config(page_title="TgToolbox", page_icon="🕹️", layout="wide", i backend_status = api.get_backend_client_status() need_login = False +sign = "" if backend_status is None or not backend_status["init"]: st.status("Server not ready") @@ -15,8 +16,10 @@ if backend_status is None or not backend_status["init"]: st.rerun() for v in backend_status["clients"]: - if not v["status"]: - need_login = True + if v["name"] != api.get_config_default_name(): + continue + need_login = not v["status"] + sign = v["sign"] if need_login: import login @@ -28,7 +31,7 @@ search_tab, link_convert_tab = st.tabs(["Search", "Link Convert"]) with search_tab: import search - search.loop() + search.loop(sign) with link_convert_tab: import link_convert diff --git a/frontend/remote_api.py b/frontend/remote_api.py index 7ec0811..e6dbc1b 100644 --- a/frontend/remote_api.py +++ b/frontend/remote_api.py @@ -51,10 +51,12 @@ def get_white_list_chat_dict() -> dict[str, any]: search_api_route = "/tg/api/v1/file/search" -def search_database_by_keyword(keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool) -> list[any] | None: +def search_database_by_keyword( + sign: str, keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool +) -> list[any] | None: request_url = background_server_url + search_api_route req_body = { - "token": param.web.token, + "sign": sign, "search": keyword, "chat_ids": chat_list, "index": offset, @@ -83,3 +85,7 @@ def convert_tg_link_to_proxy_link(link: str) -> str: return f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}" response_js = json.loads(response.content.decode("utf-8")) return response_js["url"] + + +def get_config_default_name() -> str: + return param.web.name diff --git a/frontend/search.py b/frontend/search.py index c4897be..056c4b4 100644 --- a/frontend/search.py +++ b/frontend/search.py @@ -9,7 +9,7 @@ import remote_api as api @st.experimental_fragment -def loop(): +def loop(sign: str): if "page_index" not in st.session_state: st.session_state.page_index = 1 if "force_skip" not in st.session_state: @@ -81,7 +81,7 @@ def loop(): except Exception as err: print(f"{err=},{traceback.format_exc()}") search_res = api.search_database_by_keyword( - st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order + sign, st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order ) status_bar.empty() if search_res is None: