feat: sign generate

This commit is contained in:
Hehesheng 2024-06-16 22:38:40 +08:00
parent 7511d4ad46
commit ff997c7434
10 changed files with 172 additions and 45 deletions

1
.gitignore vendored
View File

@ -9,6 +9,7 @@ __pycache__
*.toml *.toml
*.db *.db
*.service *.service
*.pem
log log
cache_media cache_media
tmp tmp

View File

@ -162,7 +162,7 @@ class MediaChunkHolderManager(object):
def __init__(self) -> None: def __init__(self) -> None:
self.chunk_lru = collections.OrderedDict() self.chunk_lru = collections.OrderedDict()
self.disk_chunk_cache = diskcache.Cache( self.disk_chunk_cache = diskcache.Cache(
f"{os.path.dirname(__file__)}/cache_media", size_limit=MediaChunkHolderManager.MAX_CACHE_SIZE * 2 f"{os.path.dirname(__file__)}/db/cache_media", size_limit=MediaChunkHolderManager.MAX_CACHE_SIZE * 2
) )
self._restore_cache() self._restore_cache()

View File

@ -61,8 +61,8 @@ class TgFileSystemClient(object):
else {} else {}
) )
self.client_param = next( self.client_param = next(
(client_param for client_param in param.clients if client_param.token == session_name), (client_param for client_param in param.clients if client_param.name == session_name),
configParse.TgToFileSystemParameter.ClientConfigPatameter(), configParse.TgToFileSystemParameter.ClientConfigPatameter(name="__tmp__"),
) )
self.task_queue = asyncio.Queue() self.task_queue = asyncio.Queue()
self.client = TelegramClient( self.client = TelegramClient(

View File

@ -1,8 +1,11 @@
import asyncio import asyncio
import time import time
import base64
import hashlib import hashlib
import rsa import rsa
import os import os
from enum import IntEnum, unique, auto
import time
import traceback import traceback
import logging import logging
@ -14,7 +17,16 @@ import configParse
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
@unique
class EnumSignLevel(IntEnum):
ADMIN = auto()
NORMAL = auto()
VIST = auto()
NONE = auto()
class TgFileSystemClientManager(object): class TgFileSystemClientManager(object):
TIME_MS_24HOURS: int = 24 * 60 * 60 * 1000
MAX_MANAGE_CLIENTS: int = 10 MAX_MANAGE_CLIENTS: int = 10
is_init: bool = False is_init: bool = False
param: configParse.TgToFileSystemParameter param: configParse.TgToFileSystemParameter
@ -35,7 +47,7 @@ class TgFileSystemClientManager(object):
self.db = UserManager() self.db = UserManager()
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.media_chunk_manager = MediaChunkHolderManager() self.media_chunk_manager = MediaChunkHolderManager()
self.public_key, self.private_key = rsa.newkeys(1024) self._init_rsa_keys()
if self.loop.is_running(): if self.loop.is_running():
self.loop.create_task(self._start_clients()) self.loop.create_task(self._start_clients())
else: else:
@ -47,7 +59,7 @@ class TgFileSystemClientManager(object):
async def _start_clients(self) -> None: async def _start_clients(self) -> None:
# init cache clients # init cache clients
for client_config in self.param.clients: for client_config in self.param.clients:
client = self.create_client(client_id=client_config.token) client = self.create_client(client_config.name)
self._register_client(client) self._register_client(client)
for _, client in self.clients.items(): for _, client in self.clients.items():
try: try:
@ -57,11 +69,97 @@ class TgFileSystemClientManager(object):
logger.warning(f"start client: {err=}, {traceback.format_exc()}") logger.warning(f"start client: {err=}, {traceback.format_exc()}")
self.is_init = True self.is_init = True
def _init_rsa_keys(self):
key_dir = f"{os.path.dirname(__file__)}/db"
pub_key_path = f"{key_dir}/pub.pem"
pri_key_path = f"{key_dir}/pri.pem"
if not os.path.isfile(pub_key_path) or not os.path.isfile(pri_key_path):
self.public_key, self.private_key = rsa.newkeys(512)
with open(pub_key_path, "wb") as f:
f.write(self.public_key.save_pkcs1())
with open(pri_key_path, "wb") as f:
f.write(self.private_key.save_pkcs1())
else:
with open(pub_key_path, "rb") as f:
self.public_key = rsa.PublicKey.load_pkcs1(f.read())
with open(pri_key_path, "rb") as f:
self.private_key = rsa.PrivateKey.load_pkcs1(f.read())
def generate_sign(
self, client_id: str, sign_type: EnumSignLevel = EnumSignLevel.NORMAL, salt: str = None, valid_time: int = -1
) -> str:
timestamp = int(time.time())
if valid_time == -1:
timestamp += self.TIME_MS_24HOURS
elif valid_time == 0:
timestamp = 0
else:
timestamp += valid_time * 1000
need_encrypt_str = f"ts={timestamp};l={sign_type.value};"
if salt:
need_encrypt_str += f"s={hashlib.md5(salt).hexdigest()[:8]};"
# rsa 512 bits only
valid_len = 512 // 8 - 11
valid_len -= len(need_encrypt_str)
# id=xxxxx;
valid_len -= len("id=;")
if valid_len < 0:
logger.error(f"{need_encrypt_str=},{traceback.format_exc()}")
raise RuntimeError(f"generate sign too big")
real_client_id = client_id[:valid_len]
if len(real_client_id) != len(client_id):
logger.warning(f"client id too long: {client_id} -> {real_client_id}")
need_encrypt_str += f"id={real_client_id};"
need_encrypt_bin = need_encrypt_str.encode()
sign_bin = rsa.encrypt(need_encrypt_bin, self.public_key)
sign = base64.b64encode(sign_bin).decode()
logger.info(f"generate {sign_type.name} sign: {sign}")
return sign
def parse_sign(self, sign: str) -> dict[str, any] | None:
try:
res_dict = {}
sign_bin = base64.b64decode(sign)
decrypt_bin = rsa.decrypt(sign_bin, self.private_key)
decrypt_str = decrypt_bin.decode()
for key_value_str in decrypt_str.split(";"):
if key_value_str == "":
continue
key, value = key_value_str.split("=")
res_dict[key] = value
except Exception as err:
logger.warning(f"verify sign {err=}, {traceback.format_exc()}")
return None
return res_dict
@staticmethod
def get_sign_client_id(key_map: dict[str, any]) -> str:
return key_map.get("id")
def verify_sign(
self,
sign: str,
client_id: str = None,
v_ts: bool = True,
target_level: EnumSignLevel = EnumSignLevel.NONE,
salt: str = None,
) -> bool:
key_map = self.parse_sign(sign)
if not key_map:
return False
if client_id and (not key_map.get("id") or not client_id.startswith(key_map.get("id"))):
return False
if not key_map.get("l") or target_level.value < int(key_map.get("l")):
return False
if v_ts and int(key_map.get("ts", 0)) > 0 and (int(time.time()) - int(key_map.get("ts", 0)) > 0):
return False
if salt and hashlib.md5(key_map.get("s", "")).hexdigest() != salt:
return False
return True
async def get_status(self) -> dict[str, any]: async def get_status(self) -> dict[str, any]:
clients_status = [ clients_status = [
{ {"status": client.is_valid(), "name": client.session_name, "sign": self.generate_sign(client.session_name)}
"status": client.is_valid(),
}
for _, client in self.clients.items() for _, client in self.clients.items()
] ]
return {"init": self.is_init, "clients": clients_status} return {"init": self.is_init, "clients": clients_status}
@ -77,12 +175,7 @@ class TgFileSystemClientManager(object):
session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session" session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session"
return os.path.isfile(session_db_file) return os.path.isfile(session_db_file)
def generate_client_id(self) -> str: def create_client(self, client_id: str) -> TgFileSystemClient:
return hashlib.md5((str(time.perf_counter()) + self.param.base.salt).encode("utf-8")).hexdigest()
def create_client(self, client_id: str = None) -> TgFileSystemClient:
if client_id is None:
client_id = self.generate_client_id()
client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager) client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager)
return client return client

View File

@ -41,7 +41,7 @@ app.add_middleware(
class TgToFileListRequestBody(BaseModel): class TgToFileListRequestBody(BaseModel):
token: str sign: str
search: str = "" search: str = ""
chat_ids: list[int] = [] chat_ids: list[int] = []
index: int = 0 index: int = 0
@ -51,15 +51,31 @@ class TgToFileListRequestBody(BaseModel):
inc: bool = False inc: bool = False
@app.post("/tg/api/v1/file/search") async def verify_post_sign(body: TgToFileListRequestBody):
clients_mgr = TgFileSystemClientManager.get_instance()
if not clients_mgr.verify_sign(body.sign):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{body}")
async def verify_get_sign(sign: str):
clients_mgr = TgFileSystemClientManager.get_instance()
sign = sign.replace(" ", "+")
if not clients_mgr.verify_sign(sign):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{sign}")
return sign
@app.post("/tg/api/v1/file/search", dependencies=[Depends(verify_post_sign)])
@apiutils.atimeit @apiutils.atimeit
async def search_tg_file_list(body: TgToFileListRequestBody): async def search_tg_file_list(body: TgToFileListRequestBody):
try: try:
param = configParse.get_TgToFileSystemParameter()
clients_mgr = TgFileSystemClientManager.get_instance() clients_mgr = TgFileSystemClientManager.get_instance()
param = configParse.get_TgToFileSystemParameter()
res = hints.TotalList() res = hints.TotalList()
res_type = "msg" res_type = "msg"
client = await clients_mgr.get_client_force(body.token) sign_info = clients_mgr.parse_sign(body.sign)
client_id = TgFileSystemClientManager.get_sign_client_id(sign_info)
client = await clients_mgr.get_client_force(client_id)
res_dict = [] res_dict = []
res = await client.get_messages_by_search_db( res = await client.get_messages_by_search_db(
body.chat_ids, body.search, limit=body.length, inc=body.inc, offset=body.index body.chat_ids, body.search, limit=body.length, inc=body.inc, offset=body.index
@ -75,7 +91,7 @@ async def search_tg_file_list(body: TgToFileListRequestBody):
res_dict.append(msg_info) res_dict.append(msg_info)
client_dict = json.loads(client.to_json()) client_dict = json.loads(client.to_json())
client_dict["sign"] = body.token client_dict["sign"] = body.sign
response_dict = { response_dict = {
"client": client_dict, "client": client_dict,
@ -128,17 +144,18 @@ async def get_tg_file_list(body: TgToFileListRequestBody):
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
@app.get("/tg/api/v1/file/msg") @app.get("/tg/api/v1/file/msg", deprecated=[Depends(verify_get_sign)])
@apiutils.atimeit @apiutils.atimeit
async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request): async def get_tg_file_media_stream(sign: str, cid: int, mid: int, request: Request):
try: try:
return await api.get_media_file_stream(token, cid, mid, request) sign = sign.replace(" ", "+")
return await api.get_media_file_stream(sign, cid, mid, request)
except Exception as err: except Exception as err:
logger.error(f"{err=},{traceback.format_exc()}") logger.error(f"{err=},{traceback.format_exc()}")
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
@app.get("/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}") @app.get("/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}", dependencies=[Depends(verify_get_sign)])
@apiutils.atimeit @apiutils.atimeit
async def get_tg_file_media(chat_id: int | str, msg_id: int, file_name: str, sign: str, req: Request): async def get_tg_file_media(chat_id: int | str, msg_id: int, file_name: str, sign: str, req: Request):
try: try:
@ -223,15 +240,20 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
async def get_verify(q: str | None, skip: int = 0): async def get_verify(id: str = None):
logger.info("run common param") if id is None:
if skip < 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{id=}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{q=},{skip=}") client_mgr = TgFileSystemClientManager.get_instance()
client = await client_mgr.get_client_force(id)
if not client.is_valid():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{id=}")
@app.get("/tg/api/v1/test", dependencies=[Depends(get_verify)]) @app.get("/tg/api/v1/test", dependencies=[Depends(get_verify)])
async def test_get_depends_verify_method(other: str = ""): async def test_get_depends_verify_method(id: str, other: str = ""):
return Response() client_mgr = TgFileSystemClientManager.get_instance()
client = await client_mgr.get_client_force(id)
return Response((await client.client.get_me()).stringify())
async def post_verify(body: TgToChatListRequestBody | None = None): async def post_verify(body: TgToChatListRequestBody | None = None):

View File

@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse, Response
import configParse import configParse
from backend import apiutils from backend import apiutils
from backend.TgFileSystemClientManager import TgFileSystemClientManager from backend.TgFileSystemClientManager import TgFileSystemClientManager, EnumSignLevel
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
@ -38,9 +38,8 @@ async def link_convert(link: str) -> str:
msg = await client.get_message(chat_id_or_name, msg_id) msg = await client.get_message(chat_id_or_name, msg_id)
file_name = apiutils.get_message_media_name(msg) file_name = apiutils.get_message_media_name(msg)
param = configParse.get_TgToFileSystemParameter() param = configParse.get_TgToFileSystemParameter()
url = ( sign = clients_mgr.generate_sign(client.session_name, EnumSignLevel.VIST)
f"{param.base.exposed_url}/tg/api/v1/file/get/{utils.get_peer_id(msg.peer_id)}/{msg.id}/{file_name}?sign={client.sign}" url = f"{param.base.exposed_url}/tg/api/v1/file/get/{utils.get_peer_id(msg.peer_id)}/{msg.id}/{file_name}?sign={sign}"
)
return url return url
@ -63,7 +62,7 @@ async def get_clients_manager_status(detail: bool) -> dict[str, any]:
return ret return ret
async def get_media_file_stream(token: str, cid: int, mid: int, request: Request) -> StreamingResponse: async def get_media_file_stream(sign: str, cid: int, mid: int, request: Request) -> StreamingResponse:
msg_id = mid msg_id = mid
chat_id = cid chat_id = cid
headers = { headers = {
@ -76,7 +75,9 @@ async def get_media_file_stream(token: str, cid: int, mid: int, request: Request
range_header = request.headers.get("range") range_header = request.headers.get("range")
clients_mgr = TgFileSystemClientManager.get_instance() clients_mgr = TgFileSystemClientManager.get_instance()
client = await clients_mgr.get_client_force(token) sign_info = clients_mgr.parse_sign(sign)
client_id = TgFileSystemClientManager.get_sign_client_id(sign_info)
client = await clients_mgr.get_client_force(client_id)
msg = await client.get_message(chat_id, msg_id) msg = await client.get_message(chat_id, msg_id)
if not isinstance(msg.media, types.MessageMediaDocument) and not isinstance(msg.media, types.MessageMediaPhoto): if not isinstance(msg.media, types.MessageMediaDocument) and not isinstance(msg.media, types.MessageMediaPhoto):
raise RuntimeError(f"request don't support: {msg.media=}") raise RuntimeError(f"request don't support: {msg.media=}")

View File

@ -6,15 +6,16 @@ from pydantic import BaseModel
class TgToFileSystemParameter(BaseModel): class TgToFileSystemParameter(BaseModel):
class BaseParameter(BaseModel): class BaseParameter(BaseModel):
salt: str = ""
exposed_url: str = "http://127.0.0.1:7777" exposed_url: str = "http://127.0.0.1:7777"
port: int = 7777 port: int = 7777
timeit_enable: bool = False timeit_enable: bool = False
base: BaseParameter base: BaseParameter
class ClientConfigPatameter(BaseModel): class ClientConfigPatameter(BaseModel):
token: str = "" name: str
interval: float = 0.1 interval: float = 0.1
whitelist_chat: list[int] = [] whitelist_chat: list[int] = []
clients: list[ClientConfigPatameter] clients: list[ClientConfigPatameter]
@ -33,7 +34,7 @@ class TgToFileSystemParameter(BaseModel):
class TgWebParameter(BaseModel): class TgWebParameter(BaseModel):
enable: bool = False enable: bool = False
token: str = "" name: str = ""
port: int = 2000 port: int = 2000
web: TgWebParameter web: TgWebParameter

View File

@ -8,6 +8,7 @@ st.set_page_config(page_title="TgToolbox", page_icon="🕹️", layout="wide", i
backend_status = api.get_backend_client_status() backend_status = api.get_backend_client_status()
need_login = False need_login = False
sign = ""
if backend_status is None or not backend_status["init"]: if backend_status is None or not backend_status["init"]:
st.status("Server not ready") st.status("Server not ready")
@ -15,8 +16,10 @@ if backend_status is None or not backend_status["init"]:
st.rerun() st.rerun()
for v in backend_status["clients"]: for v in backend_status["clients"]:
if not v["status"]: if v["name"] != api.get_config_default_name():
need_login = True continue
need_login = not v["status"]
sign = v["sign"]
if need_login: if need_login:
import login import login
@ -28,7 +31,7 @@ search_tab, link_convert_tab = st.tabs(["Search", "Link Convert"])
with search_tab: with search_tab:
import search import search
search.loop() search.loop(sign)
with link_convert_tab: with link_convert_tab:
import link_convert import link_convert

View File

@ -51,10 +51,12 @@ def get_white_list_chat_dict() -> dict[str, any]:
search_api_route = "/tg/api/v1/file/search" search_api_route = "/tg/api/v1/file/search"
def search_database_by_keyword(keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool) -> list[any] | None: def search_database_by_keyword(
sign: str, keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool
) -> list[any] | None:
request_url = background_server_url + search_api_route request_url = background_server_url + search_api_route
req_body = { req_body = {
"token": param.web.token, "sign": sign,
"search": keyword, "search": keyword,
"chat_ids": chat_list, "chat_ids": chat_list,
"index": offset, "index": offset,
@ -83,3 +85,7 @@ def convert_tg_link_to_proxy_link(link: str) -> str:
return f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}" return f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}"
response_js = json.loads(response.content.decode("utf-8")) response_js = json.loads(response.content.decode("utf-8"))
return response_js["url"] return response_js["url"]
def get_config_default_name() -> str:
return param.web.name

View File

@ -9,7 +9,7 @@ import remote_api as api
@st.experimental_fragment @st.experimental_fragment
def loop(): def loop(sign: str):
if "page_index" not in st.session_state: if "page_index" not in st.session_state:
st.session_state.page_index = 1 st.session_state.page_index = 1
if "force_skip" not in st.session_state: if "force_skip" not in st.session_state:
@ -81,7 +81,7 @@ def loop():
except Exception as err: except Exception as err:
print(f"{err=},{traceback.format_exc()}") print(f"{err=},{traceback.format_exc()}")
search_res = api.search_database_by_keyword( search_res = api.search_database_by_keyword(
st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order sign, st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order
) )
status_bar.empty() status_bar.empty()
if search_res is None: if search_res is None: