fix: fix some bugs & optimze ui

This commit is contained in:
hehesheng 2024-06-01 15:15:33 +08:00
parent f6c8f406cc
commit f48a35ad17
7 changed files with 135 additions and 92 deletions

View File

@ -299,8 +299,8 @@ class TgFileSystemClient(object):
self.worker_routines.append(worker_routine) self.worker_routines.append(worker_routine)
if len(self.client_param.whitelist_chat) > 0: if len(self.client_param.whitelist_chat) > 0:
self._register_update_event(from_users=self.client_param.whitelist_chat) self._register_update_event(from_users=self.client_param.whitelist_chat)
# await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat())) await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat()))
await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat2())) # await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat2()))
async def stop(self) -> None: async def stop(self) -> None:
await self.client.loop.create_task(self._cancel_tasks()) await self.client.loop.create_task(self._cancel_tasks())
@ -322,6 +322,7 @@ class TgFileSystemClient(object):
if len(self.db.get_msg_by_unique_id(self.db.generate_unique_id_by_msg(self.me, msg))) != 0: if len(self.db.get_msg_by_unique_id(self.db.generate_unique_id_by_msg(self.me, msg))) != 0:
continue continue
self.db.insert_by_message(self.me, msg) self.db.insert_by_message(self.me, msg)
logger.info(f"{chat_id} quit cache task.")
async def _cache_whitelist_chat(self): async def _cache_whitelist_chat(self):
for chat_id in self.client_param.whitelist_chat: for chat_id in self.client_param.whitelist_chat:
@ -343,6 +344,7 @@ class TgFileSystemClient(object):
else: else:
async for msg in self.client.iter_messages(chat_id): async for msg in self.client.iter_messages(chat_id):
self.db.insert_by_message(self.me, msg) self.db.insert_by_message(self.me, msg)
logger.info(f"{chat_id} quit cache task.")
@_acheck_before_call @_acheck_before_call

View File

@ -55,7 +55,7 @@ class UserManager(object):
def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]: def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute( res = self.cur.execute(
"SELECT * FROM message WHERE chat_id == ? ORDER BY date_time DESC", (chat_id,)) "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC", (chat_id,))
return res.fetchall() return res.fetchall()
def get_msg_by_chat_id_and_keyword(self, chat_id: int, keyword: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]: def get_msg_by_chat_id_and_keyword(self, chat_id: int, keyword: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]:
@ -63,23 +63,24 @@ class UserManager(object):
if ignore_case: if ignore_case:
keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')" keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')"
keyword_condition = keyword_condition.format(key=keyword) keyword_condition = keyword_condition.format(key=keyword)
execute_script = f"SELECT * FROM message WHERE chat_id == {chat_id} AND ({keyword_condition}) ORDER BY date_time {'' if inc else 'DESC '}LIMIT {limit} OFFSET {offset}" execute_script = f"SELECT * FROM message WHERE chat_id = {chat_id} AND ({keyword_condition}) ORDER BY date_time {'' if inc else 'DESC '}LIMIT {limit} OFFSET {offset}"
logger.info(f"{execute_script=}")
res = self.cur.execute(execute_script) res = self.cur.execute(execute_script)
return res return res
def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]: def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute( res = self.cur.execute(
"SELECT * FROM message WHERE chat_id == ? ORDER BY date_time LIMIT 1", (chat_id,)) "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time LIMIT 1", (chat_id,))
return res.fetchall() return res.fetchall()
def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]: def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute( res = self.cur.execute(
"SELECT * FROM message WHERE chat_id == ? ORDER BY date_time DESC LIMIT 1", (chat_id,)) "SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC LIMIT 1", (chat_id,))
return res.fetchall() return res.fetchall()
def get_msg_by_unique_id(self, unique_id: str) -> list[any]: def get_msg_by_unique_id(self, unique_id: str) -> list[any]:
res = self.cur.execute( res = self.cur.execute(
"SELECT * FROM message WHERE unique_id == ? ORDER BY date_time DESC LIMIT 1", (unique_id,)) "SELECT * FROM message WHERE unique_id = ? ORDER BY date_time DESC LIMIT 1", (unique_id,))
return res.fetchall() return res.fetchall()
@unique @unique
@ -178,7 +179,7 @@ class UserManager(object):
"CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)") "CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)")
if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0: if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0:
self.cur.execute( self.cur.execute(
"CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx, mime_type, file_name, msg_js, date_time int)") "CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx text, mime_type text, file_name text, msg_js text, date_time int NOT NULL)")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,6 +2,7 @@ import asyncio
import json import json
import os import os
import logging import logging
from urllib.parse import quote
import uvicorn import uvicorn
from fastapi import FastAPI, status, Request from fastapi import FastAPI, status, Request
@ -22,6 +23,9 @@ clients_mgr: TgFileSystemClientManager = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.handlers.TimedRotatingFileHandler):
handler.suffix = "%Y-%m-%d"
global clients_mgr global clients_mgr
param = configParse.get_TgToFileSystemParameter() param = configParse.get_TgToFileSystemParameter()
clients_mgr = TgFileSystemClientManager(param) clients_mgr = TgFileSystemClientManager(param)
@ -61,11 +65,15 @@ async def search_tg_file_list(body: TgToFileListRequestBody):
msg_info = json.loads(item) msg_info = json.loads(item)
file_name = apiutils.get_message_media_name_from_dict(msg_info) file_name = apiutils.get_message_media_name_from_dict(msg_info)
msg_info['file_name'] = file_name msg_info['file_name'] = file_name
msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{msg_info.get('id')}/{file_name}?sign={body.token}" msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{msg_info.get('id')}/{file_name}"
msg_info['src_tg_link'] = f"https://t.me/c/1216816802/21206"
res_dict.append(msg_info) res_dict.append(msg_info)
client_dict = json.loads(client.to_json())
client_dict['sign'] = body.token
response_dict = { response_dict = {
"client": json.loads(client.to_json()), "client": client_dict,
"type": res_type, "type": res_type,
"length": len(res_dict), "length": len(res_dict),
"list": res_dict, "list": res_dict,
@ -141,7 +149,7 @@ async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Requ
maybe_file_type = mime_type.split("/")[-1] maybe_file_type = mime_type.split("/")[-1]
file_name = f"{chat_id}.{msg_id}.{maybe_file_type}" file_name = f"{chat_id}.{msg_id}.{maybe_file_type}"
headers[ headers[
"Content-Disposition"] = f'inline; filename="{file_name}"' "Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'"
if range_header is not None: if range_header is not None:
start, end = apiutils.get_range_header(range_header, file_size) start, end = apiutils.get_range_header(range_header, file_size)
@ -183,7 +191,13 @@ async def login_new_tg_file_client():
@app.get("/tg/api/v1/client/link_convert") @app.get("/tg/api/v1/client/link_convert")
@apiutils.atimeit @apiutils.atimeit
async def convert_tg_msg_link_media_stream(link: str, token: str): async def convert_tg_msg_link_media_stream(link: str, sign: str):
raise NotImplementedError
@app.get("/tg/api/v1/client/profile_photo")
@apiutils.atimeit
async def get_tg_chat_profile_photo(chat_id: int, sign: str):
raise NotImplementedError raise NotImplementedError

View File

@ -2,72 +2,68 @@ import sys
import os import os
import json import json
sys.path.append(os.getcwd()) sys.path.append(os.getcwd() + "/../")
import streamlit as st import streamlit as st
import qrcode import qrcode
import pandas import pandas
import requests import requests
import configParse import configParse
import utils
# qr = qrcode.make("https://www.baidu.com") # qr = qrcode.make("https://www.baidu.com")
# st.image(qrcode.make("https://www.baidu.com").get_image()) # st.image(qrcode.make("https://www.baidu.com").get_image())
param = configParse.get_TgToFileSystemParameter()
background_server_url = f"{param.base.exposed_url}/tg/api/v1/file/search"
st.set_page_config(page_title="TgToolbox", page_icon='🕹️', layout='wide') st.set_page_config(page_title="TgToolbox", page_icon='🕹️', layout='wide')
if 'page_index' not in st.session_state: if 'page_index' not in st.session_state:
st.session_state.page_index = 1 st.session_state.page_index = 1
if 'search_input' not in st.session_state: if 'force_skip' not in st.session_state:
st.session_state.search_input = "" st.session_state.force_skip = False
if 'last_search_input' not in st.session_state:
st.session_state.last_search_input = ""
if 'search_clicked' not in st.session_state:
st.session_state.search_clicked = False
if 'is_order' not in st.session_state:
st.session_state.is_order = False
param = configParse.get_TgToFileSystemParameter()
background_server_url = f"{param.base.exposed_url}/tg/api/v1/file/search"
if 'search_key' not in st.query_params:
st.query_params.search_key = ""
if 'is_order' not in st.query_params:
st.query_params.is_order = False
if 'search_res_limit' not in st.query_params:
st.query_params.search_res_limit = "10"
@st.experimental_fragment @st.experimental_fragment
def search_input_container(): def search_container():
st.session_state.search_input = st.text_input("搜索🔎", value=st.query_params.get( st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key)
'search') if st.query_params.get('search') is not None else "") columns = st.columns([7, 1])
with columns[0]:
st.query_params.search_res_limit = str(st.number_input(
"**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d"))
with columns[1]:
st.text("排序")
st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order))
search_container()
search_input_container() search_clicked = st.button('Search', type='primary', use_container_width=True)
if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None):
col1, col2 = st.columns(2)
search_res_limit = st.number_input(
"每页结果", min_value=1, max_value=100, value=10, format="%d")
columns = st.columns([7, 1])
with columns[0]:
if st.button("Search"):
st.session_state.page_index = 1
st.session_state.search_clicked = True
with columns[1]:
st.session_state.is_order = st.checkbox("顺序")
if st.session_state.search_input == "" or (st.session_state.search_input != st.session_state.last_search_input and not st.session_state.search_clicked):
st.session_state.search_clicked = False
st.stop() st.stop()
st.session_state.last_search_input = st.session_state.search_input if not st.session_state.force_skip:
st.query_params.search = st.session_state.search_input st.session_state.page_index = 1
if st.session_state.force_skip:
st.session_state.force_skip = False
@st.experimental_fragment @st.experimental_fragment
def do_search_req(): def do_search_req():
offset_index = (st.session_state.page_index - 1) * search_res_limit search_limit = int(st.query_params.search_res_limit)
is_order = st.session_state.is_order offset_index = (st.session_state.page_index - 1) * search_limit
is_order = utils.strtobool(st.query_params.is_order)
req_body = { req_body = {
"token": param.web.token, "token": param.web.token,
"search": f"{st.session_state.search_input}", "search": f"{st.query_params.search_key}",
"chat_id": param.web.chat_id[0], "chat_id": param.web.chat_id[0],
"index": offset_index, "index": offset_index,
"length": search_res_limit, "length": search_limit,
"refresh": False, "refresh": False,
"inner": False, "inner": False,
"inc": is_order, "inc": is_order,
@ -81,11 +77,11 @@ def do_search_req():
def page_switch_render(): def page_switch_render():
columns = st.columns(3) columns = st.columns(3)
with columns[0]: with columns[0]:
pre_button = st.button("Prev", use_container_width=True) if st.button("Prev", use_container_width=True):
if pre_button:
st.session_state.page_index = st.session_state.page_index - 1 st.session_state.page_index = st.session_state.page_index - 1
st.session_state.page_index = max( st.session_state.page_index = max(
st.session_state.page_index, 1) st.session_state.page_index, 1)
st.session_state.force_skip = True
st.rerun() st.rerun()
with columns[1]: with columns[1]:
# st.text(f"{st.session_state.page_index}") # st.text(f"{st.session_state.page_index}")
@ -93,31 +89,40 @@ def do_search_req():
f"<p style='text-align: center;'>{st.session_state.page_index}</p>", unsafe_allow_html=True) f"<p style='text-align: center;'>{st.session_state.page_index}</p>", unsafe_allow_html=True)
# st.markdown(f"<input type='number' style='text-align: center;' value={st.session_state.page_index}>", unsafe_allow_html=True) # st.markdown(f"<input type='number' style='text-align: center;' value={st.session_state.page_index}>", unsafe_allow_html=True)
with columns[2]: with columns[2]:
next_button = st.button("Next", use_container_width=True) if st.button("Next", use_container_width=True):
if next_button:
st.session_state.page_index = st.session_state.page_index + 1 st.session_state.page_index = st.session_state.page_index + 1
st.session_state.force_skip = True
st.rerun() st.rerun()
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: str, url: str): def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str):
file_size_str = f"{file_size/1024/1024:.2f}MB"
container = st.container() container = st.container()
container_columns = container.columns([1, 99]) container_columns = container.columns([1, 99])
st.session_state.search_res_select_list[index] = container_columns[0].checkbox( st.session_state.search_res_select_list[index] = container_columns[0].checkbox(
url, label_visibility='collapsed') "search_res_checkbox_" + str(index), label_visibility='collapsed')
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size}*" expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*"
popover = container_columns[1].popover(expender_title, use_container_width=True) popover = container_columns[1].popover(expender_title, use_container_width=True)
popover_columns = popover.columns([1, 3]) popover_columns = popover.columns([1, 3])
popover_columns[0].video(url) if url:
popover_columns[0].video(url)
else:
popover_columns[0].video('./static/404.webm', format="video/webm")
popover_columns[1].markdown(f'{msg_ctx}') popover_columns[1].markdown(f'{msg_ctx}')
popover_columns[1].markdown(f'**{file_name}**') popover_columns[1].markdown(f'**{file_name}**')
popover_columns[1].markdown(f'文件大小:*{file_size}*') popover_columns[1].markdown(f'文件大小:*{file_size_str}*')
popover_columns[1].page_link(url, label='Download Link', icon='⬇️') popover_columns[1].link_button('Download Link', url)
@st.experimental_fragment @st.experimental_fragment
def show_search_res(): def show_search_res(res: dict[str, any]):
search_res_list = search_res['list'] sign_token = ""
if len(search_res_list) == 0: try:
sign_token = res['client']['sign']
except Exception as err:
pass
search_res_list = res.get('list')
if search_res_list is None or len(search_res_list) == 0:
st.info("No result") st.info("No result")
page_switch_render() page_switch_render()
st.stop() st.stop()
@ -125,28 +130,29 @@ def do_search_req():
url_list = [] url_list = []
for i in range(len(search_res_list)): for i in range(len(search_res_list)):
v = search_res_list[i] v = search_res_list[i]
msg_ctx = v['message'] msg_ctx= ""
doc = None file_name = None
file_size = 0 file_size = 0
msg_id = str(v['id']) download_url = ""
download_url = v['download_url']
url_list.append(download_url)
try: try:
msg_ctx = v['message']
msg_id = str(v['id'])
doc = v['media']['document'] doc = v['media']['document']
file_size = doc['size'] file_size = doc['size']
except: if doc is not None:
pass for attr in doc['attributes']:
file_size_str = f"{file_size/1024/1024:.2f}MB" file_name = attr.get('file_name')
file_name = None if file_name != "" and file_name is not None:
if doc is not None: break
for attr in doc['attributes']: if file_name == "" or file_name is None:
file_name = attr.get('file_name') file_name = "Can not get file name"
if file_name != "" and file_name is not None: download_url = v['download_url']
break download_url += f'?sign={sign_token}'
if file_name == "" or file_name is None: url_list.append(download_url)
file_name = "Can not get file name" except Exception as err:
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
media_file_res_container( media_file_res_container(
i, msg_ctx, file_name, file_size_str, download_url) i, msg_ctx, file_name, file_size, download_url)
page_switch_render() page_switch_render()
show_text = "" show_text = ""
@ -156,7 +162,7 @@ def do_search_req():
show_text = show_text + url_list[i] + '\n' show_text = show_text + url_list[i] + '\n'
st.text_area("链接", value=show_text) st.text_area("链接", value=show_text)
show_search_res() show_search_res(search_res)
do_search_req() do_search_req()

BIN
frontend/static/404.webm Normal file

Binary file not shown.

14
frontend/utils.py Normal file
View File

@ -0,0 +1,14 @@
def strtobool (val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return 1
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return 0
else:
raise ValueError("invalid truth value %r" % (val,))

View File

@ -14,13 +14,10 @@ if not os.path.exists(os.path.dirname(__file__) + '/logs'):
os.mkdir(os.path.dirname(__file__) + '/logs') os.mkdir(os.path.dirname(__file__) + '/logs')
with open('logging_config.yaml', 'r') as f: with open('logging_config.yaml', 'r') as f:
logging.config.dictConfig(yaml.safe_load(f.read())) logging.config.dictConfig(yaml.safe_load(f.read()))
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.handlers.TimedRotatingFileHandler):
handler.suffix = "%Y-%m-%d"
LOGGING_CONFIG["formatters"]["default"]["fmt"] = "[%(levelname)s] %(asctime)s [uvicorn.default]:%(message)s" LOGGING_CONFIG["formatters"]["default"]["fmt"] = "[%(levelname)s] %(asctime)s [uvicorn.default]:%(message)s"
LOGGING_CONFIG["formatters"]["access"]["fmt"] = '[%(levelname)s]%(asctime)s [uvicorn.access]:%(client_addr)s - "%(request_line)s" %(status_code)s' LOGGING_CONFIG["formatters"]["access"]["fmt"] = '[%(levelname)s]%(asctime)s [uvicorn.access]:%(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["handlers"]["timed_rotating_file"] = { LOGGING_CONFIG["handlers"]["timed_rotating_api_file"] = {
"class": "logging.handlers.TimedRotatingFileHandler", "class": "logging.handlers.TimedRotatingFileHandler",
"filename": "logs/app.log", "filename": "logs/app.log",
"when": "midnight", "when": "midnight",
@ -30,8 +27,12 @@ LOGGING_CONFIG["handlers"]["timed_rotating_file"] = {
"formatter": "default", "formatter": "default",
"encoding": "utf-8", "encoding": "utf-8",
} }
LOGGING_CONFIG["loggers"]["uvicorn"]["handlers"].append("timed_rotating_file") LOGGING_CONFIG["loggers"]["uvicorn"]["handlers"].append("timed_rotating_api_file")
LOGGING_CONFIG["loggers"]["uvicorn.access"]["handlers"].append("timed_rotating_file") LOGGING_CONFIG["loggers"]["uvicorn.access"]["handlers"].append("timed_rotating_api_file")
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.handlers.TimedRotatingFileHandler):
handler.suffix = "%Y-%m-%d"
logger = logging.getLogger(__file__.split("/")[-1]) logger = logging.getLogger(__file__.split("/")[-1])
@ -39,14 +40,19 @@ if __name__ == "__main__":
param = configParse.get_TgToFileSystemParameter() param = configParse.get_TgToFileSystemParameter()
async def run_web_server(): async def run_web_server():
cmd = f"streamlit run {os.getcwd()}/frontend/home.py --server.port {param.web.port}" cmd = f"streamlit run {os.getcwd()}/frontend/home.py --server.port {param.web.port}"
proc = await asyncio.create_subprocess_shell(cmd, stdout=asyncio.subprocess.PIPE, proc = await asyncio.create_subprocess_shell(cmd, cwd=f"{os.path.dirname(__file__)}/frontend", stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE) stderr=asyncio.subprocess.PIPE)
stdout, stderr = await proc.communicate() async def loop_get_cli_pipe(p, suffix = ""):
while True:
stdp = await p.readline()
if stdp:
logger.info(f"[web:{suffix}]{stdp.decode()[:-1]}")
else:
break
stdout_task = asyncio.create_task(loop_get_cli_pipe(proc.stdout, "out"))
stderr_task = asyncio.create_task(loop_get_cli_pipe(proc.stderr, "err"))
await asyncio.gather(*[stdout_task, stderr_task])
logger.info(f'[{cmd!r} exited with {proc.returncode}]') logger.info(f'[{cmd!r} exited with {proc.returncode}]')
if stdout:
logger.info(f'[stdout]\n{stdout.decode()}')
if stderr:
logger.info(f'[stderr]\n{stderr.decode()}')
if param.web.enable: if param.web.enable:
ret = os.fork() ret = os.fork()
if ret == 0: if ret == 0: