feat: sign generate

This commit is contained in:
Hehesheng 2024-06-16 22:38:40 +08:00
parent 7511d4ad46
commit ff997c7434
10 changed files with 172 additions and 45 deletions

1
.gitignore vendored
View File

@ -9,6 +9,7 @@ __pycache__
*.toml
*.db
*.service
*.pem
log
cache_media
tmp

View File

@ -162,7 +162,7 @@ class MediaChunkHolderManager(object):
def __init__(self) -> None:
self.chunk_lru = collections.OrderedDict()
self.disk_chunk_cache = diskcache.Cache(
f"{os.path.dirname(__file__)}/cache_media", size_limit=MediaChunkHolderManager.MAX_CACHE_SIZE * 2
f"{os.path.dirname(__file__)}/db/cache_media", size_limit=MediaChunkHolderManager.MAX_CACHE_SIZE * 2
)
self._restore_cache()

View File

@ -61,8 +61,8 @@ class TgFileSystemClient(object):
else {}
)
self.client_param = next(
(client_param for client_param in param.clients if client_param.token == session_name),
configParse.TgToFileSystemParameter.ClientConfigPatameter(),
(client_param for client_param in param.clients if client_param.name == session_name),
configParse.TgToFileSystemParameter.ClientConfigPatameter(name="__tmp__"),
)
self.task_queue = asyncio.Queue()
self.client = TelegramClient(

View File

@ -1,8 +1,11 @@
import asyncio
import time
import base64
import hashlib
import rsa
import os
from enum import IntEnum, unique, auto
import time
import traceback
import logging
@ -14,7 +17,16 @@ import configParse
logger = logging.getLogger(__file__.split("/")[-1])
@unique
class EnumSignLevel(IntEnum):
ADMIN = auto()
NORMAL = auto()
VIST = auto()
NONE = auto()
class TgFileSystemClientManager(object):
TIME_MS_24HOURS: int = 24 * 60 * 60 * 1000
MAX_MANAGE_CLIENTS: int = 10
is_init: bool = False
param: configParse.TgToFileSystemParameter
@ -35,7 +47,7 @@ class TgFileSystemClientManager(object):
self.db = UserManager()
self.loop = asyncio.get_running_loop()
self.media_chunk_manager = MediaChunkHolderManager()
self.public_key, self.private_key = rsa.newkeys(1024)
self._init_rsa_keys()
if self.loop.is_running():
self.loop.create_task(self._start_clients())
else:
@ -47,7 +59,7 @@ class TgFileSystemClientManager(object):
async def _start_clients(self) -> None:
# init cache clients
for client_config in self.param.clients:
client = self.create_client(client_id=client_config.token)
client = self.create_client(client_config.name)
self._register_client(client)
for _, client in self.clients.items():
try:
@ -57,11 +69,97 @@ class TgFileSystemClientManager(object):
logger.warning(f"start client: {err=}, {traceback.format_exc()}")
self.is_init = True
def _init_rsa_keys(self):
key_dir = f"{os.path.dirname(__file__)}/db"
pub_key_path = f"{key_dir}/pub.pem"
pri_key_path = f"{key_dir}/pri.pem"
if not os.path.isfile(pub_key_path) or not os.path.isfile(pri_key_path):
self.public_key, self.private_key = rsa.newkeys(512)
with open(pub_key_path, "wb") as f:
f.write(self.public_key.save_pkcs1())
with open(pri_key_path, "wb") as f:
f.write(self.private_key.save_pkcs1())
else:
with open(pub_key_path, "rb") as f:
self.public_key = rsa.PublicKey.load_pkcs1(f.read())
with open(pri_key_path, "rb") as f:
self.private_key = rsa.PrivateKey.load_pkcs1(f.read())
def generate_sign(
self, client_id: str, sign_type: EnumSignLevel = EnumSignLevel.NORMAL, salt: str = None, valid_time: int = -1
) -> str:
timestamp = int(time.time())
if valid_time == -1:
timestamp += self.TIME_MS_24HOURS
elif valid_time == 0:
timestamp = 0
else:
timestamp += valid_time * 1000
need_encrypt_str = f"ts={timestamp};l={sign_type.value};"
if salt:
need_encrypt_str += f"s={hashlib.md5(salt).hexdigest()[:8]};"
# rsa 512 bits only
valid_len = 512 // 8 - 11
valid_len -= len(need_encrypt_str)
# id=xxxxx;
valid_len -= len("id=;")
if valid_len < 0:
logger.error(f"{need_encrypt_str=},{traceback.format_exc()}")
raise RuntimeError(f"generate sign too big")
real_client_id = client_id[:valid_len]
if len(real_client_id) != len(client_id):
logger.warning(f"client id too long: {client_id} -> {real_client_id}")
need_encrypt_str += f"id={real_client_id};"
need_encrypt_bin = need_encrypt_str.encode()
sign_bin = rsa.encrypt(need_encrypt_bin, self.public_key)
sign = base64.b64encode(sign_bin).decode()
logger.info(f"generate {sign_type.name} sign: {sign}")
return sign
def parse_sign(self, sign: str) -> dict[str, any] | None:
try:
res_dict = {}
sign_bin = base64.b64decode(sign)
decrypt_bin = rsa.decrypt(sign_bin, self.private_key)
decrypt_str = decrypt_bin.decode()
for key_value_str in decrypt_str.split(";"):
if key_value_str == "":
continue
key, value = key_value_str.split("=")
res_dict[key] = value
except Exception as err:
logger.warning(f"verify sign {err=}, {traceback.format_exc()}")
return None
return res_dict
@staticmethod
def get_sign_client_id(key_map: dict[str, any]) -> str:
return key_map.get("id")
def verify_sign(
self,
sign: str,
client_id: str = None,
v_ts: bool = True,
target_level: EnumSignLevel = EnumSignLevel.NONE,
salt: str = None,
) -> bool:
key_map = self.parse_sign(sign)
if not key_map:
return False
if client_id and (not key_map.get("id") or not client_id.startswith(key_map.get("id"))):
return False
if not key_map.get("l") or target_level.value < int(key_map.get("l")):
return False
if v_ts and int(key_map.get("ts", 0)) > 0 and (int(time.time()) - int(key_map.get("ts", 0)) > 0):
return False
if salt and hashlib.md5(key_map.get("s", "")).hexdigest() != salt:
return False
return True
async def get_status(self) -> dict[str, any]:
clients_status = [
{
"status": client.is_valid(),
}
{"status": client.is_valid(), "name": client.session_name, "sign": self.generate_sign(client.session_name)}
for _, client in self.clients.items()
]
return {"init": self.is_init, "clients": clients_status}
@ -77,12 +175,7 @@ class TgFileSystemClientManager(object):
session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session"
return os.path.isfile(session_db_file)
def generate_client_id(self) -> str:
return hashlib.md5((str(time.perf_counter()) + self.param.base.salt).encode("utf-8")).hexdigest()
def create_client(self, client_id: str = None) -> TgFileSystemClient:
if client_id is None:
client_id = self.generate_client_id()
def create_client(self, client_id: str) -> TgFileSystemClient:
client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager)
return client

View File

@ -41,7 +41,7 @@ app.add_middleware(
class TgToFileListRequestBody(BaseModel):
token: str
sign: str
search: str = ""
chat_ids: list[int] = []
index: int = 0
@ -51,15 +51,31 @@ class TgToFileListRequestBody(BaseModel):
inc: bool = False
@app.post("/tg/api/v1/file/search")
async def verify_post_sign(body: TgToFileListRequestBody):
clients_mgr = TgFileSystemClientManager.get_instance()
if not clients_mgr.verify_sign(body.sign):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{body}")
async def verify_get_sign(sign: str):
clients_mgr = TgFileSystemClientManager.get_instance()
sign = sign.replace(" ", "+")
if not clients_mgr.verify_sign(sign):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{sign}")
return sign
@app.post("/tg/api/v1/file/search", dependencies=[Depends(verify_post_sign)])
@apiutils.atimeit
async def search_tg_file_list(body: TgToFileListRequestBody):
try:
param = configParse.get_TgToFileSystemParameter()
clients_mgr = TgFileSystemClientManager.get_instance()
param = configParse.get_TgToFileSystemParameter()
res = hints.TotalList()
res_type = "msg"
client = await clients_mgr.get_client_force(body.token)
sign_info = clients_mgr.parse_sign(body.sign)
client_id = TgFileSystemClientManager.get_sign_client_id(sign_info)
client = await clients_mgr.get_client_force(client_id)
res_dict = []
res = await client.get_messages_by_search_db(
body.chat_ids, body.search, limit=body.length, inc=body.inc, offset=body.index
@ -75,7 +91,7 @@ async def search_tg_file_list(body: TgToFileListRequestBody):
res_dict.append(msg_info)
client_dict = json.loads(client.to_json())
client_dict["sign"] = body.token
client_dict["sign"] = body.sign
response_dict = {
"client": client_dict,
@ -128,17 +144,18 @@ async def get_tg_file_list(body: TgToFileListRequestBody):
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
@app.get("/tg/api/v1/file/msg")
@app.get("/tg/api/v1/file/msg", deprecated=[Depends(verify_get_sign)])
@apiutils.atimeit
async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request):
async def get_tg_file_media_stream(sign: str, cid: int, mid: int, request: Request):
try:
return await api.get_media_file_stream(token, cid, mid, request)
sign = sign.replace(" ", "+")
return await api.get_media_file_stream(sign, cid, mid, request)
except Exception as err:
logger.error(f"{err=},{traceback.format_exc()}")
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
@app.get("/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}")
@app.get("/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}", dependencies=[Depends(verify_get_sign)])
@apiutils.atimeit
async def get_tg_file_media(chat_id: int | str, msg_id: int, file_name: str, sign: str, req: Request):
try:
@ -223,15 +240,20 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
async def get_verify(q: str | None, skip: int = 0):
logger.info("run common param")
if skip < 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{q=},{skip=}")
async def get_verify(id: str = None):
if id is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{id=}")
client_mgr = TgFileSystemClientManager.get_instance()
client = await client_mgr.get_client_force(id)
if not client.is_valid():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{id=}")
@app.get("/tg/api/v1/test", dependencies=[Depends(get_verify)])
async def test_get_depends_verify_method(other: str = ""):
return Response()
async def test_get_depends_verify_method(id: str, other: str = ""):
client_mgr = TgFileSystemClientManager.get_instance()
client = await client_mgr.get_client_force(id)
return Response((await client.client.get_me()).stringify())
async def post_verify(body: TgToChatListRequestBody | None = None):

View File

@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse, Response
import configParse
from backend import apiutils
from backend.TgFileSystemClientManager import TgFileSystemClientManager
from backend.TgFileSystemClientManager import TgFileSystemClientManager, EnumSignLevel
logger = logging.getLogger(__file__.split("/")[-1])
@ -38,9 +38,8 @@ async def link_convert(link: str) -> str:
msg = await client.get_message(chat_id_or_name, msg_id)
file_name = apiutils.get_message_media_name(msg)
param = configParse.get_TgToFileSystemParameter()
url = (
f"{param.base.exposed_url}/tg/api/v1/file/get/{utils.get_peer_id(msg.peer_id)}/{msg.id}/{file_name}?sign={client.sign}"
)
sign = clients_mgr.generate_sign(client.session_name, EnumSignLevel.VIST)
url = f"{param.base.exposed_url}/tg/api/v1/file/get/{utils.get_peer_id(msg.peer_id)}/{msg.id}/{file_name}?sign={sign}"
return url
@ -63,7 +62,7 @@ async def get_clients_manager_status(detail: bool) -> dict[str, any]:
return ret
async def get_media_file_stream(token: str, cid: int, mid: int, request: Request) -> StreamingResponse:
async def get_media_file_stream(sign: str, cid: int, mid: int, request: Request) -> StreamingResponse:
msg_id = mid
chat_id = cid
headers = {
@ -76,7 +75,9 @@ async def get_media_file_stream(token: str, cid: int, mid: int, request: Request
range_header = request.headers.get("range")
clients_mgr = TgFileSystemClientManager.get_instance()
client = await clients_mgr.get_client_force(token)
sign_info = clients_mgr.parse_sign(sign)
client_id = TgFileSystemClientManager.get_sign_client_id(sign_info)
client = await clients_mgr.get_client_force(client_id)
msg = await client.get_message(chat_id, msg_id)
if not isinstance(msg.media, types.MessageMediaDocument) and not isinstance(msg.media, types.MessageMediaPhoto):
raise RuntimeError(f"request don't support: {msg.media=}")

View File

@ -6,15 +6,16 @@ from pydantic import BaseModel
class TgToFileSystemParameter(BaseModel):
class BaseParameter(BaseModel):
salt: str = ""
exposed_url: str = "http://127.0.0.1:7777"
port: int = 7777
timeit_enable: bool = False
base: BaseParameter
class ClientConfigPatameter(BaseModel):
token: str = ""
name: str
interval: float = 0.1
whitelist_chat: list[int] = []
clients: list[ClientConfigPatameter]
@ -33,7 +34,7 @@ class TgToFileSystemParameter(BaseModel):
class TgWebParameter(BaseModel):
enable: bool = False
token: str = ""
name: str = ""
port: int = 2000
web: TgWebParameter

View File

@ -8,6 +8,7 @@ st.set_page_config(page_title="TgToolbox", page_icon="🕹️", layout="wide", i
backend_status = api.get_backend_client_status()
need_login = False
sign = ""
if backend_status is None or not backend_status["init"]:
st.status("Server not ready")
@ -15,8 +16,10 @@ if backend_status is None or not backend_status["init"]:
st.rerun()
for v in backend_status["clients"]:
if not v["status"]:
need_login = True
if v["name"] != api.get_config_default_name():
continue
need_login = not v["status"]
sign = v["sign"]
if need_login:
import login
@ -28,7 +31,7 @@ search_tab, link_convert_tab = st.tabs(["Search", "Link Convert"])
with search_tab:
import search
search.loop()
search.loop(sign)
with link_convert_tab:
import link_convert

View File

@ -51,10 +51,12 @@ def get_white_list_chat_dict() -> dict[str, any]:
search_api_route = "/tg/api/v1/file/search"
def search_database_by_keyword(keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool) -> list[any] | None:
def search_database_by_keyword(
sign: str, keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool
) -> list[any] | None:
request_url = background_server_url + search_api_route
req_body = {
"token": param.web.token,
"sign": sign,
"search": keyword,
"chat_ids": chat_list,
"index": offset,
@ -83,3 +85,7 @@ def convert_tg_link_to_proxy_link(link: str) -> str:
return f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}"
response_js = json.loads(response.content.decode("utf-8"))
return response_js["url"]
def get_config_default_name() -> str:
return param.web.name

View File

@ -9,7 +9,7 @@ import remote_api as api
@st.experimental_fragment
def loop():
def loop(sign: str):
if "page_index" not in st.session_state:
st.session_state.page_index = 1
if "force_skip" not in st.session_state:
@ -81,7 +81,7 @@ def loop():
except Exception as err:
print(f"{err=},{traceback.format_exc()}")
search_res = api.search_database_by_keyword(
st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order
sign, st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order
)
status_bar.empty()
if search_res is None: