TgToFileSystem/backend/TgFileSystemClientManager.py
2024-06-18 22:52:26 +08:00

213 lines
7.7 KiB
Python

import asyncio
import time
import base64
import hashlib
import rsa
import os
from enum import IntEnum, unique, auto
import time
import traceback
import logging
from backend.TgFileSystemClient import TgFileSystemClient
from backend.UserManager import UserManager
from backend.MediaCacheManager import MediaChunkHolderManager
import configParse
logger = logging.getLogger(__file__.split("/")[-1])
@unique
class EnumSignLevel(IntEnum):
ADMIN = auto()
NORMAL = auto()
VIST = auto()
NONE = auto()
class TgFileSystemClientManager(object):
TIME_MS_24HOURS: int = 24 * 60 * 60 * 1000
MAX_MANAGE_CLIENTS: int = 10
is_init: bool = False
param: configParse.TgToFileSystemParameter
clients: dict[str, TgFileSystemClient] = {}
# rsa key
cache_sign: str
public_key: rsa.PublicKey
private_key: rsa.PrivateKey
@classmethod
def get_instance(cls):
if not hasattr(TgFileSystemClientManager, "_instance"):
TgFileSystemClientManager._instance = TgFileSystemClientManager(configParse.get_TgToFileSystemParameter())
return TgFileSystemClientManager._instance
def __init__(self, param: configParse.TgToFileSystemParameter) -> None:
self.param = param
self.db = UserManager()
self.loop = asyncio.get_running_loop()
self.media_chunk_manager = MediaChunkHolderManager()
self._init_rsa_keys()
if self.loop.is_running():
self.loop.create_task(self._start_clients())
else:
self.loop.run_until_complete(self._start_clients())
def __del__(self) -> None:
self.clients.clear()
async def _start_clients(self) -> None:
# init cache clients
for client_config in self.param.clients:
client = self.create_client(client_config.name)
self._register_client(client)
for _, client in self.clients.items():
try:
if not client.is_valid():
await client.start()
except Exception as err:
logger.warning(f"start client: {err=}, {traceback.format_exc()}")
self.is_init = True
def _init_rsa_keys(self):
key_dir = f"{os.path.dirname(__file__)}/db"
pub_key_path = f"{key_dir}/pub.pem"
pri_key_path = f"{key_dir}/pri.pem"
if not os.path.isfile(pub_key_path) or not os.path.isfile(pri_key_path):
self.public_key, self.private_key = rsa.newkeys(512)
with open(pub_key_path, "wb") as f:
f.write(self.public_key.save_pkcs1())
with open(pri_key_path, "wb") as f:
f.write(self.private_key.save_pkcs1())
else:
with open(pub_key_path, "rb") as f:
self.public_key = rsa.PublicKey.load_pkcs1(f.read())
with open(pri_key_path, "rb") as f:
self.private_key = rsa.PrivateKey.load_pkcs1(f.read())
def generate_sign(
self, client_id: str, sign_type: EnumSignLevel = EnumSignLevel.NORMAL, salt: str = None, valid_time: int = -1
) -> str:
timestamp = int(time.time())
if valid_time == -1:
timestamp += self.TIME_MS_24HOURS
elif valid_time == 0:
timestamp = 0
else:
timestamp += valid_time * 1000
need_encrypt_str = f"ts={timestamp};l={sign_type.value};"
if salt:
need_encrypt_str += f"s={hashlib.md5(salt).hexdigest()[:8]};"
# rsa 512 bits only
valid_len = 512 // 8 - 11
valid_len -= len(need_encrypt_str)
# id=xxxxx;
valid_len -= len("id=;")
if valid_len < 0:
logger.error(f"{need_encrypt_str=},{traceback.format_exc()}")
raise RuntimeError(f"generate sign too big")
real_client_id = client_id[:valid_len]
if len(real_client_id) != len(client_id):
logger.warning(f"client id too long: {client_id} -> {real_client_id}")
need_encrypt_str += f"id={real_client_id};"
need_encrypt_bin = need_encrypt_str.encode()
sign_bin = rsa.encrypt(need_encrypt_bin, self.public_key)
sign = base64.b64encode(sign_bin).decode()
sign = sign.replace("+", "-")
sign = sign.replace("/", "_")
logger.info(f"generate {sign_type.name} sign: {sign}")
return sign
def parse_sign(self, sign: str) -> dict[str, any] | None:
sign = sign.replace("-", "+")
sign = sign.replace("_", "/")
try:
res_dict = {}
sign_bin = base64.b64decode(sign)
decrypt_bin = rsa.decrypt(sign_bin, self.private_key)
decrypt_str = decrypt_bin.decode()
for key_value_str in decrypt_str.split(";"):
if key_value_str == "":
continue
key, value = key_value_str.split("=")
res_dict[key] = value
except Exception as err:
logger.warning(f"verify sign {err=}, {traceback.format_exc()}")
return None
return res_dict
@staticmethod
def get_sign_client_id(key_map: dict[str, any]) -> str:
return key_map.get("id")
def verify_sign(
self,
sign: str,
client_id: str = None,
v_ts: bool = True,
target_level: EnumSignLevel = EnumSignLevel.NONE,
salt: str = None,
) -> bool:
key_map = self.parse_sign(sign)
if not key_map:
return False
if client_id and (not key_map.get("id") or not client_id.startswith(key_map.get("id"))):
return False
if not key_map.get("l") or target_level.value < int(key_map.get("l")):
return False
if v_ts and int(key_map.get("ts", 0)) > 0 and (int(time.time()) - int(key_map.get("ts", 0)) > 0):
return False
if salt and hashlib.md5(key_map.get("s", "")).hexdigest() != salt:
return False
return True
async def get_status(self) -> dict[str, any]:
clients_status = [
{"status": client.is_valid(), "name": client.session_name, "sign": self.generate_sign(client.session_name)}
for _, client in self.clients.items()
]
return {"init": self.is_init, "clients": clients_status}
async def login_clients(self) -> str:
for _, client in self.clients.items():
login_url = await client.login()
if login_url != "":
return login_url
return ""
def check_client_session_exist(self, client_id: str) -> bool:
session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session"
return os.path.isfile(session_db_file)
def create_client(self, client_id: str) -> TgFileSystemClient:
client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager)
return client
def _register_client(self, client: TgFileSystemClient) -> bool:
self.clients[client.session_name] = client
return True
def _unregister_client(self, client_id: str) -> bool:
self.clients.pop(client_id)
return True
def get_client(self, client_id: str) -> TgFileSystemClient:
client = self.clients.get(client_id)
return client
def get_first_client(self) -> TgFileSystemClient:
for client in self.clients.values():
return client
return None
async def get_client_force(self, client_id: str) -> TgFileSystemClient:
client = self.get_client(client_id)
if client is None:
if not self.check_client_session_exist(client_id):
raise RuntimeError("Client session does not found.")
client = self.create_client(client_id=client_id)
if not client.is_valid():
await client.start()
self._register_client(client)
return client