chore: adjust code

This commit is contained in:
Hehesheng 2024-06-15 16:43:13 +08:00
parent ca66e251dd
commit 2210edf4ca
8 changed files with 253 additions and 86 deletions

View File

@ -1 +1,19 @@
自用脚本,随缘更新 # TgToolBox
A Telegram toolbox
## Run the project
### Install requirements.txt
### Config
### Run
## TODO
- [ ] Support photo
- [ ] chat search
- [ ] token encrypt
- [ ] The encryption key has an expiration time
- [ ] Photo gallery

View File

@ -2,7 +2,6 @@ import asyncio
import json import json
import time import time
import re import re
import rsa
import os import os
import functools import functools
import traceback import traceback
@ -35,10 +34,6 @@ class TgFileSystemClient(object):
worker_routines: list[asyncio.Task] worker_routines: list[asyncio.Task]
qr_login: QRLogin | None = None qr_login: QRLogin | None = None
login_task: asyncio.Task | None = None login_task: asyncio.Task | None = None
# rsa key
sign: str
public_key: rsa.PublicKey
private_key: rsa.PrivateKey
# task should: (task_id, callabledFunc) # task should: (task_id, callabledFunc)
task_queue: asyncio.Queue task_queue: asyncio.Queue
task_id: int = 0 task_id: int = 0
@ -51,6 +46,7 @@ class TgFileSystemClient(object):
session_name: str, session_name: str,
param: configParse.TgToFileSystemParameter, param: configParse.TgToFileSystemParameter,
db: UserManager, db: UserManager,
chunk_manager: MediaChunkHolderManager,
) -> None: ) -> None:
self.api_id = param.tgApi.api_id self.api_id = param.tgApi.api_id
self.api_hash = param.tgApi.api_hash self.api_hash = param.tgApi.api_hash
@ -64,12 +60,10 @@ class TgFileSystemClient(object):
if param.proxy.enable if param.proxy.enable
else {} else {}
) )
self.public_key, self.private_key = rsa.newkeys(1024)
self.client_param = next( self.client_param = next(
(client_param for client_param in param.clients if client_param.token == session_name), (client_param for client_param in param.clients if client_param.token == session_name),
configParse.TgToFileSystemParameter.ClientConfigPatameter(), configParse.TgToFileSystemParameter.ClientConfigPatameter(),
) )
self.sign = self.client_param.token
self.task_queue = asyncio.Queue() self.task_queue = asyncio.Queue()
self.client = TelegramClient( self.client = TelegramClient(
f"{os.path.dirname(__file__)}/db/{self.session_name}.session", f"{os.path.dirname(__file__)}/db/{self.session_name}.session",
@ -77,7 +71,7 @@ class TgFileSystemClient(object):
self.api_hash, self.api_hash,
proxy=self.proxy_param, proxy=self.proxy_param,
) )
self.media_chunk_manager = MediaChunkHolderManager() self.media_chunk_manager = chunk_manager
self.db = db self.db = db
self.worker_routines = [] self.worker_routines = []

View File

@ -1,13 +1,14 @@
from typing import Any
import asyncio import asyncio
import time import time
import hashlib import hashlib
import rsa
import os import os
import traceback import traceback
import logging import logging
from backend.TgFileSystemClient import TgFileSystemClient from backend.TgFileSystemClient import TgFileSystemClient
from backend.UserManager import UserManager from backend.UserManager import UserManager
from backend.MediaCacheManager import MediaChunkHolderManager
import configParse import configParse
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
@ -18,6 +19,10 @@ class TgFileSystemClientManager(object):
is_init: bool = False is_init: bool = False
param: configParse.TgToFileSystemParameter param: configParse.TgToFileSystemParameter
clients: dict[str, TgFileSystemClient] = {} clients: dict[str, TgFileSystemClient] = {}
# rsa key
cache_sign: str
public_key: rsa.PublicKey
private_key: rsa.PrivateKey
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
@ -29,6 +34,8 @@ class TgFileSystemClientManager(object):
self.param = param self.param = param
self.db = UserManager() self.db = UserManager()
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.media_chunk_manager = MediaChunkHolderManager()
self.public_key, self.private_key = rsa.newkeys(1024)
if self.loop.is_running(): if self.loop.is_running():
self.loop.create_task(self._start_clients()) self.loop.create_task(self._start_clients())
else: else:
@ -76,7 +83,7 @@ class TgFileSystemClientManager(object):
def create_client(self, client_id: str = None) -> TgFileSystemClient: def create_client(self, client_id: str = None) -> TgFileSystemClient:
if client_id is None: if client_id is None:
client_id = self.generate_client_id() client_id = self.generate_client_id()
client = TgFileSystemClient(client_id, self.param, self.db) client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager)
return client return client
def _register_client(self, client: TgFileSystemClient) -> bool: def _register_client(self, client: TgFileSystemClient) -> bool:

View File

@ -1,15 +1,16 @@
import asyncio import asyncio
import json import json
import os import os
import sys
import logging import logging
import traceback import traceback
from typing import Annotated
from urllib.parse import quote from urllib.parse import quote
import uvicorn import uvicorn
from fastapi import FastAPI, status, Request from fastapi import FastAPI, status, Request, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from contextlib import asynccontextmanager
from telethon import types, hints, utils from telethon import types, hints, utils
from pydantic import BaseModel from pydantic import BaseModel
@ -21,7 +22,6 @@ from backend.TgFileSystemClientManager import TgFileSystemClientManager
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
@asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
clients_mgr = TgFileSystemClientManager.get_instance() clients_mgr = TgFileSystemClientManager.get_instance()
res = await clients_mgr.get_status() res = await clients_mgr.get_status()
@ -131,48 +131,8 @@ async def get_tg_file_list(body: TgToFileListRequestBody):
@app.get("/tg/api/v1/file/msg") @app.get("/tg/api/v1/file/msg")
@apiutils.atimeit @apiutils.atimeit
async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request): async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request):
msg_id = mid
chat_id = cid
headers = {
# "content-type": "video/mp4",
"accept-ranges": "bytes",
"content-encoding": "identity",
# "content-length": stream_file_size,
"access-control-expose-headers": ("content-type, accept-ranges, content-length, " "content-range, content-encoding"),
}
range_header = request.headers.get("range")
try: try:
clients_mgr = TgFileSystemClientManager.get_instance() return api.get_media_file_stream(token, cid, mid, request)
client = await clients_mgr.get_client_force(token)
msg = await client.get_message(chat_id, msg_id)
file_size = msg.media.document.size
start = 0
end = file_size - 1
status_code = status.HTTP_200_OK
mime_type = msg.media.document.mime_type
headers["content-type"] = mime_type
# headers["content-length"] = str(file_size)
file_name = apiutils.get_message_media_name(msg)
if file_name == "":
maybe_file_type = mime_type.split("/")[-1]
file_name = f"{chat_id}.{msg_id}.{maybe_file_type}"
headers["Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'"
if range_header is not None:
start, end = apiutils.get_range_header(range_header, file_size)
size = end - start + 1
# headers["content-length"] = str(size)
headers["content-range"] = f"bytes {start}-{end}/{file_size}"
status_code = status.HTTP_206_PARTIAL_CONTENT
else:
headers["content-length"] = str(file_size)
headers["content-range"] = f"bytes 0-{file_size-1}/{file_size}"
return StreamingResponse(
client.streaming_get_iter(msg, start, end, request),
headers=headers,
media_type=mime_type,
status_code=status_code,
)
except Exception as err: except Exception as err:
logger.error(f"{err=},{traceback.format_exc()}") logger.error(f"{err=},{traceback.format_exc()}")
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
@ -263,6 +223,29 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
async def get_verify(q: str | None, skip: int = 0):
logger.info("run common param")
if skip < 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{q=},{skip=}")
@app.get("/tg/api/v1/test", dependencies=[Depends(get_verify)])
async def test_get_depends_verify_method(other: str = ""):
return Response()
async def post_verify(body: TgToChatListRequestBody | None = None):
if not body or not body.token:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return body
@app.post("/tg/api/v1/test", dependencies=[Depends(post_verify)])
async def test_get_depends_verify_method(body: TgToChatListRequestBody):
return Response()
if __name__ == "__main__": if __name__ == "__main__":
param = configParse.get_TgToFileSystemParameter() param = configParse.get_TgToFileSystemParameter()
uvicorn.run(app, host="0.0.0.0", port=param.base.port) isDebug = True if sys.gettrace() else False
uvicorn.run(app, host="0.0.0.0", port=param.base.port, reload=isDebug)

View File

@ -1,8 +1,12 @@
import traceback import traceback
import json import json
import logging import logging
from urllib.parse import quote
from telethon import types, hints, utils from telethon import types, hints, utils
import fastapi
from fastapi import Request
from fastapi.responses import StreamingResponse, Response
import configParse import configParse
from backend import apiutils from backend import apiutils
@ -57,3 +61,50 @@ async def get_clients_manager_status(detail: bool) -> dict[str, any]:
return ret return ret
ret["clist"] = await get_chat_details(clients_mgr) ret["clist"] = await get_chat_details(clients_mgr)
return ret return ret
async def get_media_file_stream(token: str, cid: int, mid: int, request: Request) -> StreamingResponse:
msg_id = mid
chat_id = cid
headers = {
# "content-type": "video/mp4",
"accept-ranges": "bytes",
"content-encoding": "identity",
# "content-length": stream_file_size,
"access-control-expose-headers": ("content-type, accept-ranges, content-length, " "content-range, content-encoding"),
}
range_header = request.headers.get("range")
clients_mgr = TgFileSystemClientManager.get_instance()
client = await clients_mgr.get_client_force(token)
msg = await client.get_message(chat_id, msg_id)
if not isinstance(msg.media, types.MessageMediaDocument) and not isinstance(msg.media, types.MessageMediaPhoto):
raise RuntimeError(f"request don't support: {msg.media=}")
file_size = msg.media.document.size
start = 0
end = file_size - 1
status_code = fastapi.status.HTTP_200_OK
mime_type = msg.media.document.mime_type
headers["content-type"] = mime_type
# headers["content-length"] = str(file_size)
file_name = apiutils.get_message_media_name(msg)
if file_name == "":
maybe_file_type = mime_type.split("/")[-1]
file_name = f"{chat_id}.{msg_id}.{maybe_file_type}"
headers["Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'"
if range_header is not None:
start, end = apiutils.get_range_header(range_header, file_size)
size = end - start + 1
# headers["content-length"] = str(size)
headers["content-range"] = f"bytes {start}-{end}/{file_size}"
status_code = fastapi.status.HTTP_206_PARTIAL_CONTENT
else:
headers["content-length"] = str(file_size)
headers["content-range"] = f"bytes 0-{file_size-1}/{file_size}"
return StreamingResponse(
client.streaming_get_iter(msg, start, end, request),
headers=headers,
media_type=mime_type,
status_code=status_code,
)

View File

@ -2,13 +2,14 @@ import time
import logging import logging
from fastapi import status, HTTPException from fastapi import status, HTTPException
from telethon import types from telethon import types, utils
from functools import wraps from functools import wraps
import configParse import configParse
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
def get_range_header(range_header: str, file_size: int) -> tuple[int, int]: def get_range_header(range_header: str, file_size: int) -> tuple[int, int]:
def _invalid_range(): def _invalid_range():
return HTTPException( return HTTPException(
@ -28,87 +29,201 @@ def get_range_header(range_header: str, file_size: int) -> tuple[int, int]:
return start, end return start, end
def get_message_media_name(msg: types.Message) -> str: def _get_message_media_document_kind_and_names(document: types.MessageMediaDocument) -> tuple[str, str]:
if msg.media is None or msg.media.document is None: """Gets kind and possible names for :tl:`DocumentAttribute`."""
return "" kind = "document"
for attr in msg.media.document.attributes: possible_names = []
for attr in document.attributes:
if isinstance(attr, types.DocumentAttributeFilename): if isinstance(attr, types.DocumentAttributeFilename):
return attr.file_name possible_names.insert(0, attr.file_name)
elif isinstance(attr, types.DocumentAttributeAudio):
kind = "audio"
if attr.performer and attr.title:
possible_names.append("{} - {}".format(attr.performer, attr.title))
elif attr.performer:
possible_names.append(attr.performer)
elif attr.title:
possible_names.append(attr.title)
elif attr.voice:
kind = "voice"
return kind, possible_names
def get_message_media_name(msg: types.Message) -> str:
if msg.media is None:
return "" return ""
match type(msg.media):
case types.MessageMediaPhoto:
return f"{msg.media.photo.id}.jpg"
case types.MessageMediaDocument:
kind, possible_names = _get_message_media_document_kind_and_names(msg.media.document)
try:
name = None if possible_names is None else next(x for x in possible_names if x)
except StopIteration:
name = None
if name:
return name
extension = utils.get_extension(msg.media)
peer_id = utils.get_peer_id(msg)
return f"{kind}_{peer_id}-{msg.id}{extension}"
case _:
return ""
def _get_message_media_valid_photo(msg: types.Message) -> types.Photo | None:
if msg.media is None:
return None
photo = msg.media
if isinstance(photo, types.MessageMediaPhoto):
photo = photo.photo
if not isinstance(photo, types.Photo):
return None
return photo
def _sort_message_media_photo_thumbs(thumbs: list[any]) -> list[any]:
def sort_thumbs(thumb):
if isinstance(thumb, types.PhotoStrippedSize):
return 1, len(thumb.bytes)
if isinstance(thumb, types.PhotoCachedSize):
return 1, len(thumb.bytes)
if isinstance(thumb, types.PhotoSize):
return 1, thumb.size
if isinstance(thumb, types.PhotoSizeProgressive):
return 1, max(thumb.sizes)
if isinstance(thumb, types.VideoSize):
return 2, thumb.size
# Empty size or invalid should go last
return 0, 0
thumbs = list(sorted(thumbs), key=sort_thumbs)
for i in reversed(range(len(thumbs))):
if isinstance(thumbs[i], types.PhotoPathSize):
thumbs.pop(i)
return thumbs
def _get_message_media_photo_file_last_photo_size(thumbs: list[any]):
thumbs = _sort_message_media_photo_thumbs(thumbs)
size = thumbs[-1] if thumbs else None
if not size or isinstance(size, types.PhotoSizeEmpty):
return None
return size
def get_message_media_photo_file_name(msg: types.Message) -> str:
photo = _get_message_media_valid_photo(msg)
if not photo:
return ""
size = _get_message_media_photo_file_last_photo_size(photo.sizes + (photo.video_sizes or []))
if not size:
return ""
if isinstance(size, types.VideoSize):
return f"{photo.id}.mp4"
return f"{photo.id}.jpg"
def get_message_media_photo_file_size(msg: types.Message) -> int:
photo = _get_message_media_valid_photo(msg)
if not photo:
return 0
size = _get_message_media_photo_file_last_photo_size(photo.sizes + (photo.video_sizes or []))
if not size:
return 0
if isinstance(size, types.PhotoStrippedSize):
return len(utils.stripped_photo_to_jpg(size.bytes))
elif isinstance(size, types.PhotoCachedSize):
return len(size.bytes)
if isinstance(size, types.PhotoSizeProgressive):
return max(size.sizes)
return size.size
def get_message_media_name_from_dict(msg: dict[str, any]) -> str: def get_message_media_name_from_dict(msg: dict[str, any]) -> str:
doc = None doc = None
try: try:
doc = msg['media']['document'] doc = msg["media"]["document"]
except: except:
pass pass
file_name = None file_name = None
if doc is not None: if doc is not None:
for attr in doc['attributes']: for attr in doc["attributes"]:
file_name = attr.get('file_name') file_name = attr.get("file_name")
if file_name != "" and file_name is not None: if file_name != "" and file_name is not None:
break break
if file_name == "" or file_name is None: if file_name == "" or file_name is None:
file_name = "unknown.tmp" file_name = "unknown.tmp"
return file_name return file_name
def get_message_chat_id_from_dict(msg: dict[str, any]) -> int: def get_message_chat_id_from_dict(msg: dict[str, any]) -> int:
try: try:
return msg['peer_id']['channel_id'] return msg["peer_id"]["channel_id"]
except: except:
pass pass
return 0 return 0
def get_message_msg_id_from_dict(msg: dict[str, any]) -> int: def get_message_msg_id_from_dict(msg: dict[str, any]) -> int:
try: try:
return msg['id'] return msg["id"]
except: except:
pass pass
return 0 return 0
def timeit_sec(func): def timeit_sec(func):
@wraps(func) @wraps(func)
def timeit_wrapper(*args, **kwargs): def timeit_wrapper(*args, **kwargs):
logger.debug( logger.debug(f"Function called {func.__name__}{args} {kwargs}")
f'Function called {func.__name__}{args} {kwargs}')
start_time = time.perf_counter() start_time = time.perf_counter()
result = func(*args, **kwargs) result = func(*args, **kwargs)
end_time = time.perf_counter() end_time = time.perf_counter()
total_time = end_time - start_time total_time = end_time - start_time
logger.debug( logger.debug(f"Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds")
f'Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
return result return result
return timeit_wrapper return timeit_wrapper
def timeit(func): def timeit(func):
if configParse.get_TgToFileSystemParameter().base.timeit_enable: if configParse.get_TgToFileSystemParameter().base.timeit_enable:
@wraps(func) @wraps(func)
def timeit_wrapper(*args, **kwargs): def timeit_wrapper(*args, **kwargs):
logger.debug( logger.debug(f"Function called {func.__name__}{args} {kwargs}")
f'Function called {func.__name__}{args} {kwargs}')
start_time = time.perf_counter() start_time = time.perf_counter()
result = func(*args, **kwargs) result = func(*args, **kwargs)
end_time = time.perf_counter() end_time = time.perf_counter()
total_time = end_time - start_time total_time = end_time - start_time
logger.debug( logger.debug(f"Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds")
f'Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
return result return result
return timeit_wrapper return timeit_wrapper
return func return func
def atimeit(func): def atimeit(func):
if configParse.get_TgToFileSystemParameter().base.timeit_enable: if configParse.get_TgToFileSystemParameter().base.timeit_enable:
@wraps(func) @wraps(func)
async def timeit_wrapper(*args, **kwargs): async def timeit_wrapper(*args, **kwargs):
logger.debug( logger.debug(f"AFunction called {func.__name__}{args} {kwargs}")
f'AFunction called {func.__name__}{args} {kwargs}')
start_time = time.perf_counter() start_time = time.perf_counter()
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
end_time = time.perf_counter() end_time = time.perf_counter()
total_time = end_time - start_time total_time = end_time - start_time
logger.debug( logger.debug(f"AFunction quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds")
f'AFunction quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
return result return result
return timeit_wrapper return timeit_wrapper
return func return func

View File

@ -80,7 +80,6 @@ def convert_tg_link_to_proxy_link(link: str) -> str:
request_url = background_server_url + link_convert_api_route + f"?link={link}" request_url = background_server_url + link_convert_api_route + f"?link={link}"
response = requests.get(request_url) response = requests.get(request_url)
if response.status_code != 200: if response.status_code != 200:
print(f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}") return f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}"
return ""
response_js = json.loads(response.content.decode("utf-8")) response_js = json.loads(response.content.decode("utf-8"))
return response_js["url"] return response_js["url"]

View File

@ -31,7 +31,7 @@ def loop():
wait_client_ready.status("Server Initializing") wait_client_ready.status("Server Initializing")
st.session_state.chat_dict = api.get_white_list_chat_dict() st.session_state.chat_dict = api.get_white_list_chat_dict()
wait_client_ready.empty() wait_client_ready.empty()
st.query_params.search_key = st.text_input("**搜索🔎**", value=keyword) st.query_params.search_key = st.text_input("**Search🔎**", value=keyword)
chat_list = [] chat_list = []
for _, chat_info in st.session_state.chat_dict.items(): for _, chat_info in st.session_state.chat_dict.items():
chat_list.append(chat_info["title"]) chat_list.append(chat_info["title"])
@ -39,13 +39,13 @@ def loop():
columns = st.columns([4, 4, 1]) columns = st.columns([4, 4, 1])
with columns[0]: with columns[0]:
st.query_params.search_res_limit = str( st.query_params.search_res_limit = str(
st.number_input("**每页结果**", min_value=1, max_value=100, value=res_limit, format="%d") st.number_input("**Results per page**", min_value=1, max_value=100, value=res_limit, format="%d")
) )
with columns[1]: with columns[1]:
st.session_state.chat_select_list = st.multiselect("**Search in**", chat_list, default=chat_list) st.session_state.chat_select_list = st.multiselect("**Search in**", chat_list, default=chat_list)
with columns[2]: with columns[2]:
st.text("排序") st.text("Sort")
st.query_params.is_order = st.toggle("顺序", value=isorder) st.query_params.is_order = st.toggle("Time🔼", value=isorder)
search_limit_container = st.container() search_limit_container = st.container()
with search_limit_container: with search_limit_container: