chore: adjust code

This commit is contained in:
Hehesheng 2024-06-15 16:43:13 +08:00
parent ca66e251dd
commit 2210edf4ca
8 changed files with 253 additions and 86 deletions

View File

@ -1 +1,19 @@
自用脚本,随缘更新
# TgToolBox
A Telegram toolbox
## Run the project
### Install requirements.txt
### Config
### Run
## TODO
- [ ] Support photo
- [ ] chat search
- [ ] token encrypt
- [ ] The encryption key has an expiration time
- [ ] Photo gallery

View File

@ -2,7 +2,6 @@ import asyncio
import json
import time
import re
import rsa
import os
import functools
import traceback
@ -35,10 +34,6 @@ class TgFileSystemClient(object):
worker_routines: list[asyncio.Task]
qr_login: QRLogin | None = None
login_task: asyncio.Task | None = None
# rsa key
sign: str
public_key: rsa.PublicKey
private_key: rsa.PrivateKey
# task should: (task_id, callabledFunc)
task_queue: asyncio.Queue
task_id: int = 0
@ -51,6 +46,7 @@ class TgFileSystemClient(object):
session_name: str,
param: configParse.TgToFileSystemParameter,
db: UserManager,
chunk_manager: MediaChunkHolderManager,
) -> None:
self.api_id = param.tgApi.api_id
self.api_hash = param.tgApi.api_hash
@ -64,12 +60,10 @@ class TgFileSystemClient(object):
if param.proxy.enable
else {}
)
self.public_key, self.private_key = rsa.newkeys(1024)
self.client_param = next(
(client_param for client_param in param.clients if client_param.token == session_name),
configParse.TgToFileSystemParameter.ClientConfigPatameter(),
)
self.sign = self.client_param.token
self.task_queue = asyncio.Queue()
self.client = TelegramClient(
f"{os.path.dirname(__file__)}/db/{self.session_name}.session",
@ -77,7 +71,7 @@ class TgFileSystemClient(object):
self.api_hash,
proxy=self.proxy_param,
)
self.media_chunk_manager = MediaChunkHolderManager()
self.media_chunk_manager = chunk_manager
self.db = db
self.worker_routines = []

View File

@ -1,13 +1,14 @@
from typing import Any
import asyncio
import time
import hashlib
import rsa
import os
import traceback
import logging
from backend.TgFileSystemClient import TgFileSystemClient
from backend.UserManager import UserManager
from backend.MediaCacheManager import MediaChunkHolderManager
import configParse
logger = logging.getLogger(__file__.split("/")[-1])
@ -18,6 +19,10 @@ class TgFileSystemClientManager(object):
is_init: bool = False
param: configParse.TgToFileSystemParameter
clients: dict[str, TgFileSystemClient] = {}
# rsa key
cache_sign: str
public_key: rsa.PublicKey
private_key: rsa.PrivateKey
@classmethod
def get_instance(cls):
@ -29,6 +34,8 @@ class TgFileSystemClientManager(object):
self.param = param
self.db = UserManager()
self.loop = asyncio.get_running_loop()
self.media_chunk_manager = MediaChunkHolderManager()
self.public_key, self.private_key = rsa.newkeys(1024)
if self.loop.is_running():
self.loop.create_task(self._start_clients())
else:
@ -76,7 +83,7 @@ class TgFileSystemClientManager(object):
def create_client(self, client_id: str = None) -> TgFileSystemClient:
if client_id is None:
client_id = self.generate_client_id()
client = TgFileSystemClient(client_id, self.param, self.db)
client = TgFileSystemClient(client_id, self.param, self.db, self.media_chunk_manager)
return client
def _register_client(self, client: TgFileSystemClient) -> bool:

View File

@ -1,15 +1,16 @@
import asyncio
import json
import os
import sys
import logging
import traceback
from typing import Annotated
from urllib.parse import quote
import uvicorn
from fastapi import FastAPI, status, Request
from fastapi import FastAPI, status, Request, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from contextlib import asynccontextmanager
from telethon import types, hints, utils
from pydantic import BaseModel
@ -21,7 +22,6 @@ from backend.TgFileSystemClientManager import TgFileSystemClientManager
logger = logging.getLogger(__file__.split("/")[-1])
@asynccontextmanager
async def lifespan(app: FastAPI):
clients_mgr = TgFileSystemClientManager.get_instance()
res = await clients_mgr.get_status()
@ -131,48 +131,8 @@ async def get_tg_file_list(body: TgToFileListRequestBody):
@app.get("/tg/api/v1/file/msg")
@apiutils.atimeit
async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request):
msg_id = mid
chat_id = cid
headers = {
# "content-type": "video/mp4",
"accept-ranges": "bytes",
"content-encoding": "identity",
# "content-length": stream_file_size,
"access-control-expose-headers": ("content-type, accept-ranges, content-length, " "content-range, content-encoding"),
}
range_header = request.headers.get("range")
try:
clients_mgr = TgFileSystemClientManager.get_instance()
client = await clients_mgr.get_client_force(token)
msg = await client.get_message(chat_id, msg_id)
file_size = msg.media.document.size
start = 0
end = file_size - 1
status_code = status.HTTP_200_OK
mime_type = msg.media.document.mime_type
headers["content-type"] = mime_type
# headers["content-length"] = str(file_size)
file_name = apiutils.get_message_media_name(msg)
if file_name == "":
maybe_file_type = mime_type.split("/")[-1]
file_name = f"{chat_id}.{msg_id}.{maybe_file_type}"
headers["Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'"
if range_header is not None:
start, end = apiutils.get_range_header(range_header, file_size)
size = end - start + 1
# headers["content-length"] = str(size)
headers["content-range"] = f"bytes {start}-{end}/{file_size}"
status_code = status.HTTP_206_PARTIAL_CONTENT
else:
headers["content-length"] = str(file_size)
headers["content-range"] = f"bytes 0-{file_size-1}/{file_size}"
return StreamingResponse(
client.streaming_get_iter(msg, start, end, request),
headers=headers,
media_type=mime_type,
status_code=status_code,
)
return api.get_media_file_stream(token, cid, mid, request)
except Exception as err:
logger.error(f"{err=},{traceback.format_exc()}")
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
@ -263,6 +223,29 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
async def get_verify(q: str | None, skip: int = 0):
logger.info("run common param")
if skip < 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{q=},{skip=}")
@app.get("/tg/api/v1/test", dependencies=[Depends(get_verify)])
async def test_get_depends_verify_method(other: str = ""):
return Response()
async def post_verify(body: TgToChatListRequestBody | None = None):
if not body or not body.token:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return body
@app.post("/tg/api/v1/test", dependencies=[Depends(post_verify)])
async def test_get_depends_verify_method(body: TgToChatListRequestBody):
return Response()
if __name__ == "__main__":
param = configParse.get_TgToFileSystemParameter()
uvicorn.run(app, host="0.0.0.0", port=param.base.port)
isDebug = True if sys.gettrace() else False
uvicorn.run(app, host="0.0.0.0", port=param.base.port, reload=isDebug)

View File

@ -1,8 +1,12 @@
import traceback
import json
import logging
from urllib.parse import quote
from telethon import types, hints, utils
import fastapi
from fastapi import Request
from fastapi.responses import StreamingResponse, Response
import configParse
from backend import apiutils
@ -57,3 +61,50 @@ async def get_clients_manager_status(detail: bool) -> dict[str, any]:
return ret
ret["clist"] = await get_chat_details(clients_mgr)
return ret
async def get_media_file_stream(token: str, cid: int, mid: int, request: Request) -> StreamingResponse:
msg_id = mid
chat_id = cid
headers = {
# "content-type": "video/mp4",
"accept-ranges": "bytes",
"content-encoding": "identity",
# "content-length": stream_file_size,
"access-control-expose-headers": ("content-type, accept-ranges, content-length, " "content-range, content-encoding"),
}
range_header = request.headers.get("range")
clients_mgr = TgFileSystemClientManager.get_instance()
client = await clients_mgr.get_client_force(token)
msg = await client.get_message(chat_id, msg_id)
if not isinstance(msg.media, types.MessageMediaDocument) and not isinstance(msg.media, types.MessageMediaPhoto):
raise RuntimeError(f"request don't support: {msg.media=}")
file_size = msg.media.document.size
start = 0
end = file_size - 1
status_code = fastapi.status.HTTP_200_OK
mime_type = msg.media.document.mime_type
headers["content-type"] = mime_type
# headers["content-length"] = str(file_size)
file_name = apiutils.get_message_media_name(msg)
if file_name == "":
maybe_file_type = mime_type.split("/")[-1]
file_name = f"{chat_id}.{msg_id}.{maybe_file_type}"
headers["Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'"
if range_header is not None:
start, end = apiutils.get_range_header(range_header, file_size)
size = end - start + 1
# headers["content-length"] = str(size)
headers["content-range"] = f"bytes {start}-{end}/{file_size}"
status_code = fastapi.status.HTTP_206_PARTIAL_CONTENT
else:
headers["content-length"] = str(file_size)
headers["content-range"] = f"bytes 0-{file_size-1}/{file_size}"
return StreamingResponse(
client.streaming_get_iter(msg, start, end, request),
headers=headers,
media_type=mime_type,
status_code=status_code,
)

View File

@ -2,13 +2,14 @@ import time
import logging
from fastapi import status, HTTPException
from telethon import types
from telethon import types, utils
from functools import wraps
import configParse
logger = logging.getLogger(__file__.split("/")[-1])
def get_range_header(range_header: str, file_size: int) -> tuple[int, int]:
def _invalid_range():
return HTTPException(
@ -28,87 +29,201 @@ def get_range_header(range_header: str, file_size: int) -> tuple[int, int]:
return start, end
def get_message_media_name(msg: types.Message) -> str:
if msg.media is None or msg.media.document is None:
return ""
for attr in msg.media.document.attributes:
def _get_message_media_document_kind_and_names(document: types.MessageMediaDocument) -> tuple[str, str]:
"""Gets kind and possible names for :tl:`DocumentAttribute`."""
kind = "document"
possible_names = []
for attr in document.attributes:
if isinstance(attr, types.DocumentAttributeFilename):
return attr.file_name
return ""
possible_names.insert(0, attr.file_name)
elif isinstance(attr, types.DocumentAttributeAudio):
kind = "audio"
if attr.performer and attr.title:
possible_names.append("{} - {}".format(attr.performer, attr.title))
elif attr.performer:
possible_names.append(attr.performer)
elif attr.title:
possible_names.append(attr.title)
elif attr.voice:
kind = "voice"
return kind, possible_names
def get_message_media_name(msg: types.Message) -> str:
if msg.media is None:
return ""
match type(msg.media):
case types.MessageMediaPhoto:
return f"{msg.media.photo.id}.jpg"
case types.MessageMediaDocument:
kind, possible_names = _get_message_media_document_kind_and_names(msg.media.document)
try:
name = None if possible_names is None else next(x for x in possible_names if x)
except StopIteration:
name = None
if name:
return name
extension = utils.get_extension(msg.media)
peer_id = utils.get_peer_id(msg)
return f"{kind}_{peer_id}-{msg.id}{extension}"
case _:
return ""
def _get_message_media_valid_photo(msg: types.Message) -> types.Photo | None:
if msg.media is None:
return None
photo = msg.media
if isinstance(photo, types.MessageMediaPhoto):
photo = photo.photo
if not isinstance(photo, types.Photo):
return None
return photo
def _sort_message_media_photo_thumbs(thumbs: list[any]) -> list[any]:
def sort_thumbs(thumb):
if isinstance(thumb, types.PhotoStrippedSize):
return 1, len(thumb.bytes)
if isinstance(thumb, types.PhotoCachedSize):
return 1, len(thumb.bytes)
if isinstance(thumb, types.PhotoSize):
return 1, thumb.size
if isinstance(thumb, types.PhotoSizeProgressive):
return 1, max(thumb.sizes)
if isinstance(thumb, types.VideoSize):
return 2, thumb.size
# Empty size or invalid should go last
return 0, 0
thumbs = list(sorted(thumbs), key=sort_thumbs)
for i in reversed(range(len(thumbs))):
if isinstance(thumbs[i], types.PhotoPathSize):
thumbs.pop(i)
return thumbs
def _get_message_media_photo_file_last_photo_size(thumbs: list[any]):
thumbs = _sort_message_media_photo_thumbs(thumbs)
size = thumbs[-1] if thumbs else None
if not size or isinstance(size, types.PhotoSizeEmpty):
return None
return size
def get_message_media_photo_file_name(msg: types.Message) -> str:
photo = _get_message_media_valid_photo(msg)
if not photo:
return ""
size = _get_message_media_photo_file_last_photo_size(photo.sizes + (photo.video_sizes or []))
if not size:
return ""
if isinstance(size, types.VideoSize):
return f"{photo.id}.mp4"
return f"{photo.id}.jpg"
def get_message_media_photo_file_size(msg: types.Message) -> int:
photo = _get_message_media_valid_photo(msg)
if not photo:
return 0
size = _get_message_media_photo_file_last_photo_size(photo.sizes + (photo.video_sizes or []))
if not size:
return 0
if isinstance(size, types.PhotoStrippedSize):
return len(utils.stripped_photo_to_jpg(size.bytes))
elif isinstance(size, types.PhotoCachedSize):
return len(size.bytes)
if isinstance(size, types.PhotoSizeProgressive):
return max(size.sizes)
return size.size
def get_message_media_name_from_dict(msg: dict[str, any]) -> str:
doc = None
try:
doc = msg['media']['document']
doc = msg["media"]["document"]
except:
pass
file_name = None
if doc is not None:
for attr in doc['attributes']:
file_name = attr.get('file_name')
for attr in doc["attributes"]:
file_name = attr.get("file_name")
if file_name != "" and file_name is not None:
break
if file_name == "" or file_name is None:
file_name = "unknown.tmp"
return file_name
def get_message_chat_id_from_dict(msg: dict[str, any]) -> int:
try:
return msg['peer_id']['channel_id']
return msg["peer_id"]["channel_id"]
except:
pass
return 0
def get_message_msg_id_from_dict(msg: dict[str, any]) -> int:
try:
return msg['id']
return msg["id"]
except:
pass
return 0
def timeit_sec(func):
@wraps(func)
def timeit_wrapper(*args, **kwargs):
logger.debug(
f'Function called {func.__name__}{args} {kwargs}')
logger.debug(f"Function called {func.__name__}{args} {kwargs}")
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
total_time = end_time - start_time
logger.debug(
f'Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
logger.debug(f"Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds")
return result
return timeit_wrapper
def timeit(func):
if configParse.get_TgToFileSystemParameter().base.timeit_enable:
@wraps(func)
def timeit_wrapper(*args, **kwargs):
logger.debug(
f'Function called {func.__name__}{args} {kwargs}')
logger.debug(f"Function called {func.__name__}{args} {kwargs}")
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
total_time = end_time - start_time
logger.debug(
f'Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
logger.debug(f"Function quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds")
return result
return timeit_wrapper
return func
def atimeit(func):
if configParse.get_TgToFileSystemParameter().base.timeit_enable:
@wraps(func)
async def timeit_wrapper(*args, **kwargs):
logger.debug(
f'AFunction called {func.__name__}{args} {kwargs}')
logger.debug(f"AFunction called {func.__name__}{args} {kwargs}")
start_time = time.perf_counter()
result = await func(*args, **kwargs)
end_time = time.perf_counter()
total_time = end_time - start_time
logger.debug(
f'AFunction quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
logger.debug(f"AFunction quited {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds")
return result
return timeit_wrapper
return func

View File

@ -80,7 +80,6 @@ def convert_tg_link_to_proxy_link(link: str) -> str:
request_url = background_server_url + link_convert_api_route + f"?link={link}"
response = requests.get(request_url)
if response.status_code != 200:
print(f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}")
return ""
return f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}"
response_js = json.loads(response.content.decode("utf-8"))
return response_js["url"]

View File

@ -31,7 +31,7 @@ def loop():
wait_client_ready.status("Server Initializing")
st.session_state.chat_dict = api.get_white_list_chat_dict()
wait_client_ready.empty()
st.query_params.search_key = st.text_input("**搜索🔎**", value=keyword)
st.query_params.search_key = st.text_input("**Search🔎**", value=keyword)
chat_list = []
for _, chat_info in st.session_state.chat_dict.items():
chat_list.append(chat_info["title"])
@ -39,13 +39,13 @@ def loop():
columns = st.columns([4, 4, 1])
with columns[0]:
st.query_params.search_res_limit = str(
st.number_input("**每页结果**", min_value=1, max_value=100, value=res_limit, format="%d")
st.number_input("**Results per page**", min_value=1, max_value=100, value=res_limit, format="%d")
)
with columns[1]:
st.session_state.chat_select_list = st.multiselect("**Search in**", chat_list, default=chat_list)
with columns[2]:
st.text("排序")
st.query_params.is_order = st.toggle("顺序", value=isorder)
st.text("Sort")
st.query_params.is_order = st.toggle("Time🔼", value=isorder)
search_limit_container = st.container()
with search_limit_container: