feat: support multi search

This commit is contained in:
Hehesheng 2024-06-09 15:41:18 +08:00
parent 50af6974da
commit 9c12ae6e79
8 changed files with 204 additions and 136 deletions

View File

@ -101,9 +101,9 @@ class MediaChunkHolder(object):
self.mem = self.mem + mem self.mem = self.mem + mem
self.length = len(self.mem) self.length = len(self.mem)
if self.length > self.target_len: if self.length > self.target_len:
logger.warning(RuntimeWarning( logger.warning(RuntimeWarning(f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}"))
f"MeidaChunk Overflow:start:{self.start},len:{self.length},tlen:{self.target_len}"))
self.notify_waiters() self.notify_waiters()
self.try_clear_waiter_and_requester()
def add_chunk_requester(self, req: Request) -> None: def add_chunk_requester(self, req: Request) -> None:
if self.is_completed(): if self.is_completed():

View File

@ -7,7 +7,7 @@ import os
import functools import functools
import traceback import traceback
import logging import logging
from typing import Union, Optional, Literal from typing import Union, Optional, Literal, Callable
from telethon import TelegramClient, types, hints, events from telethon import TelegramClient, types, hints, events
from telethon.custom import QRLogin from telethon.custom import QRLogin
@ -32,8 +32,7 @@ class TgFileSystemClient(object):
client: TelegramClient client: TelegramClient
media_chunk_manager: MediaChunkHolderManager media_chunk_manager: MediaChunkHolderManager
dialogs_cache: Optional[hints.TotalList] = None dialogs_cache: Optional[hints.TotalList] = None
msg_cache: list[types.Message] = [] worker_routines: list[asyncio.Task]
worker_routines: list[asyncio.Task] = []
qr_login: QRLogin | None = None qr_login: QRLogin | None = None
login_task: asyncio.Task | None = None login_task: asyncio.Task | None = None
# rsa key # rsa key
@ -80,6 +79,7 @@ class TgFileSystemClient(object):
) )
self.media_chunk_manager = MediaChunkHolderManager() self.media_chunk_manager = MediaChunkHolderManager()
self.db = db self.db = db
self.worker_routines = []
def __del__(self) -> None: def __del__(self) -> None:
if self.client.loop.is_running(): if self.client.loop.is_running():
@ -164,7 +164,6 @@ class TgFileSystemClient(object):
if len(self.client_param.whitelist_chat) > 0: if len(self.client_param.whitelist_chat) > 0:
self._register_update_event(from_users=self.client_param.whitelist_chat) self._register_update_event(from_users=self.client_param.whitelist_chat)
await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat())) await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat()))
# await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat2()))
async def stop(self) -> None: async def stop(self) -> None:
await self.client.loop.create_task(self._cancel_tasks()) await self.client.loop.create_task(self._cancel_tasks())
@ -181,16 +180,16 @@ class TgFileSystemClient(object):
logger.error(f"{err=}") logger.error(f"{err=}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def _cache_whitelist_chat2(self): async def _cache_whitelist_chat_full_policy(self, chat_id: int, callback: Callable = None):
for chat_id in self.client_param.whitelist_chat:
async for msg in self.client.iter_messages(chat_id): async for msg in self.client.iter_messages(chat_id):
if len(self.db.get_msg_by_unique_id(UserManager.generate_unique_id_by_msg(self.me, msg))) != 0: if len(self.db.get_msg_by_unique_id(UserManager.generate_unique_id_by_msg(self.me, msg))) != 0:
continue continue
self.db.insert_by_message(self.me, msg) self.db.insert_by_message(self.me, msg)
if callback is not None:
callback()
logger.info(f"{chat_id} quit cache task.") logger.info(f"{chat_id} quit cache task.")
async def _cache_whitelist_chat(self): async def _cache_whitelist_chat_lazy_policy(self, chat_id: int, callback: Callable = None):
for chat_id in self.client_param.whitelist_chat:
# update newest msg # update newest msg
newest_msg = self.db.get_newest_msg_by_chat_id(chat_id) newest_msg = self.db.get_newest_msg_by_chat_id(chat_id)
if len(newest_msg) > 0: if len(newest_msg) > 0:
@ -209,8 +208,23 @@ class TgFileSystemClient(object):
else: else:
async for msg in self.client.iter_messages(chat_id): async for msg in self.client.iter_messages(chat_id):
self.db.insert_by_message(self.me, msg) self.db.insert_by_message(self.me, msg)
if callback is not None:
callback()
logger.info(f"{chat_id} quit cache task.") logger.info(f"{chat_id} quit cache task.")
async def _cache_whitelist_chat(self):
max_cache_tasks_num = TgFileSystemClient.MAX_WORKER_ROUTINE // 2
tasks_sem = asyncio.Semaphore(value=max_cache_tasks_num)
def _sem_release_callback():
tasks_sem.release()
for chat_id in self.client_param.whitelist_chat:
await tasks_sem.acquire()
await self.task_queue.put(
(self._get_unique_task_id(), self._cache_whitelist_chat_lazy_policy(chat_id, callback=_sem_release_callback))
)
@_acheck_before_call @_acheck_before_call
async def get_message(self, chat_id: int | str, msg_id: int) -> types.Message: async def get_message(self, chat_id: int | str, msg_id: int) -> types.Message:
msg = await self.client.get_messages(chat_id, ids=msg_id) msg = await self.client.get_messages(chat_id, ids=msg_id)
@ -287,17 +301,15 @@ class TgFileSystemClient(object):
async def get_messages_by_search_db( async def get_messages_by_search_db(
self, self,
chat_id: int, chat_ids: list[int],
search_word: str, search_word: str,
limit: int = 10, limit: int = 10,
offset: int = 0, offset: int = 0,
inc: bool = False, inc: bool = False,
ignore_case: bool = False, ignore_case: bool = False,
) -> list[any]: ) -> list[any]:
if chat_id not in self.client_param.whitelist_chat:
return []
res = self.db.get_msg_by_chat_id_and_keyword( res = self.db.get_msg_by_chat_id_and_keyword(
chat_id, chat_ids,
search_word, search_word,
limit=limit, limit=limit,
offset=offset, offset=offset,
@ -333,8 +345,6 @@ class TgFileSystemClient(object):
f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}" f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}"
) )
else: else:
if not media_holder.try_clear_waiter_and_requester():
logger.error("I think never run here.")
if not self.media_chunk_manager.move_media_chunk_to_disk(media_holder): if not self.media_chunk_manager.move_media_chunk_to_disk(media_holder):
logger.warning(f"move to disk failed, {media_holder=}") logger.warning(f"move to disk failed, {media_holder=}")
logger.debug(f"downloaded chunk:{offset=},{target_size=},{media_holder}") logger.debug(f"downloaded chunk:{offset=},{target_size=},{media_holder}")

View File

@ -67,18 +67,26 @@ class UserManager(object):
def get_msg_by_chat_id_and_keyword( def get_msg_by_chat_id_and_keyword(
self, self,
chat_id: int, chat_ids: list[int],
keyword: str, keyword: str,
limit: int = 10, limit: int = 10,
offset: int = 0, offset: int = 0,
inc: bool = False, inc: bool = False,
ignore_case: bool = False, ignore_case: bool = False,
) -> list[any]: ) -> list[any]:
if not chat_ids:
logger.warning(f"chat_ids is empty.")
return []
keyword_condition = "msg_ctx LIKE '%{key}%' OR file_name LIKE '%{key}%'" keyword_condition = "msg_ctx LIKE '%{key}%' OR file_name LIKE '%{key}%'"
if ignore_case: if ignore_case:
keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')" keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')"
keyword_condition = keyword_condition.format(key=keyword) keyword_condition = keyword_condition.format(key=keyword)
execute_script = f"SELECT * FROM message WHERE chat_id = {chat_id} AND ({keyword_condition}) ORDER BY date_time {'' if inc else 'DESC '}LIMIT {limit} OFFSET {offset}" chat_ids_str = ""
if len(chat_ids) > 1:
chat_ids_str = f"{tuple(chat_ids)}"
else:
chat_ids_str = f"({chat_ids[0]})"
execute_script = f"SELECT * FROM message WHERE chat_id in {chat_ids_str} AND ({keyword_condition}) ORDER BY date_time {'' if inc else 'DESC '}LIMIT {limit} OFFSET {offset}"
logger.info(f"{execute_script=}") logger.info(f"{execute_script=}")
res = self.cur.execute(execute_script) res = self.cur.execute(execute_script)
return res return res

View File

@ -39,16 +39,18 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
class TgToFileListRequestBody(BaseModel): class TgToFileListRequestBody(BaseModel):
token: str token: str
search: str = "" search: str = ""
chat_id: int = 0 chat_ids: list[int] = []
index: int = 0 index: int = 0
length: int = 10 length: int = 10
refresh: bool = False refresh: bool = False
inner: bool = False inner: bool = False
inc: bool = False inc: bool = False
@app.post("/tg/api/v1/file/search") @app.post("/tg/api/v1/file/search")
@apiutils.atimeit @apiutils.atimeit
async def search_tg_file_list(body: TgToFileListRequestBody): async def search_tg_file_list(body: TgToFileListRequestBody):
@ -59,19 +61,21 @@ async def search_tg_file_list(body: TgToFileListRequestBody):
res_type = "msg" res_type = "msg"
client = await clients_mgr.get_client_force(body.token) client = await clients_mgr.get_client_force(body.token)
res_dict = [] res_dict = []
res = await client.get_messages_by_search_db(body.chat_id, body.search, limit=body.length, inc=body.inc, offset=body.index) res = await client.get_messages_by_search_db(
body.chat_ids, body.search, limit=body.length, inc=body.inc, offset=body.index
)
for item in res: for item in res:
msg_info = json.loads(item) msg_info = json.loads(item)
file_name = apiutils.get_message_media_name_from_dict(msg_info) file_name = apiutils.get_message_media_name_from_dict(msg_info)
chat_id = apiutils.get_message_chat_id_from_dict(msg_info) chat_id = apiutils.get_message_chat_id_from_dict(msg_info)
msg_id = apiutils.get_message_msg_id_from_dict(msg_info) msg_id = apiutils.get_message_msg_id_from_dict(msg_info)
msg_info['file_name'] = file_name msg_info["file_name"] = file_name
msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}" msg_info["download_url"] = f"{param.base.exposed_url}/tg/api/v1/file/get/{chat_id}/{msg_id}/{file_name}"
msg_info['src_tg_link'] = f"https://t.me/c/{chat_id}/{msg_id}" msg_info["src_tg_link"] = f"https://t.me/c/{chat_id}/{msg_id}"
res_dict.append(msg_info) res_dict.append(msg_info)
client_dict = json.loads(client.to_json()) client_dict = json.loads(client.to_json())
client_dict['sign'] = body.token client_dict["sign"] = body.token
response_dict = { response_dict = {
"client": client_dict, "client": client_dict,
@ -95,7 +99,9 @@ async def get_tg_file_list(body: TgToFileListRequestBody):
client = await clients_mgr.get_client_force(body.token) client = await clients_mgr.get_client_force(body.token)
res_dict = [] res_dict = []
if body.search != "": if body.search != "":
res = await client.get_messages_by_search(body.chat_id, search_word=body.search, limit=body.length, offset=body.index, inner_search=body.inner) res = await client.get_messages_by_search(
body.chat_id, search_word=body.search, limit=body.length, offset=body.index, inner_search=body.inner
)
else: else:
res = await client.get_messages(body.chat_id, limit=body.length, offset=body.index) res = await client.get_messages(body.chat_id, limit=body.length, offset=body.index)
res_type = "msg" res_type = "msg"
@ -104,8 +110,10 @@ async def get_tg_file_list(body: TgToFileListRequestBody):
if file_name == "": if file_name == "":
file_name = "unknown.tmp" file_name = "unknown.tmp"
msg_info = json.loads(item.to_json()) msg_info = json.loads(item.to_json())
msg_info['file_name'] = file_name msg_info["file_name"] = file_name
msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{item.id}/{file_name}?sign={body.token}" msg_info["download_url"] = (
f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{item.id}/{file_name}?sign={body.token}"
)
res_dict.append(msg_info) res_dict.append(msg_info)
response_dict = { response_dict = {
@ -130,10 +138,7 @@ async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Requ
"accept-ranges": "bytes", "accept-ranges": "bytes",
"content-encoding": "identity", "content-encoding": "identity",
# "content-length": stream_file_size, # "content-length": stream_file_size,
"access-control-expose-headers": ( "access-control-expose-headers": ("content-type, accept-ranges, content-length, " "content-range, content-encoding"),
"content-type, accept-ranges, content-length, "
"content-range, content-encoding"
),
} }
range_header = request.headers.get("range") range_header = request.headers.get("range")
try: try:
@ -151,8 +156,7 @@ async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Requ
if file_name == "": if file_name == "":
maybe_file_type = mime_type.split("/")[-1] maybe_file_type = mime_type.split("/")[-1]
file_name = f"{chat_id}.{msg_id}.{maybe_file_type}" file_name = f"{chat_id}.{msg_id}.{maybe_file_type}"
headers[ headers["Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'"
"Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'"
if range_header is not None: if range_header is not None:
start, end = apiutils.get_range_header(range_header, file_size) start, end = apiutils.get_range_header(range_header, file_size)
@ -224,6 +228,7 @@ class TgToChatListRequestBody(BaseModel):
length: int = 0 length: int = 0
refresh: bool = False refresh: bool = False
@app.post("/tg/api/v1/client/chat") @app.post("/tg/api/v1/client/chat")
@apiutils.atimeit @apiutils.atimeit
async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Request): async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Request):
@ -235,8 +240,16 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques
res_dict = {} res_dict = {}
res = await client.get_dialogs(limit=body.length, offset=body.index, refresh=body.refresh) res = await client.get_dialogs(limit=body.length, offset=body.index, refresh=body.refresh)
res_dict = [{"id": item.id, "is_channel": item.is_channel, res_dict = [
"is_group": item.is_group, "is_user": item.is_user, "name": item.name, } for item in res] {
"id": item.id,
"is_channel": item.is_channel,
"is_group": item.is_group,
"is_user": item.is_user,
"name": item.name,
}
for item in res
]
response_dict = { response_dict = {
"client": json.loads(client.to_json()), "client": json.loads(client.to_json()),
@ -249,6 +262,7 @@ async def get_tg_client_chat_list(body: TgToChatListRequestBody, request: Reques
logger.error(f"{err=}") logger.error(f"{err=}")
return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND) return Response(json.dumps({"detail": f"{err=}"}), status_code=status.HTTP_404_NOT_FOUND)
if __name__ == "__main__": if __name__ == "__main__":
param = configParse.get_TgToFileSystemParameter() param = configParse.get_TgToFileSystemParameter()
uvicorn.run(app, host="0.0.0.0", port=param.base.port) uvicorn.run(app, host="0.0.0.0", port=param.base.port)

View File

@ -40,16 +40,20 @@ async def link_convert(link: str) -> str:
return url return url
async def get_clients_manager_status(detail) -> dict[str, any]: async def get_chat_details(mgr: TgFileSystemClientManager) -> dict[int, any]:
clients_mgr = TgFileSystemClientManager.get_instance()
ret = await clients_mgr.get_status()
if not detail:
return ret
chat_details = {} chat_details = {}
for _, client in clients_mgr.clients.items(): for _, client in mgr.clients.items():
chat_list = client.client_param.whitelist_chat chat_list = client.client_param.whitelist_chat
for chat_id in chat_list: for chat_id in chat_list:
chat_entity = await client.get_entity(chat_id) chat_entity = await client.get_entity(chat_id)
chat_details[chat_id] = json.loads(chat_entity.to_json()) chat_details[chat_id] = json.loads(chat_entity.to_json())
ret["info"] = chat_details return chat_details
async def get_clients_manager_status(detail: bool) -> dict[str, any]:
clients_mgr = TgFileSystemClientManager.get_instance()
ret = await clients_mgr.get_status()
if not detail:
return ret
ret["clist"] = await get_chat_details(clients_mgr)
return ret return ret

View File

@ -35,7 +35,6 @@ class TgToFileSystemParameter(BaseModel):
enable: bool = False enable: bool = False
token: str = "" token: str = ""
port: int = 2000 port: int = 2000
chat_id: list[int] = []
web: TgWebParameter web: TgWebParameter
@functools.lru_cache @functools.lru_cache

View File

@ -2,6 +2,7 @@ import sys
import os import os
import json import json
import logging import logging
import traceback
from urllib.parse import quote from urllib.parse import quote
import requests import requests
@ -9,8 +10,6 @@ import requests
sys.path.append(os.getcwd() + "/../") sys.path.append(os.getcwd() + "/../")
import configParse import configParse
logger = logging.getLogger(__file__.split("/")[-1])
param = configParse.get_TgToFileSystemParameter() param = configParse.get_TgToFileSystemParameter()
background_server_url = f"{param.base.exposed_url}" background_server_url = f"{param.base.exposed_url}"
@ -23,7 +22,7 @@ def login_client_by_qr_code_url() -> str:
request_url = background_server_url + login_api_route request_url = background_server_url + login_api_route
response = requests.get(request_url) response = requests.get(request_url)
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"Could not login, err:{response.status_code}, {response.content.decode('utf-8')}") logging.warning(f"Could not login, err:{response.status_code}, {response.content.decode('utf-8')}")
return None return None
url_info = json.loads(response.content.decode("utf-8")) url_info = json.loads(response.content.decode("utf-8"))
return url_info.get("url") return url_info.get("url")
@ -32,24 +31,33 @@ def login_client_by_qr_code_url() -> str:
status_api_route = "/tg/api/v1/client/status" status_api_route = "/tg/api/v1/client/status"
def get_backend_client_status() -> dict[str, any]: def get_backend_client_status(flag: bool = False) -> dict[str, any]:
request_url = background_server_url + status_api_route request_url = f"{background_server_url}{status_api_route}?flag={flag}"
response = requests.get(request_url) response = requests.get(request_url)
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"get_status, backend is running? err:{response.status_code}, {response.content.decode('utf-8')}") logging.warning(f"get_status, backend is running? err:{response.status_code}, {response.content.decode('utf-8')}")
return None return None
return json.loads(response.content.decode("utf-8")) return json.loads(response.content.decode("utf-8"))
def get_white_list_chat_dict() -> dict[str, any]:
backend_status = get_backend_client_status(True)
try:
return backend_status["clist"]
except Exception as err:
logging.warning(f"{err=},{traceback.format_exc()}")
return {}
search_api_route = "/tg/api/v1/file/search" search_api_route = "/tg/api/v1/file/search"
def search_database_by_keyword(keyword: str, offset: int, limit: int, is_order: bool) -> list[any] | None: def search_database_by_keyword(keyword: str, chat_list: list[int], offset: int, limit: int, is_order: bool) -> list[any] | None:
request_url = background_server_url + search_api_route request_url = background_server_url + search_api_route
req_body = { req_body = {
"token": param.web.token, "token": param.web.token,
"search": keyword, "search": keyword,
"chat_id": param.web.chat_id[0], "chat_ids": chat_list,
"index": offset, "index": offset,
"length": limit, "length": limit,
"refresh": False, "refresh": False,
@ -59,7 +67,7 @@ def search_database_by_keyword(keyword: str, offset: int, limit: int, is_order:
response = requests.post(request_url, data=json.dumps(req_body)) response = requests.post(request_url, data=json.dumps(req_body))
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"search_database_by_keyword err:{response.status_code}, {response.content.decode('utf-8')}") logging.warning(f"search_database_by_keyword err:{response.status_code}, {response.content.decode('utf-8')}")
return None return None
search_res = json.loads(response.content.decode("utf-8")) search_res = json.loads(response.content.decode("utf-8"))
return search_res return search_res
@ -73,7 +81,7 @@ def convert_tg_link_to_proxy_link(link: str) -> str:
request_url = background_server_url + link_convert_api_route + f"?link={link}" request_url = background_server_url + link_convert_api_route + f"?link={link}"
response = requests.get(request_url) response = requests.get(request_url)
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}") logging.warning(f"link convert fail: {response.status_code}, {response.content.decode('utf-8')}")
return "" return ""
response_js = json.loads(response.content.decode("utf-8")) response_js = json.loads(response.content.decode("utf-8"))
return response_js["url"] return response_js["url"]

View File

@ -1,40 +1,58 @@
import sys import sys
import os import os
import logging
import traceback
import streamlit as st import streamlit as st
import utils import utils
import remote_api as api import remote_api as api
@st.experimental_fragment @st.experimental_fragment
def loop(): def loop():
if 'page_index' not in st.session_state: if "page_index" not in st.session_state:
st.session_state.page_index = 1 st.session_state.page_index = 1
if 'force_skip' not in st.session_state: if "force_skip" not in st.session_state:
st.session_state.force_skip = False st.session_state.force_skip = False
if "chat_select_list" not in st.session_state:
st.session_state.chat_select_list = []
if 'search_key' not in st.query_params: if "search_key" not in st.query_params:
st.query_params.search_key = "" st.query_params.search_key = ""
if 'is_order' not in st.query_params: if "is_order" not in st.query_params:
st.query_params.is_order = False st.query_params.is_order = False
if 'search_res_limit' not in st.query_params: if "search_res_limit" not in st.query_params:
st.query_params.search_res_limit = "10" st.query_params.search_res_limit = "10"
@st.experimental_fragment @st.experimental_fragment
def search_container(): def search_container():
st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key) if "chat_dict" not in st.session_state:
st.session_state.chat_dict = api.get_white_list_chat_dict()
columns = st.columns([1, 1])
st.query_params.search_key = columns[0].text_input("**搜索🔎**", value=st.query_params.search_key)
chat_list = []
for _, chat_info in st.session_state.chat_dict.items():
chat_list.append(chat_info["title"])
st.session_state.chat_select_list = columns[1].multiselect("**Search in**", chat_list, default=chat_list)
columns = st.columns([7, 1]) columns = st.columns([7, 1])
with columns[0]: with columns[0]:
st.query_params.search_res_limit = str(st.number_input( st.query_params.search_res_limit = str(
"**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d")) st.number_input(
"**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d"
)
)
with columns[1]: with columns[1]:
st.text("排序") st.text("排序")
st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order)) st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order))
search_container() search_container()
search_clicked = st.button('Search', type='primary', use_container_width=True) search_clicked = st.button("Search", type="primary", use_container_width=True)
if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None): if not st.session_state.force_skip and (
not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None
):
return return
if not st.session_state.force_skip: if not st.session_state.force_skip:
@ -48,56 +66,62 @@ def loop():
offset_index = (st.session_state.page_index - 1) * search_limit offset_index = (st.session_state.page_index - 1) * search_limit
is_order = utils.strtobool(st.query_params.is_order) is_order = utils.strtobool(st.query_params.is_order)
search_res = api.search_database_by_keyword(st.query_params.search_key, offset_index, search_limit, is_order) status_bar = st.empty()
status_bar.status("Searching......")
search_chat_id_list = []
for chat_id, chat_info in st.session_state.chat_dict.items():
try:
if chat_info["title"] in st.session_state.chat_select_list:
search_chat_id_list.append(int(chat_id))
except Exception as err:
logging.warning(f"{err=},{traceback.format_exc()}")
search_res = api.search_database_by_keyword(
st.query_params.search_key, search_chat_id_list, offset_index, search_limit, is_order
)
status_bar.empty()
if search_res is None: if search_res is None:
return return
def page_switch_render(): def page_switch_render():
columns = st.columns(3) page_index = st.number_input(
with columns[0]: "Page number:",
if st.button("Prev", use_container_width=True): key="page_index_input",
st.session_state.page_index = st.session_state.page_index - 1 min_value=1,
st.session_state.page_index = max( max_value=100,
st.session_state.page_index, 1) value=st.session_state.page_index,
format="%d",
)
if page_index != st.session_state.page_index:
st.session_state.page_index = page_index
st.session_state.force_skip = True st.session_state.force_skip = True
st.rerun() st.rerun()
with columns[1]:
# st.text(f"{st.session_state.page_index}")
st.markdown(
f"<p style='text-align: center;'>{st.session_state.page_index}</p>", unsafe_allow_html=True)
# st.markdown(f"<input type='number' style='text-align: center;' value={st.session_state.page_index}>", unsafe_allow_html=True)
with columns[2]:
if st.button("Next", use_container_width=True):
st.session_state.page_index = st.session_state.page_index + 1
st.session_state.force_skip = True
st.rerun()
st.session_state.page_index = st.number_input(
"page_index", key="page_index_input", min_value=1, max_value=100, value=st.session_state.page_index, format="%d", label_visibility="hidden")
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str): def media_file_res_container(
index: int, msg_ctx: str, file_name: str, file_size: int, url: str, src_link: str, mime_type: str
):
file_size_str = f"{file_size/1024/1024:.2f}MB" file_size_str = f"{file_size/1024/1024:.2f}MB"
container = st.container() container = st.container()
container_columns = container.columns([1, 99]) container_columns = container.columns([1, 99])
st.session_state.search_res_select_list[index] = container_columns[0].checkbox( st.session_state.search_res_select_list[index] = container_columns[0].checkbox(
"search_res_checkbox_" + str(index), label_visibility='collapsed') "search_res_checkbox_" + str(index), label_visibility="collapsed"
)
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*" expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*"
popover = container_columns[1].popover(expender_title, use_container_width=True) popover = container_columns[1].popover(expender_title, use_container_width=True)
# media_file_popover_container(popover, url, msg_ctx, file_name, file_size_str, src_link) # media_file_popover_container(popover, url, msg_ctx, file_name, file_size_str, src_link)
popover_columns = popover.columns([1, 3, 1]) popover_columns = popover.columns([1, 3, 1])
video_holder = popover_columns[0].empty() video_holder = popover_columns[0].empty()
if video_holder.button("Preview", key=f"videoBtn{url}", use_container_width=True): if video_holder.button("Preview", key=f"videoBtn{url}{index}", use_container_width=True):
video_holder.empty() video_holder.empty()
if url: p_url = url if url else "./static/404.webm"
video_holder.video(url) mime_type = mime_type if mime_type else "video/webm"
else: video_holder.video(p_url, autoplay=True, format=mime_type)
video_holder.video('./static/404.webm', format="video/webm") popover_columns[1].markdown(f"{msg_ctx}")
popover_columns[1].markdown(f'{msg_ctx}') popover_columns[1].markdown(f"**{file_name}**")
popover_columns[1].markdown(f'**{file_name}**') popover_columns[1].markdown(f"File Size: *{file_size_str}*")
popover_columns[1].markdown(f'文件大小:*{file_size_str}*') popover_columns[2].link_button("Download Link", url, use_container_width=True)
popover_columns[2].link_button('Download Link', url, use_container_width=True) popover_columns[2].link_button("🔗Telegram Link", src_link, use_container_width=True)
popover_columns[2].link_button('🔗Telegram Link', src_link, use_container_width=True)
@st.experimental_fragment @st.experimental_fragment
def show_search_res(res: dict[str, any]): def show_search_res(res: dict[str, any]):
@ -108,7 +132,7 @@ def loop():
return return
sign_token = "" sign_token = ""
try: try:
sign_token = res['client']['sign'] sign_token = res["client"]["sign"]
except Exception as err: except Exception as err:
pass pass
st.session_state.search_res_select_list = [False] * len(search_res_list) st.session_state.search_res_select_list = [False] * len(search_res_list)
@ -120,36 +144,37 @@ def loop():
file_size = 0 file_size = 0
download_url = "" download_url = ""
src_link = "" src_link = ""
mime_type = ""
try: try:
src_link = v['src_tg_link'] src_link = v["src_tg_link"]
msg_ctx = v['message'] msg_ctx = v["message"]
msg_id = str(v['id']) msg_id = str(v["id"])
doc = v['media']['document'] doc = v["media"]["document"]
file_size = doc['size'] mime_type = doc["mime_type"]
file_size = doc["size"]
if doc is not None: if doc is not None:
for attr in doc['attributes']: for attr in doc["attributes"]:
file_name = attr.get('file_name') file_name = attr.get("file_name")
if file_name != "" and file_name is not None: if file_name != "" and file_name is not None:
break break
if file_name == "" or file_name is None: if file_name == "" or file_name is None:
file_name = "Can not get file name" file_name = "Can not get file name"
download_url = v['download_url'] download_url = v["download_url"]
download_url += f'?sign={sign_token}' download_url += f"?sign={sign_token}"
url_list.append(download_url)
except Exception as err: except Exception as err:
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
media_file_res_container( logging.warning(f"{err=},{traceback.format_exc()}")
i, msg_ctx, file_name, file_size, download_url, src_link) url_list.append(download_url)
media_file_res_container(i, msg_ctx, file_name, file_size, download_url, src_link, mime_type)
page_switch_render() page_switch_render()
show_text = "" show_text = ""
select_list = st.session_state.search_res_select_list select_list = st.session_state.search_res_select_list
for i in range(len(select_list)): for i in range(len(select_list)):
if select_list[i]: if select_list[i]:
show_text = show_text + url_list[i] + '\n' show_text = show_text + url_list[i] + "\n"
st.text_area("链接", value=show_text) st.text_area("Links", value=show_text)
show_search_res(search_res) show_search_res(search_res)
do_search_req() do_search_req()