feat: support login

This commit is contained in:
Hehesheng 2024-06-02 14:59:09 +08:00
parent cab27becf1
commit 7acd0f4712
12 changed files with 438 additions and 236 deletions

4
.gitignore vendored
View File

@ -1,13 +1,15 @@
__pycache__ __pycache__
.venv .venv
.idea
.vscode .vscode
*.session *.session
*.session-journal *.session-journal
*-journal
*.session.sql *.session.sql
*.toml *.toml
*.db *.db
*.service *.service
log log
cacheTest cache_media
tmp tmp
logs logs

View File

@ -1,3 +1,4 @@
import os
import functools import functools
import logging import logging
import bisect import bisect
@ -18,6 +19,11 @@ class MediaChunkHolder(object):
waiters: collections.deque[asyncio.Future] waiters: collections.deque[asyncio.Future]
requester: list[Request] = [] requester: list[Request] = []
chunk_id: int = 0 chunk_id: int = 0
unique_id: str = ""
@staticmethod
def generate_id(chat_id: int, msg_id: int, start: int) -> str:
return f"{chat_id}:{msg_id}:{start}"
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None: def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
self.chat_id = chat_id self.chat_id = chat_id
@ -27,6 +33,7 @@ class MediaChunkHolder(object):
self.mem = mem or bytes() self.mem = mem or bytes()
self.length = len(self.mem) self.length = len(self.mem)
self.waiters = collections.deque() self.waiters = collections.deque()
self.unique_id = MediaChunkHolder.generate_id(chat_id, msg_id, start)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"MediaChunk,start:{self.start},len:{self.length}" return f"MediaChunk,start:{self.start},len:{self.length}"
@ -114,8 +121,11 @@ class MediaChunkHolderManager(object):
unique_chunk_id: int = 0 unique_chunk_id: int = 0
chunk_lru: collections.OrderedDict[int, MediaChunkHolder] chunk_lru: collections.OrderedDict[int, MediaChunkHolder]
disk_chunk_cache: diskcache.Cache
def __init__(self) -> None: def __init__(self) -> None:
self.chunk_lru = collections.OrderedDict() self.chunk_lru = collections.OrderedDict()
self.disk_chunk_cache = diskcache.Cache(f"{os.path.dirname(__file__)}/cache_media", size_limit=2**30)
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]: def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]:
chat_cache = self.chunk_cache.get(msg.chat_id) chat_cache = self.chunk_cache.get(msg.chat_id)
@ -160,14 +170,8 @@ class MediaChunkHolderManager(object):
return res return res
def set_media_chunk(self, chunk: MediaChunkHolder) -> None: def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
cache_chat = self.chunk_cache.get(chunk.chat_id) cache_chat = self.chunk_cache.setdefault(chunk.chat_id, {})
if cache_chat is None: cache_msg = cache_chat.setdefault(chunk.msg_id, [])
self.chunk_cache[chunk.chat_id] = {}
cache_chat = self.chunk_cache[chunk.chat_id]
cache_msg = cache_chat.get(chunk.msg_id)
if cache_msg is None:
cache_chat[chunk.msg_id] = []
cache_msg = cache_chat[chunk.msg_id]
chunk.chunk_id = self.unique_chunk_id chunk.chunk_id = self.unique_chunk_id
self.unique_chunk_id += 1 self.unique_chunk_id += 1
bisect.insort(cache_msg, chunk) bisect.insort(cache_msg, chunk)
@ -283,4 +287,3 @@ class MediaBlockHolderManager(object):
def __init__(self, limit_size: int = DEFAULT_MAX_CACHE_SIZE, dir: str = 'cache') -> None: def __init__(self, limit_size: int = DEFAULT_MAX_CACHE_SIZE, dir: str = 'cache') -> None:
pass pass

View File

@ -10,6 +10,7 @@ import logging
from typing import Union, Optional from typing import Union, Optional
from telethon import TelegramClient, types, hints, events from telethon import TelegramClient, types, hints, events
from telethon.custom import QRLogin
from fastapi import Request from fastapi import Request
import configParse import configParse
@ -19,6 +20,7 @@ from backend.MediaCacheManager import MediaChunkHolder, MediaChunkHolderManager
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
class TgFileSystemClient(object): class TgFileSystemClient(object):
MAX_WORKER_ROUTINE = 4 MAX_WORKER_ROUTINE = 4
SINGLE_NET_CHUNK_SIZE = 512 * 1024 # 512kb SINGLE_NET_CHUNK_SIZE = 512 * 1024 # 512kb
@ -32,6 +34,8 @@ class TgFileSystemClient(object):
dialogs_cache: Optional[hints.TotalList] = None dialogs_cache: Optional[hints.TotalList] = None
msg_cache: list[types.Message] = [] msg_cache: list[types.Message] = []
worker_routines: list[asyncio.Task] = [] worker_routines: list[asyncio.Task] = []
qr_login: QRLogin | None = None
login_task: asyncio.Task | None = None
# task should: (task_id, callabledFunc) # task should: (task_id, callabledFunc)
task_queue: asyncio.Queue task_queue: asyncio.Queue
task_id: int = 0 task_id: int = 0
@ -39,19 +43,35 @@ class TgFileSystemClient(object):
# client config # client config
client_param: configParse.TgToFileSystemParameter.ClientConfigPatameter client_param: configParse.TgToFileSystemParameter.ClientConfigPatameter
def __init__(self, session_name: str, param: configParse.TgToFileSystemParameter, db: UserManager) -> None: def __init__(
self,
session_name: str,
param: configParse.TgToFileSystemParameter,
db: UserManager,
) -> None:
self.api_id = param.tgApi.api_id self.api_id = param.tgApi.api_id
self.api_hash = param.tgApi.api_hash self.api_hash = param.tgApi.api_hash
self.session_name = session_name self.session_name = session_name
self.proxy_param = { self.proxy_param = (
'proxy_type': param.proxy.proxy_type, {
'addr': param.proxy.addr, "proxy_type": param.proxy.proxy_type,
'port': param.proxy.port, "addr": param.proxy.addr,
} if param.proxy.enable else {} "port": param.proxy.port,
self.client_param = next((client_param for client_param in param.clients if client_param.token == session_name), configParse.TgToFileSystemParameter.ClientConfigPatameter()) }
if param.proxy.enable
else {}
)
self.client_param = next(
(client_param for client_param in param.clients if client_param.token == session_name),
configParse.TgToFileSystemParameter.ClientConfigPatameter(),
)
self.task_queue = asyncio.Queue() self.task_queue = asyncio.Queue()
self.client = TelegramClient( self.client = TelegramClient(
f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param) f"{os.path.dirname(__file__)}/db/{self.session_name}.session",
self.api_id,
self.api_hash,
proxy=self.proxy_param,
)
self.media_chunk_manager = MediaChunkHolderManager() self.media_chunk_manager = MediaChunkHolderManager()
self.db = db self.db = db
@ -72,6 +92,7 @@ class TgFileSystemClient(object):
raise RuntimeError("Client does not run.") raise RuntimeError("Client does not run.")
result = func(self, *args, **kwargs) result = func(self, *args, **kwargs)
return result return result
return call_check_wrapper return call_check_wrapper
def _acheck_before_call(func): def _acheck_before_call(func):
@ -80,6 +101,7 @@ class TgFileSystemClient(object):
raise RuntimeError("Client does not run.") raise RuntimeError("Client does not run.")
result = await func(self, *args, **kwargs) result = await func(self, *args, **kwargs)
return result return result
return call_check_wrapper return call_check_wrapper
@_check_before_call @_check_before_call
@ -100,6 +122,28 @@ class TgFileSystemClient(object):
msg: types.Message = event.message msg: types.Message = event.message
self.db.insert_by_message(self.me, msg) self.db.insert_by_message(self.me, msg)
async def login(self, mode: Union["phone", "qrcode"] = "qrcode") -> str:
if self.is_valid():
return ""
if mode == "phone":
raise NotImplementedError
if self.qr_login is not None:
return self.qr_login.url
self.qr_login = await self.client.qr_login()
async def wait_for_qr_login():
try:
await self.qr_login.wait()
await self.start()
except Exception as err:
logger.warning(f"wait for login, {err=}, {traceback.format_exc()}")
finally:
self.login_task = None
self.qr_login = None
self.login_task = self.client.loop.create_task(wait_for_qr_login())
return self.qr_login.url
async def start(self) -> None: async def start(self) -> None:
if self.is_valid(): if self.is_valid():
return return
@ -107,11 +151,9 @@ class TgFileSystemClient(object):
await self.client.connect() await self.client.connect()
self.me = await self.client.get_me() self.me = await self.client.get_me()
if self.me is None: if self.me is None:
raise RuntimeError( raise RuntimeError(f"The {self.session_name} Client Does Not Login")
f"The {self.session_name} Client Does Not Login")
for _ in range(self.MAX_WORKER_ROUTINE): for _ in range(self.MAX_WORKER_ROUTINE):
worker_routine = self.client.loop.create_task( worker_routine = self.client.loop.create_task(self._worker_routine_handler())
self._worker_routine_handler())
self.worker_routines.append(worker_routine) self.worker_routines.append(worker_routine)
if len(self.client_param.whitelist_chat) > 0: if len(self.client_param.whitelist_chat) > 0:
self._register_update_event(from_users=self.client_param.whitelist_chat) self._register_update_event(from_users=self.client_param.whitelist_chat)
@ -163,7 +205,6 @@ class TgFileSystemClient(object):
self.db.insert_by_message(self.me, msg) self.db.insert_by_message(self.me, msg)
logger.info(f"{chat_id} quit cache task.") logger.info(f"{chat_id} quit cache task.")
@_acheck_before_call @_acheck_before_call
async def get_message(self, chat_id: int, msg_id: int) -> types.Message: async def get_message(self, chat_id: int, msg_id: int) -> types.Message:
msg = await self.client.get_messages(chat_id, ids=msg_id) msg = await self.client.get_messages(chat_id, ids=msg_id)
@ -207,7 +248,15 @@ class TgFileSystemClient(object):
return res_list return res_list
@_acheck_before_call @_acheck_before_call
async def get_messages_by_search(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inner_search: bool = False, ignore_case: bool = False) -> hints.TotalList: async def get_messages_by_search(
self,
chat_id: int,
search_word: str,
limit: int = 10,
offset: int = 0,
inner_search: bool = False,
ignore_case: bool = False,
) -> hints.TotalList:
offset = await self._get_offset_msg_id(chat_id, offset) offset = await self._get_offset_msg_id(chat_id, offset)
if inner_search: if inner_search:
res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word) res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word)
@ -226,10 +275,25 @@ class TgFileSystemClient(object):
break break
return res_list return res_list
async def get_messages_by_search_db(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]: async def get_messages_by_search_db(
self,
chat_id: int,
search_word: str,
limit: int = 10,
offset: int = 0,
inc: bool = False,
ignore_case: bool = False,
) -> list[any]:
if chat_id not in self.client_param.whitelist_chat: if chat_id not in self.client_param.whitelist_chat:
return [] return []
res = self.db.get_msg_by_chat_id_and_keyword(chat_id, search_word, limit=limit, offset=offset, inc=inc, ignore_case=ignore_case) res = self.db.get_msg_by_chat_id_and_keyword(
chat_id,
search_word,
limit=limit,
offset=offset,
inc=inc,
ignore_case=ignore_case,
)
res = [self.db.get_column_msg_js(v) for v in res] res = [self.db.get_column_msg_js(v) for v in res]
return res return res
@ -243,8 +307,7 @@ class TgFileSystemClient(object):
chunk = chunk.tobytes() chunk = chunk.tobytes()
remain_size -= len(chunk) remain_size -= len(chunk)
if remain_size <= 0: if remain_size <= 0:
media_holder.append_chunk_mem( media_holder.append_chunk_mem(chunk[: len(chunk) + remain_size])
chunk[:len(chunk)+remain_size])
else: else:
media_holder.append_chunk_mem(chunk) media_holder.append_chunk_mem(chunk)
if media_holder.is_completed(): if media_holder.is_completed():
@ -256,30 +319,26 @@ class TgFileSystemClient(object):
self.media_chunk_manager.cancel_media_chunk(media_holder) self.media_chunk_manager.cancel_media_chunk(media_holder)
except Exception as err: except Exception as err:
logger.error( logger.error(
f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}") f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}"
)
finally: finally:
logger.debug( logger.debug(f"downloaded chunk:{time.time()}.{offset=},{target_size=},{media_holder}")
f"downloaded chunk:{time.time()}.{offset=},{target_size=},{media_holder}")
async def streaming_get_iter(self, msg: types.Message, start: int, end: int, req: Request): async def streaming_get_iter(self, msg: types.Message, start: int, end: int, req: Request):
try: try:
logger.debug( logger.debug(f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]")
f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]")
cur_task_id = self._get_unique_task_id() cur_task_id = self._get_unique_task_id()
pos = start pos = start
while not await req.is_disconnected() and pos <= end: while not await req.is_disconnected() and pos <= end:
cache_chunk = self.media_chunk_manager.get_media_chunk( cache_chunk = self.media_chunk_manager.get_media_chunk(msg, pos)
msg, pos)
if cache_chunk is None: if cache_chunk is None:
# post download task # post download task
# align pos download task # align pos download task
file_size = msg.media.document.size file_size = msg.media.document.size
# align_pos = pos // self.SINGLE_MEDIA_SIZE * self.SINGLE_MEDIA_SIZE # align_pos = pos // self.SINGLE_MEDIA_SIZE * self.SINGLE_MEDIA_SIZE
align_pos = pos align_pos = pos
align_size = min(self.SINGLE_MEDIA_SIZE, align_size = min(self.SINGLE_MEDIA_SIZE, file_size - align_pos)
file_size - align_pos) holder = MediaChunkHolder(msg.chat_id, msg.id, align_pos, align_size)
holder = MediaChunkHolder(
msg.chat_id, msg.id, align_pos, align_size)
holder.add_chunk_requester(req) holder.add_chunk_requester(req)
self.media_chunk_manager.set_media_chunk(holder) self.media_chunk_manager.set_media_chunk(holder)
self.task_queue.put_nowait((cur_task_id, self._download_media_chunk(msg, holder))) self.task_queue.put_nowait((cur_task_id, self._download_media_chunk(msg, holder)))
@ -294,15 +353,13 @@ class TgFileSystemClient(object):
if offset >= cache_chunk.length: if offset >= cache_chunk.length:
await cache_chunk.wait_chunk_update() await cache_chunk.wait_chunk_update()
continue continue
need_len = min(cache_chunk.length - need_len = min(cache_chunk.length - offset, end - pos + 1)
offset, end - pos + 1)
pos = pos + need_len pos = pos + need_len
yield cache_chunk.mem[offset : offset + need_len] yield cache_chunk.mem[offset : offset + need_len]
else: else:
offset = pos - cache_chunk.start offset = pos - cache_chunk.start
if offset >= cache_chunk.length: if offset >= cache_chunk.length:
raise RuntimeError( raise RuntimeError(f"lru cache missed!{pos=},{cache_chunk=}")
f"lru cache missed!{pos=},{cache_chunk=}")
need_len = min(cache_chunk.length - offset, end - pos + 1) need_len = min(cache_chunk.length - offset, end - pos + 1)
pos = pos + need_len pos = pos + need_len
yield cache_chunk.mem[offset : offset + need_len] yield cache_chunk.mem[offset : offset + need_len]
@ -310,12 +367,14 @@ class TgFileSystemClient(object):
logger.error(f"stream iter:{err=}") logger.error(f"stream iter:{err=}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
finally: finally:
async def _cancel_task_by_id(task_id: int): async def _cancel_task_by_id(task_id: int):
for _ in range(self.task_queue.qsize()): for _ in range(self.task_queue.qsize()):
task = self.task_queue.get_nowait() task = self.task_queue.get_nowait()
self.task_queue.task_done() self.task_queue.task_done()
if task[0] != task_id: if task[0] != task_id:
self.task_queue.put_nowait(task) self.task_queue.put_nowait(task)
await self.client.loop.create_task(_cancel_task_by_id(cur_task_id)) await self.client.loop.create_task(_cancel_task_by_id(cur_task_id))
logger.debug(f"yield quit,{msg.chat_id=},{msg.id=},[{start}:{end}]") logger.debug(f"yield quit,{msg.chat_id=},{msg.id=},[{start}:{end}]")

View File

@ -3,6 +3,7 @@ import asyncio
import time import time
import hashlib import hashlib
import os import os
import traceback
import logging import logging
from backend.TgFileSystemClient import TgFileSystemClient from backend.TgFileSystemClient import TgFileSystemClient
@ -11,14 +12,17 @@ import configParse
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
class TgFileSystemClientManager(object): class TgFileSystemClientManager(object):
MAX_MANAGE_CLIENTS: int = 10 MAX_MANAGE_CLIENTS: int = 10
is_init: asyncio.Future
param: configParse.TgToFileSystemParameter param: configParse.TgToFileSystemParameter
clients: dict[str, TgFileSystemClient] = {} clients: dict[str, TgFileSystemClient] = {}
def __init__(self, param: configParse.TgToFileSystemParameter) -> None: def __init__(self, param: configParse.TgToFileSystemParameter) -> None:
self.param = param self.param = param
self.db = UserManager() self.db = UserManager()
self.is_init = asyncio.Future()
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
if self.loop.is_running(): if self.loop.is_running():
self.loop.create_task(self._start_clients()) self.loop.create_task(self._start_clients())
@ -32,17 +36,37 @@ class TgFileSystemClientManager(object):
# init cache clients # init cache clients
for client_config in self.param.clients: for client_config in self.param.clients:
client = self.create_client(client_id=client_config.token) client = self.create_client(client_id=client_config.token)
self._register_client(client)
for _, client in self.clients.items():
try:
if not client.is_valid(): if not client.is_valid():
await client.start() await client.start()
self._register_client(client) except Exception as err:
logger.warning(f"start client: {err=}, {traceback.format_exc()}")
self.is_init.set_result(True)
async def get_status(self) -> dict[str, any]:
clients_status = [
{
"status": client.is_valid(),
}
for _, client in self.clients.items()
]
return {"init": await self.is_init, "clients": clients_status}
async def login_clients(self) -> str:
for _, client in self.clients.items():
login_url = await client.login()
if login_url != "":
return login_url
return ""
def check_client_session_exist(self, client_id: str) -> bool: def check_client_session_exist(self, client_id: str) -> bool:
session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session" session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session"
return os.path.isfile(session_db_file) return os.path.isfile(session_db_file)
def generate_client_id(self) -> str: def generate_client_id(self) -> str:
return hashlib.md5( return hashlib.md5((str(time.perf_counter()) + self.param.base.salt).encode("utf-8")).hexdigest()
(str(time.perf_counter()) + self.param.base.salt).encode('utf-8')).hexdigest()
def create_client(self, client_id: str = None) -> TgFileSystemClient: def create_client(self, client_id: str = None) -> TgFileSystemClient:
if client_id is None: if client_id is None:
@ -72,4 +96,3 @@ class TgFileSystemClientManager(object):
await client.start() await client.start()
self._register_client(client) self._register_client(client)
return client return client

View File

@ -9,6 +9,7 @@ from telethon import types
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
class UserUpdateParam(BaseModel): class UserUpdateParam(BaseModel):
client_id: str client_id: str
username: str username: str
@ -31,6 +32,8 @@ class MessageUpdateParam(BaseModel):
class UserManager(object): class UserManager(object):
def __init__(self) -> None: def __init__(self) -> None:
if not os.path.exists(os.path.dirname(__file__) + "/db"):
os.mkdir(os.path.dirname(__file__) + "/db")
self.con = sqlite3.connect(f"{os.path.dirname(__file__)}/db/user.db") self.con = sqlite3.connect(f"{os.path.dirname(__file__)}/db/user.db")
self.cur = self.con.cursor() self.cur = self.con.cursor()
if not self._table_has_been_inited(): if not self._table_has_been_inited():
@ -55,10 +58,20 @@ class UserManager(object):
def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]: def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute( res = self.cur.execute(
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC", (chat_id,)) "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC",
(chat_id,),
)
return res.fetchall() return res.fetchall()
def get_msg_by_chat_id_and_keyword(self, chat_id: int, keyword: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]: def get_msg_by_chat_id_and_keyword(
self,
chat_id: int,
keyword: str,
limit: int = 10,
offset: int = 0,
inc: bool = False,
ignore_case: bool = False,
) -> list[any]:
keyword_condition = "msg_ctx LIKE '%{key}%' OR file_name LIKE '%{key}%'" keyword_condition = "msg_ctx LIKE '%{key}%' OR file_name LIKE '%{key}%'"
if ignore_case: if ignore_case:
keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')" keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')"
@ -70,17 +83,23 @@ class UserManager(object):
def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]: def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute( res = self.cur.execute(
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time LIMIT 1", (chat_id,)) "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time LIMIT 1",
(chat_id,),
)
return res.fetchall() return res.fetchall()
def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]: def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute( res = self.cur.execute(
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC LIMIT 1", (chat_id,)) "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC LIMIT 1",
(chat_id,),
)
return res.fetchall() return res.fetchall()
def get_msg_by_unique_id(self, unique_id: str) -> list[any]: def get_msg_by_unique_id(self, unique_id: str) -> list[any]:
res = self.cur.execute( res = self.cur.execute(
"SELECT * FROM message WHERE unique_id = ? ORDER BY date_time DESC LIMIT 1", (unique_id,)) "SELECT * FROM message WHERE unique_id = ? ORDER BY date_time DESC LIMIT 1",
(unique_id,),
)
return res.fetchall() return res.fetchall()
@unique @unique
@ -128,8 +147,18 @@ class UserManager(object):
msg_type = UserManager.MessageTypeEnum.FILE.value msg_type = UserManager.MessageTypeEnum.FILE.value
except Exception as err: except Exception as err:
logger.error(f"{err=}") logger.error(f"{err=}")
insert_data = (unique_id, user_id, chat_id, msg_id, insert_data = (
msg_type, msg_ctx, mime_type, file_name, msg_js, date_time) unique_id,
user_id,
chat_id,
msg_id,
msg_type,
msg_ctx,
mime_type,
file_name,
msg_js,
date_time,
)
execute_script = "INSERT INTO message (unique_id, user_id, chat_id, msg_id, msg_type, msg_ctx, mime_type, file_name, msg_js, date_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" execute_script = "INSERT INTO message (unique_id, user_id, chat_id, msg_id, msg_type, msg_ctx, mime_type, file_name, msg_js, date_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
try: try:
self.cur.execute(execute_script, insert_data) self.cur.execute(execute_script, insert_data)
@ -175,11 +204,11 @@ class UserManager(object):
def _first_runtime_run_once(self) -> None: def _first_runtime_run_once(self) -> None:
if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='user'").fetchall()) == 0: if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='user'").fetchall()) == 0:
self.cur.execute( self.cur.execute("CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)")
"CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)")
if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0: if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0:
self.cur.execute( self.cur.execute(
"CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx text, mime_type text, file_name text, msg_js text, date_time int NOT NULL)") "CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx text, mime_type text, file_name text, msg_js text, date_time int NOT NULL)"
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -185,10 +185,16 @@ async def get_tg_file_media(chat_id: int|str, msg_id: int, file_name: str, sign:
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
@app.post("/tg/api/v1/client/login") @app.get("/tg/api/v1/client/login")
@apiutils.atimeit @apiutils.atimeit
async def login_new_tg_file_client(): async def login_new_tg_file_client():
raise NotImplementedError url = await clients_mgr.login_clients()
return {"url": url}
@app.get("/tg/api/v1/client/status")
async def get_tg_file_client_status(request: Request):
return await clients_mgr.get_status()
@app.get("/tg/api/v1/client/link_convert") @app.get("/tg/api/v1/client/link_convert")

View File

View File

@ -1,171 +1,17 @@
import sys
import os
import json
sys.path.append(os.getcwd() + "/../")
import streamlit as st import streamlit as st
import qrcode
import pandas
import requests
import configParse import remote_api as api
import utils
# qr = qrcode.make("https://www.baidu.com") st.set_page_config(page_title="TgToolbox", page_icon="🕹️", layout="wide", initial_sidebar_state="collapsed")
# st.image(qrcode.make("https://www.baidu.com").get_image())
param = configParse.get_TgToFileSystemParameter()
background_server_url = f"{param.base.exposed_url}/tg/api/v1/file/search"
st.set_page_config(page_title="TgToolbox", page_icon='🕹️', layout='wide') backend_status = api.get_backend_client_status()
need_login = False
if 'page_index' not in st.session_state: for v in backend_status["clients"]:
st.session_state.page_index = 1 if not v["status"]:
if 'force_skip' not in st.session_state: need_login = True
st.session_state.force_skip = False
if 'search_key' not in st.query_params: if need_login:
st.query_params.search_key = "" import login
if 'is_order' not in st.query_params:
st.query_params.is_order = False
if 'search_res_limit' not in st.query_params:
st.query_params.search_res_limit = "10"
@st.experimental_fragment
def search_container():
st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key)
columns = st.columns([7, 1])
with columns[0]:
st.query_params.search_res_limit = str(st.number_input(
"**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d"))
with columns[1]:
st.text("排序")
st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order))
search_container()
search_clicked = st.button('Search', type='primary', use_container_width=True)
if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None):
st.stop()
if not st.session_state.force_skip:
st.session_state.page_index = 1
if st.session_state.force_skip:
st.session_state.force_skip = False
@st.experimental_fragment
def do_search_req():
search_limit = int(st.query_params.search_res_limit)
offset_index = (st.session_state.page_index - 1) * search_limit
is_order = utils.strtobool(st.query_params.is_order)
req_body = {
"token": param.web.token,
"search": f"{st.query_params.search_key}",
"chat_id": param.web.chat_id[0],
"index": offset_index,
"length": search_limit,
"refresh": False,
"inner": False,
"inc": is_order,
}
req = requests.post(background_server_url, data=json.dumps(req_body))
if req.status_code != 200:
st.stop()
search_res = json.loads(req.content.decode("utf-8"))
def page_switch_render():
columns = st.columns(3)
with columns[0]:
if st.button("Prev", use_container_width=True):
st.session_state.page_index = st.session_state.page_index - 1
st.session_state.page_index = max(
st.session_state.page_index, 1)
st.session_state.force_skip = True
st.rerun()
with columns[1]:
# st.text(f"{st.session_state.page_index}")
st.markdown(
f"<p style='text-align: center;'>{st.session_state.page_index}</p>", unsafe_allow_html=True)
# st.markdown(f"<input type='number' style='text-align: center;' value={st.session_state.page_index}>", unsafe_allow_html=True)
with columns[2]:
if st.button("Next", use_container_width=True):
st.session_state.page_index = st.session_state.page_index + 1
st.session_state.force_skip = True
st.rerun()
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str):
file_size_str = f"{file_size/1024/1024:.2f}MB"
container = st.container()
container_columns = container.columns([1, 99])
st.session_state.search_res_select_list[index] = container_columns[0].checkbox(
"search_res_checkbox_" + str(index), label_visibility='collapsed')
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*"
popover = container_columns[1].popover(expender_title, use_container_width=True)
popover_columns = popover.columns([1, 3, 1])
if url:
popover_columns[0].video(url)
else: else:
popover_columns[0].video('./static/404.webm', format="video/webm") import search
popover_columns[1].markdown(f'{msg_ctx}')
popover_columns[1].markdown(f'**{file_name}**')
popover_columns[1].markdown(f'文件大小:*{file_size_str}*')
popover_columns[2].link_button('Download Link', url, use_container_width=True)
popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True)
@st.experimental_fragment
def show_search_res(res: dict[str, any]):
sign_token = ""
try:
sign_token = res['client']['sign']
except Exception as err:
pass
search_res_list = res.get('list')
if search_res_list is None or len(search_res_list) == 0:
st.info("No result")
page_switch_render()
st.stop()
st.session_state.search_res_select_list = [False] * len(search_res_list)
url_list = []
for i in range(len(search_res_list)):
v = search_res_list[i]
msg_ctx= ""
file_name = None
file_size = 0
download_url = ""
src_link = ""
try:
src_link = v['src_tg_link']
msg_ctx = v['message']
msg_id = str(v['id'])
doc = v['media']['document']
file_size = doc['size']
if doc is not None:
for attr in doc['attributes']:
file_name = attr.get('file_name')
if file_name != "" and file_name is not None:
break
if file_name == "" or file_name is None:
file_name = "Can not get file name"
download_url = v['download_url']
download_url += f'?sign={sign_token}'
url_list.append(download_url)
except Exception as err:
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
media_file_res_container(
i, msg_ctx, file_name, file_size, download_url, src_link)
page_switch_render()
show_text = ""
select_list = st.session_state.search_res_select_list
for i in range(len(select_list)):
if select_list[i]:
show_text = show_text + url_list[i] + '\n'
st.text_area("链接", value=show_text)
show_search_res(search_res)
do_search_req()

24
frontend/login.py Normal file
View File

@ -0,0 +1,24 @@
import sys
import os
import streamlit as st
import qrcode
sys.path.append(os.getcwd() + "/../")
import configParse
import utils
import remote_api as api
url = api.login_client_by_qr_code_url()
if url is None or url == "":
st.text("Something wrong, no login url got.")
st.stop()
st.markdown("### Please scan the qr code by telegram client.")
qr = qrcode.make(url)
st.image(qr.get_image())
st.markdown("**Click the Refrash button if you have been scaned**")
if st.button("Refresh"):
st.rerun()

58
frontend/remote_api.py Normal file
View File

@ -0,0 +1,58 @@
import sys
import os
import json
import logging
import requests
sys.path.append(os.getcwd() + "/../")
import configParse
logger = logging.getLogger(__file__.split("/")[-1])
param = configParse.get_TgToFileSystemParameter()
background_server_url = f"{param.base.exposed_url}"
search_api_route = "/tg/api/v1/file/search"
status_api_route = "/tg/api/v1/client/status"
login_api_route = "/tg/api/v1/client/login"
def login_client_by_qr_code_url() -> str:
request_url = background_server_url + login_api_route
response = requests.get(request_url)
if response.status_code != 200:
logger.warning(f"Could not login, err:{response.status_code}, {response.content.decode('utf-8')}")
return None
url_info = json.loads(response.content.decode("utf-8"))
return url_info.get("url")
def get_backend_client_status() -> dict[str, any]:
request_url = background_server_url + status_api_route
response = requests.get(request_url)
if response.status_code != 200:
logger.warning(f"get_status, backend is running? err:{response.status_code}, {response.content.decode('utf-8')}")
return None
return json.loads(response.content.decode("utf-8"))
def search_database_by_keyword(keyword: str, offset: int, limit: int, is_order: bool) -> list[any] | None:
request_url = background_server_url + search_api_route
req_body = {
"token": param.web.token,
"search": keyword,
"chat_id": param.web.chat_id[0],
"index": offset,
"length": limit,
"refresh": False,
"inner": False,
"inc": is_order,
}
response = requests.post(request_url, data=json.dumps(req_body))
if response.status_code != 200:
logger.warning(f"search_database_by_keyword err:{response.status_code}, {response.content.decode('utf-8')}")
return None
search_res = json.loads(response.content.decode("utf-8"))
return search_res

151
frontend/search.py Normal file
View File

@ -0,0 +1,151 @@
import sys
import os
import streamlit as st
sys.path.append(os.getcwd() + "/../")
import configParse
import utils
import remote_api as api
param = configParse.get_TgToFileSystemParameter()
if 'page_index' not in st.session_state:
st.session_state.page_index = 1
if 'force_skip' not in st.session_state:
st.session_state.force_skip = False
if 'search_key' not in st.query_params:
st.query_params.search_key = ""
if 'is_order' not in st.query_params:
st.query_params.is_order = False
if 'search_res_limit' not in st.query_params:
st.query_params.search_res_limit = "10"
@st.experimental_fragment
def search_container():
st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key)
columns = st.columns([7, 1])
with columns[0]:
st.query_params.search_res_limit = str(st.number_input(
"**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d"))
with columns[1]:
st.text("排序")
st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order))
search_container()
search_clicked = st.button('Search', type='primary', use_container_width=True)
if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None):
st.stop()
if not st.session_state.force_skip:
st.session_state.page_index = 1
if st.session_state.force_skip:
st.session_state.force_skip = False
@st.experimental_fragment
def do_search_req():
search_limit = int(st.query_params.search_res_limit)
offset_index = (st.session_state.page_index - 1) * search_limit
is_order = utils.strtobool(st.query_params.is_order)
search_res = api.search_database_by_keyword(st.query_params.search_key, offset_index, search_limit, is_order)
if search_res is None:
st.stop()
def page_switch_render():
columns = st.columns(3)
with columns[0]:
if st.button("Prev", use_container_width=True):
st.session_state.page_index = st.session_state.page_index - 1
st.session_state.page_index = max(
st.session_state.page_index, 1)
st.session_state.force_skip = True
st.rerun()
with columns[1]:
# st.text(f"{st.session_state.page_index}")
st.markdown(
f"<p style='text-align: center;'>{st.session_state.page_index}</p>", unsafe_allow_html=True)
# st.markdown(f"<input type='number' style='text-align: center;' value={st.session_state.page_index}>", unsafe_allow_html=True)
with columns[2]:
if st.button("Next", use_container_width=True):
st.session_state.page_index = st.session_state.page_index + 1
st.session_state.force_skip = True
st.rerun()
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str):
file_size_str = f"{file_size/1024/1024:.2f}MB"
container = st.container()
container_columns = container.columns([1, 99])
st.session_state.search_res_select_list[index] = container_columns[0].checkbox(
"search_res_checkbox_" + str(index), label_visibility='collapsed')
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*"
popover = container_columns[1].popover(expender_title, use_container_width=True)
popover_columns = popover.columns([1, 3, 1])
if url:
popover_columns[0].video(url)
else:
popover_columns[0].video('./static/404.webm', format="video/webm")
popover_columns[1].markdown(f'{msg_ctx}')
popover_columns[1].markdown(f'**{file_name}**')
popover_columns[1].markdown(f'文件大小:*{file_size_str}*')
popover_columns[2].link_button('Download Link', url, use_container_width=True)
popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True)
@st.experimental_fragment
def show_search_res(res: dict[str, any]):
search_res_list = res.get("list")
if search_res_list is None or len(search_res_list) == 0:
st.info("No result")
page_switch_render()
st.stop()
sign_token = ""
try:
sign_token = res['client']['sign']
except Exception as err:
pass
st.session_state.search_res_select_list = [False] * len(search_res_list)
url_list = []
for i in range(len(search_res_list)):
v = search_res_list[i]
msg_ctx= ""
file_name = None
file_size = 0
download_url = ""
src_link = ""
try:
src_link = v['src_tg_link']
msg_ctx = v['message']
msg_id = str(v['id'])
doc = v['media']['document']
file_size = doc['size']
if doc is not None:
for attr in doc['attributes']:
file_name = attr.get('file_name')
if file_name != "" and file_name is not None:
break
if file_name == "" or file_name is None:
file_name = "Can not get file name"
download_url = v['download_url']
download_url += f'?sign={sign_token}'
url_list.append(download_url)
except Exception as err:
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
media_file_res_container(
i, msg_ctx, file_name, file_size, download_url, src_link)
page_switch_render()
show_text = ""
select_list = st.session_state.search_res_select_list
for i in range(len(select_list)):
if select_list[i]:
show_text = show_text + url_list[i] + '\n'
st.text_area("链接", value=show_text)
show_search_res(search_res)
do_search_req()

View File

@ -1,6 +1,7 @@
toml toml
telethon telethon
# python-socks[asyncio] # python-socks[asyncio]
diskcache
fastapi fastapi
uvicorn[standard] uvicorn[standard]
streamlit streamlit