fix: fix some bugs & optimze ui

This commit is contained in:
hehesheng 2024-06-01 15:15:33 +08:00
parent f6c8f406cc
commit f48a35ad17
7 changed files with 135 additions and 92 deletions

View File

@ -299,8 +299,8 @@ class TgFileSystemClient(object):
self.worker_routines.append(worker_routine)
if len(self.client_param.whitelist_chat) > 0:
self._register_update_event(from_users=self.client_param.whitelist_chat)
# await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat()))
await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat2()))
await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat()))
# await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat2()))
async def stop(self) -> None:
await self.client.loop.create_task(self._cancel_tasks())
@ -322,6 +322,7 @@ class TgFileSystemClient(object):
if len(self.db.get_msg_by_unique_id(self.db.generate_unique_id_by_msg(self.me, msg))) != 0:
continue
self.db.insert_by_message(self.me, msg)
logger.info(f"{chat_id} quit cache task.")
async def _cache_whitelist_chat(self):
for chat_id in self.client_param.whitelist_chat:
@ -343,6 +344,7 @@ class TgFileSystemClient(object):
else:
async for msg in self.client.iter_messages(chat_id):
self.db.insert_by_message(self.me, msg)
logger.info(f"{chat_id} quit cache task.")
@_acheck_before_call

View File

@ -55,7 +55,7 @@ class UserManager(object):
def get_all_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute(
"SELECT * FROM message WHERE chat_id == ? ORDER BY date_time DESC", (chat_id,))
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC", (chat_id,))
return res.fetchall()
def get_msg_by_chat_id_and_keyword(self, chat_id: int, keyword: str, limit: int = 10, offset: int = 0, inc: bool = False, ignore_case: bool = False) -> list[any]:
@ -63,23 +63,24 @@ class UserManager(object):
if ignore_case:
keyword_condition = "LOWER(msg_ctx) LIKE LOWER('%{key}%') OR LOWER(file_name) LIKE LOWER('%{key}%')"
keyword_condition = keyword_condition.format(key=keyword)
execute_script = f"SELECT * FROM message WHERE chat_id == {chat_id} AND ({keyword_condition}) ORDER BY date_time {'' if inc else 'DESC '}LIMIT {limit} OFFSET {offset}"
execute_script = f"SELECT * FROM message WHERE chat_id = {chat_id} AND ({keyword_condition}) ORDER BY date_time {'' if inc else 'DESC '}LIMIT {limit} OFFSET {offset}"
logger.info(f"{execute_script=}")
res = self.cur.execute(execute_script)
return res
def get_oldest_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute(
"SELECT * FROM message WHERE chat_id == ? ORDER BY date_time LIMIT 1", (chat_id,))
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time LIMIT 1", (chat_id,))
return res.fetchall()
def get_newest_msg_by_chat_id(self, chat_id: int) -> list[any]:
res = self.cur.execute(
"SELECT * FROM message WHERE chat_id == ? ORDER BY date_time DESC LIMIT 1", (chat_id,))
"SELECT * FROM message WHERE chat_id = ? ORDER BY date_time DESC LIMIT 1", (chat_id,))
return res.fetchall()
def get_msg_by_unique_id(self, unique_id: str) -> list[any]:
res = self.cur.execute(
"SELECT * FROM message WHERE unique_id == ? ORDER BY date_time DESC LIMIT 1", (unique_id,))
"SELECT * FROM message WHERE unique_id = ? ORDER BY date_time DESC LIMIT 1", (unique_id,))
return res.fetchall()
@unique
@ -178,7 +179,7 @@ class UserManager(object):
"CREATE TABLE user(client_id primary key, username, phone, tg_user_id, last_login_time)")
if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0:
self.cur.execute(
"CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx, mime_type, file_name, msg_js, date_time int)")
"CREATE TABLE message(unique_id varchar(64) primary key, user_id int NOT NULL, chat_id int NOT NULL, msg_id int NOT NULL, msg_type varchar(64), msg_ctx text, mime_type text, file_name text, msg_js text, date_time int NOT NULL)")
if __name__ == "__main__":

View File

@ -2,6 +2,7 @@ import asyncio
import json
import os
import logging
from urllib.parse import quote
import uvicorn
from fastapi import FastAPI, status, Request
@ -22,6 +23,9 @@ clients_mgr: TgFileSystemClientManager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.handlers.TimedRotatingFileHandler):
handler.suffix = "%Y-%m-%d"
global clients_mgr
param = configParse.get_TgToFileSystemParameter()
clients_mgr = TgFileSystemClientManager(param)
@ -61,11 +65,15 @@ async def search_tg_file_list(body: TgToFileListRequestBody):
msg_info = json.loads(item)
file_name = apiutils.get_message_media_name_from_dict(msg_info)
msg_info['file_name'] = file_name
msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{msg_info.get('id')}/{file_name}?sign={body.token}"
msg_info['download_url'] = f"{param.base.exposed_url}/tg/api/v1/file/get/{body.chat_id}/{msg_info.get('id')}/{file_name}"
msg_info['src_tg_link'] = f"https://t.me/c/1216816802/21206"
res_dict.append(msg_info)
client_dict = json.loads(client.to_json())
client_dict['sign'] = body.token
response_dict = {
"client": json.loads(client.to_json()),
"client": client_dict,
"type": res_type,
"length": len(res_dict),
"list": res_dict,
@ -141,7 +149,7 @@ async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Requ
maybe_file_type = mime_type.split("/")[-1]
file_name = f"{chat_id}.{msg_id}.{maybe_file_type}"
headers[
"Content-Disposition"] = f'inline; filename="{file_name}"'
"Content-Disposition"] = f"inline; filename*=utf-8'{quote(file_name)}'"
if range_header is not None:
start, end = apiutils.get_range_header(range_header, file_size)
@ -183,7 +191,13 @@ async def login_new_tg_file_client():
@app.get("/tg/api/v1/client/link_convert")
@apiutils.atimeit
async def convert_tg_msg_link_media_stream(link: str, token: str):
async def convert_tg_msg_link_media_stream(link: str, sign: str):
raise NotImplementedError
@app.get("/tg/api/v1/client/profile_photo")
@apiutils.atimeit
async def get_tg_chat_profile_photo(chat_id: int, sign: str):
raise NotImplementedError

View File

@ -2,72 +2,68 @@ import sys
import os
import json
sys.path.append(os.getcwd())
sys.path.append(os.getcwd() + "/../")
import streamlit as st
import qrcode
import pandas
import requests
import configParse
import utils
# qr = qrcode.make("https://www.baidu.com")
# st.image(qrcode.make("https://www.baidu.com").get_image())
param = configParse.get_TgToFileSystemParameter()
background_server_url = f"{param.base.exposed_url}/tg/api/v1/file/search"
st.set_page_config(page_title="TgToolbox", page_icon='🕹️', layout='wide')
if 'page_index' not in st.session_state:
st.session_state.page_index = 1
if 'search_input' not in st.session_state:
st.session_state.search_input = ""
if 'last_search_input' not in st.session_state:
st.session_state.last_search_input = ""
if 'search_clicked' not in st.session_state:
st.session_state.search_clicked = False
if 'is_order' not in st.session_state:
st.session_state.is_order = False
param = configParse.get_TgToFileSystemParameter()
background_server_url = f"{param.base.exposed_url}/tg/api/v1/file/search"
if 'force_skip' not in st.session_state:
st.session_state.force_skip = False
if 'search_key' not in st.query_params:
st.query_params.search_key = ""
if 'is_order' not in st.query_params:
st.query_params.is_order = False
if 'search_res_limit' not in st.query_params:
st.query_params.search_res_limit = "10"
@st.experimental_fragment
def search_input_container():
st.session_state.search_input = st.text_input("搜索🔎", value=st.query_params.get(
'search') if st.query_params.get('search') is not None else "")
def search_container():
st.query_params.search_key = st.text_input("**搜索🔎**", value=st.query_params.search_key)
columns = st.columns([7, 1])
with columns[0]:
st.query_params.search_res_limit = str(st.number_input(
"**每页结果**", min_value=1, max_value=100, value=int(st.query_params.search_res_limit), format="%d"))
with columns[1]:
st.text("排序")
st.query_params.is_order = st.toggle("顺序", value=utils.strtobool(st.query_params.is_order))
search_container()
search_input_container()
col1, col2 = st.columns(2)
search_res_limit = st.number_input(
"每页结果", min_value=1, max_value=100, value=10, format="%d")
columns = st.columns([7, 1])
with columns[0]:
if st.button("Search"):
st.session_state.page_index = 1
st.session_state.search_clicked = True
with columns[1]:
st.session_state.is_order = st.checkbox("顺序")
if st.session_state.search_input == "" or (st.session_state.search_input != st.session_state.last_search_input and not st.session_state.search_clicked):
st.session_state.search_clicked = False
search_clicked = st.button('Search', type='primary', use_container_width=True)
if not st.session_state.force_skip and (not search_clicked or st.query_params.search_key == "" or st.query_params.search_key is None):
st.stop()
st.session_state.last_search_input = st.session_state.search_input
st.query_params.search = st.session_state.search_input
if not st.session_state.force_skip:
st.session_state.page_index = 1
if st.session_state.force_skip:
st.session_state.force_skip = False
@st.experimental_fragment
def do_search_req():
offset_index = (st.session_state.page_index - 1) * search_res_limit
is_order = st.session_state.is_order
search_limit = int(st.query_params.search_res_limit)
offset_index = (st.session_state.page_index - 1) * search_limit
is_order = utils.strtobool(st.query_params.is_order)
req_body = {
"token": param.web.token,
"search": f"{st.session_state.search_input}",
"search": f"{st.query_params.search_key}",
"chat_id": param.web.chat_id[0],
"index": offset_index,
"length": search_res_limit,
"length": search_limit,
"refresh": False,
"inner": False,
"inc": is_order,
@ -81,11 +77,11 @@ def do_search_req():
def page_switch_render():
columns = st.columns(3)
with columns[0]:
pre_button = st.button("Prev", use_container_width=True)
if pre_button:
if st.button("Prev", use_container_width=True):
st.session_state.page_index = st.session_state.page_index - 1
st.session_state.page_index = max(
st.session_state.page_index, 1)
st.session_state.force_skip = True
st.rerun()
with columns[1]:
# st.text(f"{st.session_state.page_index}")
@ -93,31 +89,40 @@ def do_search_req():
f"<p style='text-align: center;'>{st.session_state.page_index}</p>", unsafe_allow_html=True)
# st.markdown(f"<input type='number' style='text-align: center;' value={st.session_state.page_index}>", unsafe_allow_html=True)
with columns[2]:
next_button = st.button("Next", use_container_width=True)
if next_button:
if st.button("Next", use_container_width=True):
st.session_state.page_index = st.session_state.page_index + 1
st.session_state.force_skip = True
st.rerun()
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: str, url: str):
def media_file_res_container(index: int, msg_ctx: str, file_name: str, file_size: int, url: str):
file_size_str = f"{file_size/1024/1024:.2f}MB"
container = st.container()
container_columns = container.columns([1, 99])
st.session_state.search_res_select_list[index] = container_columns[0].checkbox(
url, label_visibility='collapsed')
"search_res_checkbox_" + str(index), label_visibility='collapsed')
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size}*"
expender_title = f"{(msg_ctx if len(msg_ctx) < 103 else msg_ctx[:100] + '...')} &mdash; *{file_size_str}*"
popover = container_columns[1].popover(expender_title, use_container_width=True)
popover_columns = popover.columns([1, 3])
popover_columns[0].video(url)
if url:
popover_columns[0].video(url)
else:
popover_columns[0].video('./static/404.webm', format="video/webm")
popover_columns[1].markdown(f'{msg_ctx}')
popover_columns[1].markdown(f'**{file_name}**')
popover_columns[1].markdown(f'文件大小:*{file_size}*')
popover_columns[1].page_link(url, label='Download Link', icon='⬇️')
popover_columns[1].markdown(f'文件大小:*{file_size_str}*')
popover_columns[1].link_button('Download Link', url)
@st.experimental_fragment
def show_search_res():
search_res_list = search_res['list']
if len(search_res_list) == 0:
def show_search_res(res: dict[str, any]):
sign_token = ""
try:
sign_token = res['client']['sign']
except Exception as err:
pass
search_res_list = res.get('list')
if search_res_list is None or len(search_res_list) == 0:
st.info("No result")
page_switch_render()
st.stop()
@ -125,28 +130,29 @@ def do_search_req():
url_list = []
for i in range(len(search_res_list)):
v = search_res_list[i]
msg_ctx = v['message']
doc = None
msg_ctx= ""
file_name = None
file_size = 0
msg_id = str(v['id'])
download_url = v['download_url']
url_list.append(download_url)
download_url = ""
try:
msg_ctx = v['message']
msg_id = str(v['id'])
doc = v['media']['document']
file_size = doc['size']
except:
pass
file_size_str = f"{file_size/1024/1024:.2f}MB"
file_name = None
if doc is not None:
for attr in doc['attributes']:
file_name = attr.get('file_name')
if file_name != "" and file_name is not None:
break
if file_name == "" or file_name is None:
file_name = "Can not get file name"
if doc is not None:
for attr in doc['attributes']:
file_name = attr.get('file_name')
if file_name != "" and file_name is not None:
break
if file_name == "" or file_name is None:
file_name = "Can not get file name"
download_url = v['download_url']
download_url += f'?sign={sign_token}'
url_list.append(download_url)
except Exception as err:
msg_ctx = f"{err=}\r\n\r\n" + msg_ctx
media_file_res_container(
i, msg_ctx, file_name, file_size_str, download_url)
i, msg_ctx, file_name, file_size, download_url)
page_switch_render()
show_text = ""
@ -156,7 +162,7 @@ def do_search_req():
show_text = show_text + url_list[i] + '\n'
st.text_area("链接", value=show_text)
show_search_res()
show_search_res(search_res)
do_search_req()

BIN
frontend/static/404.webm Normal file

Binary file not shown.

14
frontend/utils.py Normal file
View File

@ -0,0 +1,14 @@
def strtobool (val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return 1
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return 0
else:
raise ValueError("invalid truth value %r" % (val,))

View File

@ -14,13 +14,10 @@ if not os.path.exists(os.path.dirname(__file__) + '/logs'):
os.mkdir(os.path.dirname(__file__) + '/logs')
with open('logging_config.yaml', 'r') as f:
logging.config.dictConfig(yaml.safe_load(f.read()))
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.handlers.TimedRotatingFileHandler):
handler.suffix = "%Y-%m-%d"
LOGGING_CONFIG["formatters"]["default"]["fmt"] = "[%(levelname)s] %(asctime)s [uvicorn.default]:%(message)s"
LOGGING_CONFIG["formatters"]["access"]["fmt"] = '[%(levelname)s]%(asctime)s [uvicorn.access]:%(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["handlers"]["timed_rotating_file"] = {
LOGGING_CONFIG["handlers"]["timed_rotating_api_file"] = {
"class": "logging.handlers.TimedRotatingFileHandler",
"filename": "logs/app.log",
"when": "midnight",
@ -30,8 +27,12 @@ LOGGING_CONFIG["handlers"]["timed_rotating_file"] = {
"formatter": "default",
"encoding": "utf-8",
}
LOGGING_CONFIG["loggers"]["uvicorn"]["handlers"].append("timed_rotating_file")
LOGGING_CONFIG["loggers"]["uvicorn.access"]["handlers"].append("timed_rotating_file")
LOGGING_CONFIG["loggers"]["uvicorn"]["handlers"].append("timed_rotating_api_file")
LOGGING_CONFIG["loggers"]["uvicorn.access"]["handlers"].append("timed_rotating_api_file")
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.handlers.TimedRotatingFileHandler):
handler.suffix = "%Y-%m-%d"
logger = logging.getLogger(__file__.split("/")[-1])
@ -39,14 +40,19 @@ if __name__ == "__main__":
param = configParse.get_TgToFileSystemParameter()
async def run_web_server():
cmd = f"streamlit run {os.getcwd()}/frontend/home.py --server.port {param.web.port}"
proc = await asyncio.create_subprocess_shell(cmd, stdout=asyncio.subprocess.PIPE,
proc = await asyncio.create_subprocess_shell(cmd, cwd=f"{os.path.dirname(__file__)}/frontend", stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE)
stdout, stderr = await proc.communicate()
async def loop_get_cli_pipe(p, suffix = ""):
while True:
stdp = await p.readline()
if stdp:
logger.info(f"[web:{suffix}]{stdp.decode()[:-1]}")
else:
break
stdout_task = asyncio.create_task(loop_get_cli_pipe(proc.stdout, "out"))
stderr_task = asyncio.create_task(loop_get_cli_pipe(proc.stderr, "err"))
await asyncio.gather(*[stdout_task, stderr_task])
logger.info(f'[{cmd!r} exited with {proc.returncode}]')
if stdout:
logger.info(f'[stdout]\n{stdout.decode()}')
if stderr:
logger.info(f'[stderr]\n{stderr.decode()}')
if param.web.enable:
ret = os.fork()
if ret == 0: