feat: support login
This commit is contained in:
parent
cab27becf1
commit
7acd0f4712
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,13 +1,15 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
.venv
|
.venv
|
||||||
|
.idea
|
||||||
.vscode
|
.vscode
|
||||||
*.session
|
*.session
|
||||||
*.session-journal
|
*.session-journal
|
||||||
|
*-journal
|
||||||
*.session.sql
|
*.session.sql
|
||||||
*.toml
|
*.toml
|
||||||
*.db
|
*.db
|
||||||
*.service
|
*.service
|
||||||
log
|
log
|
||||||
cacheTest
|
cache_media
|
||||||
tmp
|
tmp
|
||||||
logs
|
logs
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import bisect
|
import bisect
|
||||||
@ -18,6 +19,11 @@ class MediaChunkHolder(object):
|
|||||||
waiters: collections.deque[asyncio.Future]
|
waiters: collections.deque[asyncio.Future]
|
||||||
requester: list[Request] = []
|
requester: list[Request] = []
|
||||||
chunk_id: int = 0
|
chunk_id: int = 0
|
||||||
|
unique_id: str = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_id(chat_id: int, msg_id: int, start: int) -> str:
|
||||||
|
return f"{chat_id}:{msg_id}:{start}"
|
||||||
|
|
||||||
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
|
def __init__(self, chat_id: int, msg_id: int, start: int, target_len: int, mem: Optional[bytes] = None) -> None:
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
@ -27,6 +33,7 @@ class MediaChunkHolder(object):
|
|||||||
self.mem = mem or bytes()
|
self.mem = mem or bytes()
|
||||||
self.length = len(self.mem)
|
self.length = len(self.mem)
|
||||||
self.waiters = collections.deque()
|
self.waiters = collections.deque()
|
||||||
|
self.unique_id = MediaChunkHolder.generate_id(chat_id, msg_id, start)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"MediaChunk,start:{self.start},len:{self.length}"
|
return f"MediaChunk,start:{self.start},len:{self.length}"
|
||||||
@ -114,8 +121,11 @@ class MediaChunkHolderManager(object):
|
|||||||
unique_chunk_id: int = 0
|
unique_chunk_id: int = 0
|
||||||
chunk_lru: collections.OrderedDict[int, MediaChunkHolder]
|
chunk_lru: collections.OrderedDict[int, MediaChunkHolder]
|
||||||
|
|
||||||
|
disk_chunk_cache: diskcache.Cache
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.chunk_lru = collections.OrderedDict()
|
self.chunk_lru = collections.OrderedDict()
|
||||||
|
self.disk_chunk_cache = diskcache.Cache(f"{os.path.dirname(__file__)}/cache_media", size_limit=2**30)
|
||||||
|
|
||||||
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]:
|
def _get_media_msg_cache(self, msg: types.Message) -> Optional[list[MediaChunkHolder]]:
|
||||||
chat_cache = self.chunk_cache.get(msg.chat_id)
|
chat_cache = self.chunk_cache.get(msg.chat_id)
|
||||||
@ -160,14 +170,8 @@ class MediaChunkHolderManager(object):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
|
def set_media_chunk(self, chunk: MediaChunkHolder) -> None:
|
||||||
cache_chat = self.chunk_cache.get(chunk.chat_id)
|
cache_chat = self.chunk_cache.setdefault(chunk.chat_id, {})
|
||||||
if cache_chat is None:
|
cache_msg = cache_chat.setdefault(chunk.msg_id, [])
|
||||||
self.chunk_cache[chunk.chat_id] = {}
|
|
||||||
cache_chat = self.chunk_cache[chunk.chat_id]
|
|
||||||
cache_msg = cache_chat.get(chunk.msg_id)
|
|
||||||
if cache_msg is None:
|
|
||||||
cache_chat[chunk.msg_id] = []
|
|
||||||
cache_msg = cache_chat[chunk.msg_id]
|
|
||||||
chunk.chunk_id = self.unique_chunk_id
|
chunk.chunk_id = self.unique_chunk_id
|
||||||
self.unique_chunk_id += 1
|
self.unique_chunk_id += 1
|
||||||
bisect.insort(cache_msg, chunk)
|
bisect.insort(cache_msg, chunk)
|
||||||
@ -283,4 +287,3 @@ class MediaBlockHolderManager(object):
|
|||||||
|
|
||||||
def __init__(self, limit_size: int = DEFAULT_MAX_CACHE_SIZE, dir: str = 'cache') -> None:
|
def __init__(self, limit_size: int = DEFAULT_MAX_CACHE_SIZE, dir: str = 'cache') -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import logging
|
|||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
|
|
||||||
from telethon import TelegramClient, types, hints, events
|
from telethon import TelegramClient, types, hints, events
|
||||||
|
from telethon.custom import QRLogin
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
import configParse
|
import configParse
|
||||||
@ -19,6 +20,7 @@ from backend.MediaCacheManager import MediaChunkHolder, MediaChunkHolderManager
|
|||||||
|
|
||||||
logger = logging.getLogger(__file__.split("/")[-1])
|
logger = logging.getLogger(__file__.split("/")[-1])
|
||||||
|
|
||||||
|
|
||||||
class TgFileSystemClient(object):
|
class TgFileSystemClient(object):
|
||||||
MAX_WORKER_ROUTINE = 4
|
MAX_WORKER_ROUTINE = 4
|
||||||
SINGLE_NET_CHUNK_SIZE = 512 * 1024 # 512kb
|
SINGLE_NET_CHUNK_SIZE = 512 * 1024 # 512kb
|
||||||
@ -32,6 +34,8 @@ class TgFileSystemClient(object):
|
|||||||
dialogs_cache: Optional[hints.TotalList] = None
|
dialogs_cache: Optional[hints.TotalList] = None
|
||||||
msg_cache: list[types.Message] = []
|
msg_cache: list[types.Message] = []
|
||||||
worker_routines: list[asyncio.Task] = []
|
worker_routines: list[asyncio.Task] = []
|
||||||
|
qr_login: QRLogin | None = None
|
||||||
|
login_task: asyncio.Task | None = None
|
||||||
# task should: (task_id, callabledFunc)
|
# task should: (task_id, callabledFunc)
|
||||||
task_queue: asyncio.Queue
|
task_queue: asyncio.Queue
|
||||||
task_id: int = 0
|
task_id: int = 0
|
||||||
@ -39,19 +43,35 @@ class TgFileSystemClient(object):
|
|||||||
# client config
|
# client config
|
||||||
client_param: configParse.TgToFileSystemParameter.ClientConfigPatameter
|
client_param: configParse.TgToFileSystemParameter.ClientConfigPatameter
|
||||||
|
|
||||||
def __init__(self, session_name: str, param: configParse.TgToFileSystemParameter, db: UserManager) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_name: str,
|
||||||
|
param: configParse.TgToFileSystemParameter,
|
||||||
|
db: UserManager,
|
||||||
|
) -> None:
|
||||||
self.api_id = param.tgApi.api_id
|
self.api_id = param.tgApi.api_id
|
||||||
self.api_hash = param.tgApi.api_hash
|
self.api_hash = param.tgApi.api_hash
|
||||||
self.session_name = session_name
|
self.session_name = session_name
|
||||||
self.proxy_param = {
|
self.proxy_param = (
|
||||||
'proxy_type': param.proxy.proxy_type,
|
{
|
||||||
'addr': param.proxy.addr,
|
"proxy_type": param.proxy.proxy_type,
|
||||||
'port': param.proxy.port,
|
"addr": param.proxy.addr,
|
||||||
} if param.proxy.enable else {}
|
"port": param.proxy.port,
|
||||||
self.client_param = next((client_param for client_param in param.clients if client_param.token == session_name), configParse.TgToFileSystemParameter.ClientConfigPatameter())
|
}
|
||||||
|
if param.proxy.enable
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
self.client_param = next(
|
||||||
|
(client_param for client_param in param.clients if client_param.token == session_name),
|
||||||
|
configParse.TgToFileSystemParameter.ClientConfigPatameter(),
|
||||||
|
)
|
||||||
self.task_queue = asyncio.Queue()
|
self.task_queue = asyncio.Queue()
|
||||||
self.client = TelegramClient(
|
self.client = TelegramClient(
|
||||||
f"{os.path.dirname(__file__)}/db/{self.session_name}.session", self.api_id, self.api_hash, proxy=self.proxy_param)
|
f"{os.path.dirname(__file__)}/db/{self.session_name}.session",
|
||||||
|
self.api_id,
|
||||||
|
self.api_hash,
|
||||||
|
proxy=self.proxy_param,
|
||||||
|
)
|
||||||
self.media_chunk_manager = MediaChunkHolderManager()
|
self.media_chunk_manager = MediaChunkHolderManager()
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
@ -72,6 +92,7 @@ class TgFileSystemClient(object):
|
|||||||
raise RuntimeError("Client does not run.")
|
raise RuntimeError("Client does not run.")
|
||||||
result = func(self, *args, **kwargs)
|
result = func(self, *args, **kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return call_check_wrapper
|
return call_check_wrapper
|
||||||
|
|
||||||
def _acheck_before_call(func):
|
def _acheck_before_call(func):
|
||||||
@ -80,6 +101,7 @@ class TgFileSystemClient(object):
|
|||||||
raise RuntimeError("Client does not run.")
|
raise RuntimeError("Client does not run.")
|
||||||
result = await func(self, *args, **kwargs)
|
result = await func(self, *args, **kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return call_check_wrapper
|
return call_check_wrapper
|
||||||
|
|
||||||
@_check_before_call
|
@_check_before_call
|
||||||
@ -100,6 +122,28 @@ class TgFileSystemClient(object):
|
|||||||
msg: types.Message = event.message
|
msg: types.Message = event.message
|
||||||
self.db.insert_by_message(self.me, msg)
|
self.db.insert_by_message(self.me, msg)
|
||||||
|
|
||||||
|
async def login(self, mode: Union["phone", "qrcode"] = "qrcode") -> str:
|
||||||
|
if self.is_valid():
|
||||||
|
return ""
|
||||||
|
if mode == "phone":
|
||||||
|
raise NotImplementedError
|
||||||
|
if self.qr_login is not None:
|
||||||
|
return self.qr_login.url
|
||||||
|
self.qr_login = await self.client.qr_login()
|
||||||
|
|
||||||
|
async def wait_for_qr_login():
|
||||||
|
try:
|
||||||
|
await self.qr_login.wait()
|
||||||
|
await self.start()
|
||||||
|
except Exception as err:
|
||||||
|
logger.warning(f"wait for login, {err=}, {traceback.format_exc()}")
|
||||||
|
finally:
|
||||||
|
self.login_task = None
|
||||||
|
self.qr_login = None
|
||||||
|
|
||||||
|
self.login_task = self.client.loop.create_task(wait_for_qr_login())
|
||||||
|
return self.qr_login.url
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self.is_valid():
|
if self.is_valid():
|
||||||
return
|
return
|
||||||
@ -107,11 +151,9 @@ class TgFileSystemClient(object):
|
|||||||
await self.client.connect()
|
await self.client.connect()
|
||||||
self.me = await self.client.get_me()
|
self.me = await self.client.get_me()
|
||||||
if self.me is None:
|
if self.me is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"The {self.session_name} Client Does Not Login")
|
||||||
f"The {self.session_name} Client Does Not Login")
|
|
||||||
for _ in range(self.MAX_WORKER_ROUTINE):
|
for _ in range(self.MAX_WORKER_ROUTINE):
|
||||||
worker_routine = self.client.loop.create_task(
|
worker_routine = self.client.loop.create_task(self._worker_routine_handler())
|
||||||
self._worker_routine_handler())
|
|
||||||
self.worker_routines.append(worker_routine)
|
self.worker_routines.append(worker_routine)
|
||||||
if len(self.client_param.whitelist_chat) > 0:
|
if len(self.client_param.whitelist_chat) > 0:
|
||||||
self._register_update_event(from_users=self.client_param.whitelist_chat)
|
self._register_update_event(from_users=self.client_param.whitelist_chat)
|
||||||
@ -163,7 +205,6 @@ class TgFileSystemClient(object):
|
|||||||
self.db.insert_by_message(self.me, msg)
|
self.db.insert_by_message(self.me, msg)
|
||||||
logger.info(f"{chat_id} quit cache task.")
|
logger.info(f"{chat_id} quit cache task.")
|
||||||
|
|
||||||
|
|
||||||
@_acheck_before_call
|
@_acheck_before_call
|
||||||
async def get_message(self, chat_id: int, msg_id: int) -> types.Message:
|
async def get_message(self, chat_id: int, msg_id: int) -> types.Message:
|
||||||
msg = await self.client.get_messages(chat_id, ids=msg_id)
|
msg = await self.client.get_messages(chat_id, ids=msg_id)
|
||||||
@ -207,7 +248,15 @@ class TgFileSystemClient(object):
|
|||||||
return res_list
|
return res_list
|
||||||
|
|
||||||
@_acheck_before_call
|
@_acheck_before_call
|
||||||
async def get_messages_by_search(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inner_search: bool = False, ignore_case: bool = False) -> hints.TotalList:
|
async def get_messages_by_search(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
search_word: str,
|
||||||
|
limit: int = 10,
|
||||||
|
offset: int = 0,
|
||||||
|
inner_search: bool = False,
|
||||||
|
ignore_case: bool = False,
|
||||||
|
) -> hints.TotalList:
|
||||||
offset = await self._get_offset_msg_id(chat_id, offset)
|
offset = await self._get_offset_msg_id(chat_id, offset)
|
||||||
if inner_search:
|
if inner_search:
|
||||||
res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word)
|
res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word)
|
||||||
@ -226,10 +275,25 @@ class TgFileSystemClient(object):
|
|||||||
break
|
break
|
||||||
return res_list
|
return res_list
|
||||||
|
|
||||||
async def get_messages_by_search_db(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]:
|
async def get_messages_by_search_db(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
search_word: str,
|
||||||
|
limit: int = 10,
|
||||||
|
offset: int = 0,
|
||||||
|
inc: bool = False,
|
||||||
|
ignore_case: bool = False,
|
||||||
|
) -> list[any]:
|
||||||
if chat_id not in self.client_param.whitelist_chat:
|
if chat_id not in self.client_param.whitelist_chat:
|
||||||
return []
|
return []
|
||||||
res = self.db.get_msg_by_chat_id_and_keyword(chat_id, search_word, limit=limit, offset=offset, inc=inc, ignore_case=ignore_case)
|
res = self.db.get_msg_by_chat_id_and_keyword(
|
||||||
|
chat_id,
|
||||||
|
search_word,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
inc=inc,
|
||||||
|
ignore_case=ignore_case,
|
||||||
|
)
|
||||||
res = [self.db.get_column_msg_js(v) for v in res]
|
res = [self.db.get_column_msg_js(v) for v in res]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -243,8 +307,7 @@ class TgFileSystemClient(object):
|
|||||||
chunk = chunk.tobytes()
|
chunk = chunk.tobytes()
|
||||||
remain_size -= len(chunk)
|
remain_size -= len(chunk)
|
||||||
if remain_size <= 0:
|
if remain_size <= 0:
|
||||||
media_holder.append_chunk_mem(
|
media_holder.append_chunk_mem(chunk[: len(chunk) + remain_size])
|
||||||
chunk[:len(chunk)+remain_size])
|
|
||||||
else:
|
else:
|
||||||
media_holder.append_chunk_mem(chunk)
|
media_holder.append_chunk_mem(chunk)
|
||||||
if media_holder.is_completed():
|
if media_holder.is_completed():
|
||||||
@ -256,30 +319,26 @@ class TgFileSystemClient(object):
|
|||||||
self.media_chunk_manager.cancel_media_chunk(media_holder)
|
self.media_chunk_manager.cancel_media_chunk(media_holder)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}")
|
f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
logger.debug(
|
logger.debug(f"downloaded chunk:{time.time()}.{offset=},{target_size=},{media_holder}")
|
||||||
f"downloaded chunk:{time.time()}.{offset=},{target_size=},{media_holder}")
|
|
||||||
|
|
||||||
async def streaming_get_iter(self, msg: types.Message, start: int, end: int, req: Request):
|
async def streaming_get_iter(self, msg: types.Message, start: int, end: int, req: Request):
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug(f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]")
|
||||||
f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]")
|
|
||||||
cur_task_id = self._get_unique_task_id()
|
cur_task_id = self._get_unique_task_id()
|
||||||
pos = start
|
pos = start
|
||||||
while not await req.is_disconnected() and pos <= end:
|
while not await req.is_disconnected() and pos <= end:
|
||||||
cache_chunk = self.media_chunk_manager.get_media_chunk(
|
cache_chunk = self.media_chunk_manager.get_media_chunk(msg, pos)
|
||||||
msg, pos)
|
|
||||||
if cache_chunk is None:
|
if cache_chunk is None:
|
||||||
# post download task
|
# post download task
|
||||||
# align pos download task
|
# align pos download task
|
||||||
file_size = msg.media.document.size
|
file_size = msg.media.document.size
|
||||||
# align_pos = pos // self.SINGLE_MEDIA_SIZE * self.SINGLE_MEDIA_SIZE
|
# align_pos = pos // self.SINGLE_MEDIA_SIZE * self.SINGLE_MEDIA_SIZE
|
||||||
align_pos = pos
|
align_pos = pos
|
||||||
align_size = min(self.SINGLE_MEDIA_SIZE,
|
align_size = min(self.SINGLE_MEDIA_SIZE, file_size - align_pos)
|
||||||
file_size - align_pos)
|
holder = MediaChunkHolder(msg.chat_id, msg.id, align_pos, align_size)
|
||||||
holder = MediaChunkHolder(
|
|
||||||
msg.chat_id, msg.id, align_pos, align_size)
|
|
||||||
holder.add_chunk_requester(req)
|
holder.add_chunk_requester(req)
|
||||||
self.media_chunk_manager.set_media_chunk(holder)
|
self.media_chunk_manager.set_media_chunk(holder)
|
||||||
self.task_queue.put_nowait((cur_task_id, self._download_media_chunk(msg, holder)))
|
self.task_queue.put_nowait((cur_task_id, self._download_media_chunk(msg, holder)))
|
||||||
@ -294,15 +353,13 @@ class TgFileSystemClient(object):
|
|||||||
if offset >= cache_chunk.length:
|
if offset >= cache_chunk.length:
|
||||||
await cache_chunk.wait_chunk_update()
|
await cache_chunk.wait_chunk_update()
|
||||||
continue
|
continue
|
||||||
need_len = min(cache_chunk.length -
|
need_len = min(cache_chunk.length - offset, end - pos + 1)
|
||||||
offset, end - pos + 1)
|
|
||||||
pos = pos + need_len
|
pos = pos + need_len
|
||||||
yield cache_chunk.mem[offset : offset + need_len]
|
yield cache_chunk.mem[offset : offset + need_len]
|
||||||
else:
|
else:
|
||||||
offset = pos - cache_chunk.start
|
offset = pos - cache_chunk.start
|
||||||
if offset >= cache_chunk.length:
|
if offset >= cache_chunk.length:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"lru cache missed!{pos=},{cache_chunk=}")
|
||||||
f"lru cache missed!{pos=},{cache_chunk=}")
|
|
||||||
need_len = min(cache_chunk.length - offset, end - pos + 1)
|
need_len = min(cache_chunk.length - offset, end - pos + 1)
|
||||||
pos = pos + need_len
|
pos = pos + need_len
|
||||||
yield cache_chunk.mem[offset : offset + need_len]
|
yield cache_chunk.mem[offset : offset + need_len]
|
||||||
@ -310,12 +367,14 @@ class TgFileSystemClient(object):
|
|||||||
logger.error(f"stream iter:{err=}")
|
logger.error(f"stream iter:{err=}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
|
|
||||||
async def _cancel_task_by_id(task_id: int):
|
async def _cancel_task_by_id(task_id: int):
|
||||||
for _ in range(self.task_queue.qsize()):
|
for _ in range(self.task_queue.qsize()):
|
||||||
task = self.task_queue.get_nowait()
|
task = self.task_queue.get_nowait()
|
||||||
self.task_queue.task_done()
|
self.task_queue.task_done()
|
||||||
if task[0] != task_id:
|
if task[0] != task_id:
|
||||||
self.task_queue.put_nowait(task)
|
self.task_queue.put_nowait(task)
|
||||||
|
|
||||||
await self.client.loop.create_task(_cancel_task_by_id(cur_task_id))
|
await self.client.loop.create_task(_cancel_task_by_id(cur_task_id))
|
||||||
logger.debug(f"yield quit,{msg.chat_id=},{msg.id=},[{start}:{end}]")
|
logger.debug(f"yield quit,{msg.chat_id=},{msg.id=},[{start}:{end}]")
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from backend.TgFileSystemClient import TgFileSystemClient
|
from backend.TgFileSystemClient import TgFileSystemClient
|
||||||
@ -11,14 +12,17 @@ import configParse
|
|||||||
|
|
||||||
logger = logging.getLogger(__file__.split("/")[-1])
|
logger = logging.getLogger(__file__.split("/")[-1])
|
||||||
|
|
||||||
|
|
||||||
class TgFileSystemClientManager(object):
|
class TgFileSystemClientManager(object):
|
||||||
MAX_MANAGE_CLIENTS: int = 10
|
MAX_MANAGE_CLIENTS: int = 10
|
||||||
|
is_init: asyncio.Future
|
||||||
param: configParse.TgToFileSystemParameter
|
param: configParse.TgToFileSystemParameter
|
||||||
clients: dict[str, TgFileSystemClient] = {}
|
clients: dict[str, TgFileSystemClient] = {}
|
||||||
|
|
||||||
def __init__(self, param: configParse.TgToFileSystemParameter) -> None:
|
def __init__(self, param: configParse.TgToFileSystemParameter) -> None:
|
||||||
self.param = param
|
self.param = param
|
||||||
self.db = UserManager()
|
self.db = UserManager()
|
||||||
|
self.is_init = asyncio.Future()
|
||||||
self.loop = asyncio.get_running_loop()
|
self.loop = asyncio.get_running_loop()
|
||||||
if self.loop.is_running():
|
if self.loop.is_running():
|
||||||
self.loop.create_task(self._start_clients())
|
self.loop.create_task(self._start_clients())
|
||||||
@ -32,17 +36,37 @@ class TgFileSystemClientManager(object):
|
|||||||
# init cache clients
|
# init cache clients
|
||||||
for client_config in self.param.clients:
|
for client_config in self.param.clients:
|
||||||
client = self.create_client(client_id=client_config.token)
|
client = self.create_client(client_id=client_config.token)
|
||||||
|
self._register_client(client)
|
||||||
|
for _, client in self.clients.items():
|
||||||
|
try:
|
||||||
if not client.is_valid():
|
if not client.is_valid():
|
||||||
await client.start()
|
await client.start()
|
||||||
self._register_client(client)
|
except Exception as err:
|
||||||
|
logger.warning(f"start client: {err=}, {traceback.format_exc()}")
|
||||||
|
self.is_init.set_result(True)
|
||||||
|
|
||||||
|
async def get_status(self) -> dict[str, any]:
|
||||||
|
clients_status = [
|
||||||
|
{
|
||||||
|
"status": client.is_valid(),
|
||||||
|
}
|
||||||
|
for _, client in self.clients.items()
|
||||||
|
]
|
||||||
|
return {"init": await self.is_init, "clients": clients_status}
|
||||||
|
|
||||||
|
async def login_clients(self) -> str:
|
||||||
|
for _, client in self.clients.items():
|
||||||
|
login_url = await client.login()
|
||||||
|
if login_url != "":
|
||||||
|
return login_url
|
||||||
|
return ""
|
||||||
|
|
||||||
def check_client_session_exist(self, client_id: str) -> bool:
|
def check_client_session_exist(self, client_id: str) -> bool:
|
||||||
session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session"
|
session_db_file = f"{os.path.dirname(__file__)}/db/{client_id}.session"
|
||||||
return os.path.isfile(session_db_file)
|
return os.path.isfile(session_db_file)
|
||||||
|
|
||||||
def generate_client_id(self) -> str:
|
def generate_client_id(self) -> str:
|
||||||
return hashlib.md5(
|
return hashlib.md5((str(time.perf_counter()) + self.param.base.salt).encode("utf-8")).hexdigest()
|
||||||
(str(time.perf_counter()) + self.param.base.salt).encode('utf-8')).hexdigest()
|
|
||||||
|
|
||||||
def create_client(self, client_id: str = None) -> TgFileSystemClient:
|
def create_client(self, client_id: str = None) -> TgFileSystemClient:
|
||||||
if client_id is None:
|
if client_id is None:
|
||||||
@ -72,4 +96,3 @@ class TgFileSystemClientManager(object):
|
|||||||
await client.start()
|
await client.start()
|
||||||
self._register_client(client)
|
self._register_client(client)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from telethon import types
|
|||||||
|
|
||||||
logger = logging.getLogger(__file__.split("/")[-1])
|
logger = logging.getLogger(__file__.split("/")[-1])
|
||||||
|
|
||||||
|
|
||||||
class UserUpdateParam(BaseModel):
|
class UserUpdateParam(BaseModel):
|
||||||
client_id: str
|
client_id: str
|
||||||
username: str
|
username: str
|
||||||
@ -31,6 +32,8 @@ class MessageUpdateParam(BaseModel):
|
|||||||
|
|
||||||
class UserManager(object):
|
class UserManager(object):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
if not os.path.exists(os.path.dirname(__file__) + "/db"):
|
||||||
|
os.mkdir(os.path.dirname(__file__) + "/db")
|
||||||
self.con = sqlite3.connect(f"{os.path.dirname(__file__)}/db/user.db")
|
self.con = sqlite3.connect(f"{os.path.dirname(__file__)}/db/user.db")
|
||||||
self.cur = self.con.cursor()
|
self.cur = self.con.cursor()
|
||||||
if not self._table_has_been_inited():
|
if not self._table_has_been_inited():
|
||||||
@ -55,10 +58,20 @@ class UserManager(object):
|
|||||||
|
|
||||||
def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]:
|
def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]:
|
||||||
res = self.cur.execute(
|
res = self.cur.execute(
|
||||||
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC", (chat_id,))
|
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC",
|
||||||
|
(chat_id,),
|
||||||
|
)
|
||||||
return res.fetchall()
|
return res.fetchall()
|
||||||
|
|
||||||
def get_msg_by_chat_id_and_keyword(self, chat_id: int, keyword: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]:
|
def get_msg_by_chat_id_and_keyword(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
keyword: str,
|
||||||
|
limit: int = 10,
|
||||||
|
offset: int = 0,
|
||||||
|
inc: bool = False,
|
||||||
|
ignore_case: bool = False,
|
||||||
|
) -> list[any]:
|
||||||
keyword_condition = "msg_ctx LIKE '%{key}%' OR file_name LIKE '%{key}%'"
|
keyword_condition = "msg_ctx LIKE '%{key}%' OR file_name LIKE '%{key}%'"
|
||||||
if ignore_case:
|
if ignore_case:
|
||||||
keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')"
|
keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')"
|
||||||
@ -70,17 +83,23 @@ class UserManager(object):
|
|||||||
|
|
||||||
def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]:
|
def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]:
|
||||||
res = self.cur.execute(
|
res = self.cur.execute(
|
||||||
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time LIMIT 1", (chat_id,))
|
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time LIMIT 1",
|
||||||
|
(chat_id,),
|
||||||
|
)
|
||||||
return res.fetchall()
|
return res.fetchall()
|
||||||
|
|
||||||
def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]:
|
def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]:
|
||||||
res = self.cur.execute(
|
res = self.cur.execute(
|
||||||
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC LIMIT 1", (chat_id,))
|
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC LIMIT 1",
|
||||||
|
(chat_id,),
|
||||||
|
)
|
||||||
return res.fetchall()
|
return res.fetchall()
|
||||||
|
|
||||||
def get_msg_by_unique_id(self, unique_id: str) -> list[any]:
|
def get_msg_by_unique_id(self, unique_id: str) -> list[any]:
|
||||||
res = self.cur.execute(
|
res = self.cur.execute(
|
||||||
"SELECT * FROM message WHERE unique_id = ? ORDER BY date_time DESC LIMIT 1", (unique_id,))
|
"SELECT * FROM message WHERE unique_id = ? ORDER BY date_time DESC LIMIT 1",
|
||||||
|
(unique_id,),
|
||||||
|
)
|
||||||
return res.fetchall()
|
return res.fetchall()
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@ -128,8 +147,18 @@ class UserManager(object):
|
|||||||
msg_type = UserManager.MessageTypeEnum.FILE.value
|
msg_type = UserManager.MessageTypeEnum.FILE.value
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error(f"{err=}")
|
logger.error(f"{err=}")
|
||||||
insert_data = (unique_id, user_id, chat_id, msg_id,
|
insert_data = (
|
||||||
msg_type, msg_ctx, mime_type, file_name, msg_js, date_time)
|
unique_id,
|
||||||
|
user_id,
|
||||||
|
chat_id,
|
||||||
|
msg_id,
|
||||||
|
msg_type,
|
||||||
|
msg_ctx,
|
||||||
|
mime_type,
|
||||||
|
file_name,
|
||||||
|
msg_js,
|
||||||
|
date_time,
|
||||||
|
)
|
||||||
execute_script = "INSERT INTO message (unique_id, user_id, chat_id, msg_id, msg_type, msg_ctx, mime_type, file_name, msg_js, date_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
|
execute_script = "INSERT INTO message (unique_id, user_id, chat_id, msg_id, msg_type, msg_ctx, mime_type, file_name, msg_js, date_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
|
||||||
try:
|
try:
|
||||||
self.cur.execute(execute_script, insert_data)
|
self.cur.execute(execute_script, insert_data)
|
||||||
@ -175,11 +204,11 @@ class UserManager(object):
|
|||||||
|
|
||||||
def _first_runtime_run_once(self) -> None:
|
def _first_runtime_run_once(self) -> None:
|
||||||
if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='user'").fetchall()) == 0:
|
if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='user'").fetchall()) == 0:
|
||||||
self.cur.execute(
|
self.cur.execute("CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)")
|
||||||
"CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)")
|
|
||||||
if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0:
|
if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0:
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
"CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx text, mime_type text, file_name text, msg_js text, date_time int NOT NULL)")
|
"CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx text, mime_type text, file_name text, msg_js text, date_time int NOT NULL)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -185,10 +185,16 @@ async def get_tg_file_media(chat_id: int|str, msg_id: int, file_name: str, sign:
|
|||||||
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
|
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/tg/api/v1/client/login")
|
@app.get("/tg/api/v1/client/login")
|
||||||
@apiutils.atimeit
|
@apiutils.atimeit
|
||||||
async def login_new_tg_file_client():
|
async def login_new_tg_file_client():
|
||||||
raise NotImplementedError
|
url = await clients_mgr.login_clients()
|
||||||
|
return {"url": url}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/tg/api/v1/client/status")
|
||||||
|
async def get_tg_file_client_status(request: Request):
|
||||||
|
return await clients_mgr.get_status()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/tg/api/v1/client/link_convert")
|
@app.get("/tg/api/v1/client/link_convert")
|
||||||
|
174
frontend/home.py
174
frontend/home.py
@ -1,171 +1,17 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
sys.path.append(os.getcwd() + "/../")
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import qrcode
|
|
||||||
import pandas
|
|
||||||
import requests
|
|
||||||
|
|
||||||
import configParse
|
import remote_api as api
|
||||||
import utils
|
|
||||||
|
|
||||||
# qr = qrcode.make("https://www.baidu.com")
|
st.set_page_config(page_title="TgToolbox", page_icon="🕹️", layout="wide", initial_sidebar_state="collapsed")
|
||||||
# st.image(qrcode.make("https://www.baidu.com").get_image())
|
|
||||||
param = configParse.get_TgToFileSystemParameter()
|
|
||||||
background_server_url = f"{param.base.exposed_url}/tg/api/v1/file/search"
|
|
||||||
|
|
||||||
st.set_page_config(page_title="TgToolbox", page_icon='🕹️', layout='wide')
|
backend_status = api.get_backend_client_status()
|
||||||
|
need_login = False
|
||||||
|
|
||||||
if 'page_index' not in st.session_state:
|
for v in backend_status["clients"]:
|
||||||
st.session_state.page_index = 1
|
if not v["status"]:
|
||||||
if 'force_skip' not in st.session_state:
|
need_login = True
|
||||||
st.session_state.force_skip = False
|
|
||||||
|
|
||||||
if 'search_key' not in st.query_params:
|
if need_login:
|
||||||
st.query_params.search_key = ""
|
import login
|
||||||
if 'is_order' not in st.query_params:
|
|
||||||
st.query_params.is_order = False
|
|
||||||
if 'search_res_limit' not in st.query_params:
|
|
||||||
st.query_params.search_res_limit = "10"
|
|
||||||
|
|
||||||
@st.experimental_fragment
|
|
||||||
def search_container():
|
|
||||||
st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key)
|
|
||||||
columns = st.columns([7, 1])
|
|
||||||
with columns[0]:
|
|
||||||
st.query_params.search_res_limit = str(st.number_input(
|
|
||||||
"**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d"))
|
|
||||||
with columns[1]:
|
|
||||||
st.text("排序")
|
|
||||||
st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order))
|
|
||||||
|
|
||||||
search_container()
|
|
||||||
|
|
||||||
search_clicked = st.button('Search', type='primary', use_container_width=True)
|
|
||||||
if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None):
|
|
||||||
st.stop()
|
|
||||||
|
|
||||||
if not st.session_state.force_skip:
|
|
||||||
st.session_state.page_index = 1
|
|
||||||
if st.session_state.force_skip:
|
|
||||||
st.session_state.force_skip = False
|
|
||||||
|
|
||||||
@st.experimental_fragment
|
|
||||||
def do_search_req():
|
|
||||||
search_limit = int(st.query_params.search_res_limit)
|
|
||||||
offset_index = (st.session_state.page_index - 1) * search_limit
|
|
||||||
is_order = utils.strtobool(st.query_params.is_order)
|
|
||||||
|
|
||||||
req_body = {
|
|
||||||
"token": param.web.token,
|
|
||||||
"search": f"{st.query_params.search_key}",
|
|
||||||
"chat_id": param.web.chat_id[0],
|
|
||||||
"index": offset_index,
|
|
||||||
"length": search_limit,
|
|
||||||
"refresh": False,
|
|
||||||
"inner": False,
|
|
||||||
"inc": is_order,
|
|
||||||
}
|
|
||||||
|
|
||||||
req = requests.post(background_server_url, data=json.dumps(req_body))
|
|
||||||
if req.status_code != 200:
|
|
||||||
st.stop()
|
|
||||||
search_res = json.loads(req.content.decode("utf-8"))
|
|
||||||
|
|
||||||
def page_switch_render():
|
|
||||||
columns = st.columns(3)
|
|
||||||
with columns[0]:
|
|
||||||
if st.button("Prev", use_container_width=True):
|
|
||||||
st.session_state.page_index = st.session_state.page_index - 1
|
|
||||||
st.session_state.page_index = max(
|
|
||||||
st.session_state.page_index, 1)
|
|
||||||
st.session_state.force_skip = True
|
|
||||||
st.rerun()
|
|
||||||
with columns[1]:
|
|
||||||
# st.text(f"{st.session_state.page_index}")
|
|
||||||
st.markdown(
|
|
||||||
f"<p style='text-align: center;'>{st.session_state.page_index}</p>", unsafe_allow_html=True)
|
|
||||||
# st.markdown(f"<input type='number' style='text-align: center;' value={st.session_state.page_index}>", unsafe_allow_html=True)
|
|
||||||
with columns[2]:
|
|
||||||
if st.button("Next", use_container_width=True):
|
|
||||||
st.session_state.page_index = st.session_state.page_index + 1
|
|
||||||
st.session_state.force_skip = True
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str):
|
|
||||||
file_size_str = f"{file_size/1024/1024:.2f}MB"
|
|
||||||
container = st.container()
|
|
||||||
container_columns = container.columns([1, 99])
|
|
||||||
|
|
||||||
st.session_state.search_res_select_list[index] = container_columns[0].checkbox(
|
|
||||||
"search_res_checkbox_" + str(index), label_visibility='collapsed')
|
|
||||||
|
|
||||||
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} — *{file_size_str}*"
|
|
||||||
popover = container_columns[1].popover(expender_title, use_container_width=True)
|
|
||||||
popover_columns = popover.columns([1, 3, 1])
|
|
||||||
if url:
|
|
||||||
popover_columns[0].video(url)
|
|
||||||
else:
|
else:
|
||||||
popover_columns[0].video('./static/404.webm', format="video/webm")
|
import search
|
||||||
popover_columns[1].markdown(f'{msg_ctx}')
|
|
||||||
popover_columns[1].markdown(f'**{file_name}**')
|
|
||||||
popover_columns[1].markdown(f'文件大小:*{file_size_str}*')
|
|
||||||
popover_columns[2].link_button('⬇️Download Link', url, use_container_width=True)
|
|
||||||
popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True)
|
|
||||||
|
|
||||||
@st.experimental_fragment
|
|
||||||
def show_search_res(res: dict[str, any]):
|
|
||||||
sign_token = ""
|
|
||||||
try:
|
|
||||||
sign_token = res['client']['sign']
|
|
||||||
except Exception as err:
|
|
||||||
pass
|
|
||||||
search_res_list = res.get('list')
|
|
||||||
if search_res_list is None or len(search_res_list) == 0:
|
|
||||||
st.info("No result")
|
|
||||||
page_switch_render()
|
|
||||||
st.stop()
|
|
||||||
st.session_state.search_res_select_list = [False] * len(search_res_list)
|
|
||||||
url_list = []
|
|
||||||
for i in range(len(search_res_list)):
|
|
||||||
v = search_res_list[i]
|
|
||||||
msg_ctx= ""
|
|
||||||
file_name = None
|
|
||||||
file_size = 0
|
|
||||||
download_url = ""
|
|
||||||
src_link = ""
|
|
||||||
try:
|
|
||||||
src_link = v['src_tg_link']
|
|
||||||
msg_ctx = v['message']
|
|
||||||
msg_id = str(v['id'])
|
|
||||||
doc = v['media']['document']
|
|
||||||
file_size = doc['size']
|
|
||||||
if doc is not None:
|
|
||||||
for attr in doc['attributes']:
|
|
||||||
file_name = attr.get('file_name')
|
|
||||||
if file_name != "" and file_name is not None:
|
|
||||||
break
|
|
||||||
if file_name == "" or file_name is None:
|
|
||||||
file_name = "Can not get file name"
|
|
||||||
download_url = v['download_url']
|
|
||||||
download_url += f'?sign={sign_token}'
|
|
||||||
url_list.append(download_url)
|
|
||||||
except Exception as err:
|
|
||||||
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
|
|
||||||
media_file_res_container(
|
|
||||||
i, msg_ctx, file_name, file_size, download_url, src_link)
|
|
||||||
page_switch_render()
|
|
||||||
|
|
||||||
show_text = ""
|
|
||||||
select_list = st.session_state.search_res_select_list
|
|
||||||
for i in range(len(select_list)):
|
|
||||||
if select_list[i]:
|
|
||||||
show_text = show_text + url_list[i] + '\n'
|
|
||||||
st.text_area("链接", value=show_text)
|
|
||||||
|
|
||||||
show_search_res(search_res)
|
|
||||||
|
|
||||||
|
|
||||||
do_search_req()
|
|
||||||
|
24
frontend/login.py
Normal file
24
frontend/login.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import qrcode
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd() + "/../")
|
||||||
|
import configParse
|
||||||
|
import utils
|
||||||
|
import remote_api as api
|
||||||
|
|
||||||
|
url = api.login_client_by_qr_code_url()
|
||||||
|
|
||||||
|
if url is None or url == "":
|
||||||
|
st.text("Something wrong, no login url got.")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
st.markdown("### Please scan the qr code by telegram client.")
|
||||||
|
qr = qrcode.make(url)
|
||||||
|
st.image(qr.get_image())
|
||||||
|
|
||||||
|
st.markdown("**Click the Refrash button if you have been scaned**")
|
||||||
|
if st.button("Refresh"):
|
||||||
|
st.rerun()
|
58
frontend/remote_api.py
Normal file
58
frontend/remote_api.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd() + "/../")
|
||||||
|
import configParse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__.split("/")[-1])
|
||||||
|
|
||||||
|
param = configParse.get_TgToFileSystemParameter()
|
||||||
|
|
||||||
|
background_server_url = f"{param.base.exposed_url}"
|
||||||
|
search_api_route = "/tg/api/v1/file/search"
|
||||||
|
status_api_route = "/tg/api/v1/client/status"
|
||||||
|
login_api_route = "/tg/api/v1/client/login"
|
||||||
|
|
||||||
|
|
||||||
|
def login_client_by_qr_code_url() -> str:
|
||||||
|
request_url = background_server_url + login_api_route
|
||||||
|
response = requests.get(request_url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.warning(f"Could not login, err:{response.status_code}, {response.content.decode('utf-8')}")
|
||||||
|
return None
|
||||||
|
url_info = json.loads(response.content.decode("utf-8"))
|
||||||
|
return url_info.get("url")
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend_client_status() -> dict[str, any]:
|
||||||
|
request_url = background_server_url + status_api_route
|
||||||
|
response = requests.get(request_url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.warning(f"get_status, backend is running? err:{response.status_code}, {response.content.decode('utf-8')}")
|
||||||
|
return None
|
||||||
|
return json.loads(response.content.decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def search_database_by_keyword(keyword: str, offset: int, limit: int, is_order: bool) -> list[any] | None:
|
||||||
|
request_url = background_server_url + search_api_route
|
||||||
|
req_body = {
|
||||||
|
"token": param.web.token,
|
||||||
|
"search": keyword,
|
||||||
|
"chat_id": param.web.chat_id[0],
|
||||||
|
"index": offset,
|
||||||
|
"length": limit,
|
||||||
|
"refresh": False,
|
||||||
|
"inner": False,
|
||||||
|
"inc": is_order,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(request_url, data=json.dumps(req_body))
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.warning(f"search_database_by_keyword err:{response.status_code}, {response.content.decode('utf-8')}")
|
||||||
|
return None
|
||||||
|
search_res = json.loads(response.content.decode("utf-8"))
|
||||||
|
return search_res
|
151
frontend/search.py
Normal file
151
frontend/search.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd() + "/../")
|
||||||
|
import configParse
|
||||||
|
import utils
|
||||||
|
import remote_api as api
|
||||||
|
|
||||||
|
param = configParse.get_TgToFileSystemParameter()
|
||||||
|
|
||||||
|
if 'page_index' not in st.session_state:
|
||||||
|
st.session_state.page_index = 1
|
||||||
|
if 'force_skip' not in st.session_state:
|
||||||
|
st.session_state.force_skip = False
|
||||||
|
|
||||||
|
if 'search_key' not in st.query_params:
|
||||||
|
st.query_params.search_key = ""
|
||||||
|
if 'is_order' not in st.query_params:
|
||||||
|
st.query_params.is_order = False
|
||||||
|
if 'search_res_limit' not in st.query_params:
|
||||||
|
st.query_params.search_res_limit = "10"
|
||||||
|
|
||||||
|
@st.experimental_fragment
|
||||||
|
def search_container():
|
||||||
|
st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key)
|
||||||
|
columns = st.columns([7, 1])
|
||||||
|
with columns[0]:
|
||||||
|
st.query_params.search_res_limit = str(st.number_input(
|
||||||
|
"**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d"))
|
||||||
|
with columns[1]:
|
||||||
|
st.text("排序")
|
||||||
|
st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order))
|
||||||
|
|
||||||
|
search_container()
|
||||||
|
|
||||||
|
search_clicked = st.button('Search', type='primary', use_container_width=True)
|
||||||
|
if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None):
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
if not st.session_state.force_skip:
|
||||||
|
st.session_state.page_index = 1
|
||||||
|
if st.session_state.force_skip:
|
||||||
|
st.session_state.force_skip = False
|
||||||
|
|
||||||
|
@st.experimental_fragment
|
||||||
|
def do_search_req():
|
||||||
|
search_limit = int(st.query_params.search_res_limit)
|
||||||
|
offset_index = (st.session_state.page_index - 1) * search_limit
|
||||||
|
is_order = utils.strtobool(st.query_params.is_order)
|
||||||
|
|
||||||
|
search_res = api.search_database_by_keyword(st.query_params.search_key, offset_index, search_limit, is_order)
|
||||||
|
if search_res is None:
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
def page_switch_render():
|
||||||
|
columns = st.columns(3)
|
||||||
|
with columns[0]:
|
||||||
|
if st.button("Prev", use_container_width=True):
|
||||||
|
st.session_state.page_index = st.session_state.page_index - 1
|
||||||
|
st.session_state.page_index = max(
|
||||||
|
st.session_state.page_index, 1)
|
||||||
|
st.session_state.force_skip = True
|
||||||
|
st.rerun()
|
||||||
|
with columns[1]:
|
||||||
|
# st.text(f"{st.session_state.page_index}")
|
||||||
|
st.markdown(
|
||||||
|
f"<p style='text-align: center;'>{st.session_state.page_index}</p>", unsafe_allow_html=True)
|
||||||
|
# st.markdown(f"<input type='number' style='text-align: center;' value={st.session_state.page_index}>", unsafe_allow_html=True)
|
||||||
|
with columns[2]:
|
||||||
|
if st.button("Next", use_container_width=True):
|
||||||
|
st.session_state.page_index = st.session_state.page_index + 1
|
||||||
|
st.session_state.force_skip = True
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str):
|
||||||
|
file_size_str = f"{file_size/1024/1024:.2f}MB"
|
||||||
|
container = st.container()
|
||||||
|
container_columns = container.columns([1, 99])
|
||||||
|
|
||||||
|
st.session_state.search_res_select_list[index] = container_columns[0].checkbox(
|
||||||
|
"search_res_checkbox_" + str(index), label_visibility='collapsed')
|
||||||
|
|
||||||
|
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} — *{file_size_str}*"
|
||||||
|
popover = container_columns[1].popover(expender_title, use_container_width=True)
|
||||||
|
popover_columns = popover.columns([1, 3, 1])
|
||||||
|
if url:
|
||||||
|
popover_columns[0].video(url)
|
||||||
|
else:
|
||||||
|
popover_columns[0].video('./static/404.webm', format="video/webm")
|
||||||
|
popover_columns[1].markdown(f'{msg_ctx}')
|
||||||
|
popover_columns[1].markdown(f'**{file_name}**')
|
||||||
|
popover_columns[1].markdown(f'文件大小:*{file_size_str}*')
|
||||||
|
popover_columns[2].link_button('⬇️Download Link', url, use_container_width=True)
|
||||||
|
popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True)
|
||||||
|
|
||||||
|
@st.experimental_fragment
|
||||||
|
def show_search_res(res: dict[str, any]):
|
||||||
|
search_res_list = res.get("list")
|
||||||
|
if search_res_list is None or len(search_res_list) == 0:
|
||||||
|
st.info("No result")
|
||||||
|
page_switch_render()
|
||||||
|
st.stop()
|
||||||
|
sign_token = ""
|
||||||
|
try:
|
||||||
|
sign_token = res['client']['sign']
|
||||||
|
except Exception as err:
|
||||||
|
pass
|
||||||
|
st.session_state.search_res_select_list = [False] * len(search_res_list)
|
||||||
|
url_list = []
|
||||||
|
for i in range(len(search_res_list)):
|
||||||
|
v = search_res_list[i]
|
||||||
|
msg_ctx= ""
|
||||||
|
file_name = None
|
||||||
|
file_size = 0
|
||||||
|
download_url = ""
|
||||||
|
src_link = ""
|
||||||
|
try:
|
||||||
|
src_link = v['src_tg_link']
|
||||||
|
msg_ctx = v['message']
|
||||||
|
msg_id = str(v['id'])
|
||||||
|
doc = v['media']['document']
|
||||||
|
file_size = doc['size']
|
||||||
|
if doc is not None:
|
||||||
|
for attr in doc['attributes']:
|
||||||
|
file_name = attr.get('file_name')
|
||||||
|
if file_name != "" and file_name is not None:
|
||||||
|
break
|
||||||
|
if file_name == "" or file_name is None:
|
||||||
|
file_name = "Can not get file name"
|
||||||
|
download_url = v['download_url']
|
||||||
|
download_url += f'?sign={sign_token}'
|
||||||
|
url_list.append(download_url)
|
||||||
|
except Exception as err:
|
||||||
|
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
|
||||||
|
media_file_res_container(
|
||||||
|
i, msg_ctx, file_name, file_size, download_url, src_link)
|
||||||
|
page_switch_render()
|
||||||
|
|
||||||
|
show_text = ""
|
||||||
|
select_list = st.session_state.search_res_select_list
|
||||||
|
for i in range(len(select_list)):
|
||||||
|
if select_list[i]:
|
||||||
|
show_text = show_text + url_list[i] + '\n'
|
||||||
|
st.text_area("链接", value=show_text)
|
||||||
|
|
||||||
|
show_search_res(search_res)
|
||||||
|
|
||||||
|
|
||||||
|
do_search_req()
|
@ -1,6 +1,7 @@
|
|||||||
toml
|
toml
|
||||||
telethon
|
telethon
|
||||||
# python-socks[asyncio]
|
# python-socks[asyncio]
|
||||||
|
diskcache
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
streamlit
|
streamlit
|
||||||
|
Loading…
x
Reference in New Issue
Block a user