feat: sign generate
This commit is contained in:
parent
7511d4ad46
commit
ff997c7434
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,6 +9,7 @@ __pycache__
|
||||
*.toml
|
||||
*.db
|
||||
*.service
|
||||
*.pem
|
||||
log
|
||||
cache_media
|
||||
tmp
|
||||
|
@ -162,7 +162,7 @@ class MediaChunkHolderManager(object):
|
||||
def __init__(self) -> None:
|
||||
self.chunk_lru = collections.OrderedDict()
|
||||
self.disk_chunk_cache = diskcache.Cache(
|
||||
f"{os.path.dirname(__file__)}/cache_media", size_limit=MediaChunkHolderManager.MAX_CACHE_SIZE * 2
|
||||
f"{os.path.dirname(__file__)}/db/cache_media", size_limit=MediaChunkHolderManager.MAX_CACHE_SIZE * 2
|
||||
)
|
||||
self._restore_cache()
|
||||
|
||||
|
@ -61,8 +61,8 @@ class TgFileSystemClient(object):
|
||||
else {}
|
||||
)
|
||||
self.client_param = next(
|
||||
(client_param for client_param in param.clients if client_param.token == session_name),
|
||||
configParse.TgToFileSystemParameter.ClientConfigPatameter(),
|
||||
(client_param for client_param in param.clients if client_param.name == session_name),
|
||||
configParse.TgToFileSystemParameter.ClientConfigPatameter(name="__tmp__"),
|
||||
)
|
||||
self.task_queue = asyncio.Queue()
|
||||
self.client = TelegramClient(
|
||||
|
@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
import time
|
||||
import base64
|
||||
import hashlib
|
||||
import rsa
|
||||
import os
|
||||
from enum import IntEnum, unique, auto
|
||||
import time
|
||||
import traceback
|
||||
import logging
|
||||
|
||||
@ -14,7 +17,16 @@ import configParse
|
||||
logger = logging.getLogger(__file__.split("/")[-1])
|
||||
|
||||
|
||||
@unique
|
||||
class EnumSignLevel(IntEnum):
|
||||
ADMIN = auto()
|
||||
NORMAL = auto()
|
||||
VIST = auto()
|
||||
NONE = auto()
|
||||
|
||||
|
||||
class TgFileSystemClientManager(object):
|
||||
TIME_MS_24HOURS: int = 24 * 60 * 60 * 1000
|
||||
MAX_MANAGE_CLIENTS: int = 10
|
||||
is_init: bool = False
|
||||
param: configParse.TgToFileSystemParameter
|
||||
@ -35,7 +47,7 @@ class TgFileSystemClientManager(object):
|
||||
self.db = UserManager()
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.media_chunk_manager = MediaChunkHolderManager()
|
||||
self.public_key, self.private_key = rsa.newkeys(1024)
|
||||
self._init_rsa_keys()
|
||||
if self.loop.is_running():
|
||||
self.loop.create_task(self._start_clients())
|
||||
else:
|
||||
@ -47,7 +59,7 @@ class TgFileSystemClientManager(object):
|
||||
async def _start_clients(self) -> None:
|
||||
# init cache clients
|
||||
for client_config in self.param.clients:
|
||||
client = self.create_client(client_id=client_config.token)
|
||||
client = self.create_client(client_config.name)
|
||||
self._register_client(client)
|
||||
for _, client in self.clients.items():
|
||||
try:
|
||||
@ -57,11 +69,97 @@ class TgFileSystemClientManager(object):
|
||||
logger.warning(f"start client: {err=}, {traceback.format_exc()}")
|
||||
self.is_init = True
|
||||
|
||||
def _init_rsa_keys(self):
|
||||
key_dir = f"{os.path.dirname(__file__)}/db"
|
||||
pub_key_path = f"{key_dir}/pub.pem"
|
||||
pri_key_path = f"{key_dir}/pri.pem"
|
||||
if not os.path.isfile(pub_key_path) or not os.path.isfile(pri_key_path):
|
||||
self.public_key, self.private_key = rsa.newkeys(512)
|
||||
with open(pub_key_path, "wb") as f:
|
||||
f.write(self.public_key.save_pkcs1())
|
||||
with open(pri_key_path, "wb") as f:
|
||||
f.write(self.private_key.save_pkcs1())
|
||||
else:
|
||||
with open(pub_key_path, "rb") as f:
|
||||
self.public_key = rsa.PublicKey.load_pkcs1(f.read())
|
||||
with open(pri_key_path, "rb") as f:
|
||||
self.private_key = rsa.PrivateKey.load_pkcs1(f.read())
|
||||
|
||||
def generate_sign(
|
||||
self, client_id: str, sign_type: EnumSignLevel = EnumSignLevel.NORMAL, salt: str = None, valid_time: int = -1
|
||||
) -> str:
|
||||
timestamp = int(time.time())
|
||||
if valid_time == -1:
|
||||
timestamp += self.TIME_MS_24HOURS
|
||||
elif valid_time == 0:
|
||||
timestamp = 0
|
||||
else:
|
||||
timestamp += valid_time * 1000
|
||||
need_encrypt_str = f"ts={timestamp};l={sign_type.value};"
|
||||
if salt:
|
||||
need_encrypt_str += f"s={hashlib.md5(salt).hexdigest()[:8]};"
|
||||
# rsa 512 bits only
|
||||
valid_len = 512 // 8 - 11
|
||||
valid_len -= len(need_encrypt_str)
|
||||
# id=xxxxx;
|
||||
valid_len -= len("id=;")
|
||||
if valid_len < 0:
|
||||
logger.error(f"{need_encrypt_str=},{traceback.format_exc()}")
|
||||
raise RuntimeError(f"generate sign too big")
|
||||
real_client_id = client_id[:valid_len]
|
||||
if len(real_client_id) != len(client_id):
|
||||
logger.warning(f"client id too long: {client_id} -> {real_client_id}")
|
||||
need_encrypt_str += f"id={real_client_id};"
|
||||
need_encrypt_bin = need_encrypt_str.encode()
|
||||
sign_bin = rsa.encrypt(need_encrypt_bin, self.public_key)
|
||||
sign = base64.b64encode(sign_bin).decode()
|
||||
logger.info(f"generate {sign_type.name} sign: {sign}")
|
||||
return sign
|
||||
|
||||
def parse_sign(self, sign: str) -> dict[str, any] | None:
|
||||
try:
|
||||
res_dict = {}
|
||||
sign_bin = base64.b64decode(sign)
|
||||
decrypt_bin = rsa.decrypt(sign_bin, self.private_key)
|
||||
decrypt_str = decrypt_bin.decode()
|
||||
for key_value_str in decrypt_str.split(";"):
|
||||
if key_value_str == "":
|
||||
continue
|
||||
key, value = key_value_str.split("=")
|
||||
res_dict[key] = value
|
||||
except Exception as err:
|
||||
logger.warning(f"verify sign {err=}, {traceback.format_exc()}")
|
||||
return None
|
||||
return res_dict
|
||||
|
||||
@staticmethod
|
||||
def get_sign_client_id(key_map: dict[str, any]) -> str:
|
||||
return key_map.get("id")
|
||||
|
||||
def verify_sign(
|
||||
self,
|
||||
sign: str,
|
||||
client_id: str = None,
|
||||
v_ts: bool = True,
|
||||
target_level: EnumSignLevel = EnumSignLevel.NONE,
|
||||
salt: str = None,
|
||||
) -> bool:
|
||||
key_map = self.parse_sign(sign)
|
||||
if not key_map:
|
||||
return False
|
||||
if client_id and (not key_map.get("id") or not client_id.startswith(key_map.get("id"))):
|
||||
return False
|
||||
if not key_map.get("l") or target_level.value < int(key_map.get("l")):
|
||||
return False
|
||||
if v_ts and int(key_map.get("ts", 0)) > 0 and (int(time.time()) - int(key_map.get("ts", 0)) > 0):
|
||||
return False
|
||||
if salt and hashlib.md5(key_map.get("s", "")).hexdigest() != salt:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_status(self) -> dict[str, any]:
|
||||
clients_status = [
|
||||
{
|
||||
"status": client.is_valid(),
|
||||
}
|
||||
{"status": client.is_valid(), "name": client.session_name, "sign": self.generate_sign(client.session_name)}
|
||||
for _, client in self.clients.items()
|
||||
]
|
||||
return {"init": self.is_init, "clients": clients_status}
|
||||
@ -77,12 +175,7 @@ class TgFileSystemClientManager(object):
|
||||
session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session"
|
||||
return os.path.isfile(session_db_file)
|
||||
|
||||
def generate_client_id(self) -> str:
|
||||
return hashlib.md5((str(time.perf_counter()) + self.param.base.salt).encode("utf-8")).hexdigest()
|
||||
|
||||
def create_client(self, client_id: str = None) -> TgFileSystemClient:
|
||||
if client_id is None:
|
||||
client_id = self.generate_client_id()
|
||||
def create_client(self, client_id: str) -> TgFileSystemClient:
|
||||
client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager)
|
||||
return client
|
||||
|
||||
|
@ -41,7 +41,7 @@ app.add_middleware(
|
||||
|
||||
|
||||
class TgToFileListRequestBody(BaseModel):
|
||||
token: str
|
||||
sign: str
|
||||
search: str = ""
|
||||
chat_ids: list[int] = []
|
||||
index: int = 0
|
||||
@ -51,15 +51,31 @@ class TgToFileListRequestBody(BaseModel):
|
||||
inc: bool = False
|
||||
|
||||
|
||||
@app.post("/tg/api/v1/file/search")
|
||||
async def verify_post_sign(body: TgToFileListRequestBody):
|
||||
clients_mgr = TgFileSystemClientManager.get_instance()
|
||||
if not clients_mgr.verify_sign(body.sign):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{body}")
|
||||
|
||||
|
||||
async def verify_get_sign(sign: str):
|
||||
clients_mgr = TgFileSystemClientManager.get_instance()
|
||||
sign = sign.replace(" ", "+")
|
||||
if not clients_mgr.verify_sign(sign):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{sign}")
|
||||
return sign
|
||||
|
||||
|
||||
@app.post("/tg/api/v1/file/search", dependencies=[Depends(verify_post_sign)])
|
||||
@apiutils.atimeit
|
||||
async def search_tg_file_list(body: TgToFileListRequestBody):
|
||||
try:
|
||||
param = configParse.get_TgToFileSystemParameter()
|
||||
clients_mgr = TgFileSystemClientManager.get_instance()
|
||||
param = configParse.get_TgToFileSystemParameter()
|
||||
res = hints.TotalList()
|
||||
res_type = "msg"
|
||||
client = await clients_mgr.get_client_force(body.token)
|
||||
sign_info = clients_mgr.parse_sign(body.sign)
|
||||
client_id = TgFileSystemClientManager.get_sign_client_id(sign_info)
|
||||
client = await clients_mgr.get_client_force(client_id)
|
||||
res_dict = []
|
||||
res = await client.get_messages_by_search_db(
|
||||
body.chat_ids, body.search, limit=body.length, inc=body.inc, offset=body.index
|
||||
@ -75,7 +91,7 @@ async def search_tg_file_list(body: TgToFileListRequestBody):
|
||||
res_dict.append(msg_info)
|
||||
|
||||
client_dict = json.loads(client.to_json())
|
||||
client_dict["sign"] = body.token
|
||||
client_dict["sign"] = body.sign
|
||||
|
||||
response_dict = {
|
||||
"client": client_dict,
|
||||
@ -128,17 +144,18 @@ async def get_tg_file_list(body: TgToFileListRequestBody):
|
||||
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
@app.get("/tg/api/v1/file/msg")
|
||||
@app.get("/tg/api/v1/file/msg", deprecated=[Depends(verify_get_sign)])
|
||||
@apiutils.atimeit
|
||||
async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request):
|
||||
async def get_tg_file_media_stream(sign: str, cid: int, mid: int, request: Request):
|
||||
try:
|
||||
return await api.get_media_file_stream(token, cid, mid, request)
|
||||
sign = sign.replace(" ", "+")
|
||||
return await api.get_media_file_stream(sign, cid, mid, request)
|
||||
except Exception as err:
|
||||
logger.error(f"{err=},{traceback.format_exc()}")
|
||||
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
@app.get("/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}")
|
||||
@app.get("/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}", dependencies=[Depends(verify_get_sign)])
|
||||
@apiutils.atimeit
|
||||
async def get_tg_file_media(chat_id: int | str, msg_id: int, file_name: str, sign: str, req: Request):
|
||||
try:
|
||||
@ -223,15 +240,20 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques
|
||||
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
async def get_verify(q: str | None, skip: int = 0):
|
||||
logger.info("run common param")
|
||||
if skip < 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{q=},{skip=}")
|
||||
async def get_verify(id: str = None):
|
||||
if id is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{id=}")
|
||||
client_mgr = TgFileSystemClientManager.get_instance()
|
||||
client = await client_mgr.get_client_force(id)
|
||||
if not client.is_valid():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{id=}")
|
||||
|
||||
|
||||
@app.get("/tg/api/v1/test", dependencies=[Depends(get_verify)])
|
||||
async def test_get_depends_verify_method(other: str = ""):
|
||||
return Response()
|
||||
async def test_get_depends_verify_method(id: str, other: str = ""):
|
||||
client_mgr = TgFileSystemClientManager.get_instance()
|
||||
client = await client_mgr.get_client_force(id)
|
||||
return Response((await client.client.get_me()).stringify())
|
||||
|
||||
|
||||
async def post_verify(body: TgToChatListRequestBody | None = None):
|
||||
|
@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse, Response
|
||||
|
||||
import configParse
|
||||
from backend import apiutils
|
||||
from backend.TgFileSystemClientManager import TgFileSystemClientManager
|
||||
from backend.TgFileSystemClientManager import TgFileSystemClientManager, EnumSignLevel
|
||||
|
||||
|
||||
logger = logging.getLogger(__file__.split("/")[-1])
|
||||
@ -38,9 +38,8 @@ async def link_convert(link: str) -> str:
|
||||
msg = await client.get_message(chat_id_or_name, msg_id)
|
||||
file_name = apiutils.get_message_media_name(msg)
|
||||
param = configParse.get_TgToFileSystemParameter()
|
||||
url = (
|
||||
f"{param.base.exposed_url}/tg/api/v1/file/get/{utils.get_peer_id(msg.peer_id)}/{msg.id}/{file_name}?sign={client.sign}"
|
||||
)
|
||||
sign = clients_mgr.generate_sign(client.session_name, EnumSignLevel.VIST)
|
||||
url = f"{param.base.exposed_url}/tg/api/v1/file/get/{utils.get_peer_id(msg.peer_id)}/{msg.id}/{file_name}?sign={sign}"
|
||||
return url
|
||||
|
||||
|
||||
@ -63,7 +62,7 @@ async def get_clients_manager_status(detail: bool) -> dict[str, any]:
|
||||
return ret
|
||||
|
||||
|
||||
async def get_media_file_stream(token: str, cid: int, mid: int, request: Request) -> StreamingResponse:
|
||||
async def get_media_file_stream(sign: str, cid: int, mid: int, request: Request) -> StreamingResponse:
|
||||
msg_id = mid
|
||||
chat_id = cid
|
||||
headers = {
|
||||
@ -76,7 +75,9 @@ async def get_media_file_stream(token: str, cid: int, mid: int, request: Request
|
||||
range_header = request.headers.get("range")
|
||||
|
||||
clients_mgr = TgFileSystemClientManager.get_instance()
|
||||
client = await clients_mgr.get_client_force(token)
|
||||
sign_info = clients_mgr.parse_sign(sign)
|
||||
client_id = TgFileSystemClientManager.get_sign_client_id(sign_info)
|
||||
client = await clients_mgr.get_client_force(client_id)
|
||||
msg = await client.get_message(chat_id, msg_id)
|
||||
if not isinstance(msg.media, types.MessageMediaDocument) and not isinstance(msg.media, types.MessageMediaPhoto):
|
||||
raise RuntimeError(f"request don't support: {msg.media=}")
|
||||
|
@ -6,15 +6,16 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class TgToFileSystemParameter(BaseModel):
|
||||
|
||||
class BaseParameter(BaseModel):
|
||||
salt: str = ""
|
||||
exposed_url: str = "http://127.0.0.1:7777"
|
||||
port: int = 7777
|
||||
timeit_enable: bool = False
|
||||
|
||||
base: BaseParameter
|
||||
|
||||
class ClientConfigPatameter(BaseModel):
|
||||
token: str = ""
|
||||
name: str
|
||||
interval: float = 0.1
|
||||
whitelist_chat: list[int] = []
|
||||
clients: list[ClientConfigPatameter]
|
||||
@ -33,7 +34,7 @@ class TgToFileSystemParameter(BaseModel):
|
||||
|
||||
class TgWebParameter(BaseModel):
|
||||
enable: bool = False
|
||||
token: str = ""
|
||||
name: str = ""
|
||||
port: int = 2000
|
||||
web: TgWebParameter
|
||||
|
||||
|
@ -8,6 +8,7 @@ st.set_page_config(page_title="TgToolbox", page_icon="🕹️", layout="wide", i
|
||||
|
||||
backend_status = api.get_backend_client_status()
|
||||
need_login = False
|
||||
sign = ""
|
||||
|
||||
if backend_status is None or not backend_status["init"]:
|
||||
st.status("Server not ready")
|
||||
@ -15,8 +16,10 @@ if backend_status is None or not backend_status["init"]:
|
||||
st.rerun()
|
||||
|
||||
for v in backend_status["clients"]:
|
||||
if not v["status"]:
|
||||
need_login = True
|
||||
if v["name"] != api.get_config_default_name():
|
||||
continue
|
||||
need_login = not v["status"]
|
||||
sign = v["sign"]
|
||||
|
||||
if need_login:
|
||||
import login
|
||||
@ -28,7 +31,7 @@ search_tab, link_convert_tab = st.tabs(["Search", "Link Convert"])
|
||||
with search_tab:
|
||||
import search
|
||||
|
||||
search.loop()
|
||||
search.loop(sign)
|
||||
with link_convert_tab:
|
||||
import link_convert
|
||||
|
||||
|
@ -51,10 +51,12 @@ def get_white_list_chat_dict() -> dict[str, any]:
|
||||
search_api_route = "/tg/api/v1/file/search"
|
||||
|
||||
|
||||
def search_database_by_keyword(keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool) -> list[any] | None:
|
||||
def search_database_by_keyword(
|
||||
sign: str, keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool
|
||||
) -> list[any] | None:
|
||||
request_url = background_server_url + search_api_route
|
||||
req_body = {
|
||||
"token": param.web.token,
|
||||
"sign": sign,
|
||||
"search": keyword,
|
||||
"chat_ids": chat_list,
|
||||
"index": offset,
|
||||
@ -83,3 +85,7 @@ def convert_tg_link_to_proxy_link(link: str) -> str:
|
||||
return f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}"
|
||||
response_js = json.loads(response.content.decode("utf-8"))
|
||||
return response_js["url"]
|
||||
|
||||
|
||||
def get_config_default_name() -> str:
|
||||
return param.web.name
|
||||
|
@ -9,7 +9,7 @@ import remote_api as api
|
||||
|
||||
|
||||
@st.experimental_fragment
|
||||
def loop():
|
||||
def loop(sign: str):
|
||||
if "page_index" not in st.session_state:
|
||||
st.session_state.page_index = 1
|
||||
if "force_skip" not in st.session_state:
|
||||
@ -81,7 +81,7 @@ def loop():
|
||||
except Exception as err:
|
||||
print(f"{err=},{traceback.format_exc()}")
|
||||
search_res = api.search_database_by_keyword(
|
||||
st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order
|
||||
sign, st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order
|
||||
)
|
||||
status_bar.empty()
|
||||
if search_res is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user