TgToFileSystem/backend/TgFileSystemClient.py
2024-06-02 18:13:48 +08:00

397 lines
16 KiB
Python

import asyncio
import json
import time
import re
import rsa
import os
import functools
import traceback
import logging
from typing import Union, Optional
from telethon import TelegramClient, types, hints, events
from telethon.custom import QRLogin
from fastapi import Request
import configParse
from backend import apiutils
from backend.UserManager import UserManager
from backend.MediaCacheManager import MediaChunkHolder, MediaChunkHolderManager
logger = logging.getLogger(__file__.split("/")[-1])
class TgFileSystemClient(object):
MAX_WORKER_ROUTINE = 4
SINGLE_NET_CHUNK_SIZE = 512 * 1024 # 512kb
SINGLE_MEDIA_SIZE = 5 * 1024 * 1024 # 5mb
api_id: int
api_hash: str
session_name: str
proxy_param: dict[str, any]
client: TelegramClient
media_chunk_manager: MediaChunkHolderManager
dialogs_cache: Optional[hints.TotalList] = None
msg_cache: list[types.Message] = []
worker_routines: list[asyncio.Task] = []
qr_login: QRLogin | None = None
login_task: asyncio.Task | None = None
# rsa key
public_key: rsa.PublicKey
private_key: rsa.PrivateKey
# task should: (task_id, callabledFunc)
task_queue: asyncio.Queue
task_id: int = 0
me: Union[types.User, types.InputPeerUser]
# client config
client_param: configParse.TgToFileSystemParameter.ClientConfigPatameter
def __init__(
self,
session_name: str,
param: configParse.TgToFileSystemParameter,
db: UserManager,
) -> None:
self.api_id = param.tgApi.api_id
self.api_hash = param.tgApi.api_hash
self.session_name = session_name
self.proxy_param = (
{
"proxy_type": param.proxy.proxy_type,
"addr": param.proxy.addr,
"port": param.proxy.port,
}
if param.proxy.enable
else {}
)
self.public_key, self.private_key = rsa.newkeys(1024)
self.client_param = next(
(client_param for client_param in param.clients if client_param.token == session_name),
configParse.TgToFileSystemParameter.ClientConfigPatameter(),
)
self.task_queue = asyncio.Queue()
self.client = TelegramClient(
f"{os.path.dirname(__file__)}/db/{self.session_name}.session",
self.api_id,
self.api_hash,
proxy=self.proxy_param,
)
self.media_chunk_manager = MediaChunkHolderManager()
self.db = db
def __del__(self) -> None:
if self.client.loop.is_running():
self.client.loop.create_task(self.stop())
else:
self.client.loop.run_until_complete(self.stop())
def __repr__(self) -> str:
if not self.client.is_connected:
return f"client disconnected, session_name:{self.session_name}"
return f"client connected, session_name:{self.session_name}, username:{self.me.username}, phone:{self.me.phone}, detail:{self.me.stringify()}"
def _check_before_call(func):
def call_check_wrapper(self, *args, **kwargs):
if not self.is_valid():
raise RuntimeError("Client does not run.")
result = func(self, *args, **kwargs)
return result
return call_check_wrapper
def _acheck_before_call(func):
async def call_check_wrapper(self, *args, **kwargs):
if not self.is_valid():
raise RuntimeError("Client does not run.")
result = await func(self, *args, **kwargs)
return result
return call_check_wrapper
@_check_before_call
def to_dict(self) -> dict:
return self.me.to_dict()
@_check_before_call
def to_json(self) -> str:
return self.me.to_json()
def is_valid(self) -> bool:
return self.client.is_connected() and self.me is not None
@_check_before_call
def _register_update_event(self, from_users: list[int] = []) -> None:
@self.client.on(events.NewMessage(incoming=True, from_users=from_users))
async def _incoming_new_message_handler(event) -> None:
msg: types.Message = event.message
self.db.insert_by_message(self.me, msg)
async def login(self, mode: Union["phone", "qrcode"] = "qrcode") -> str:
if self.is_valid():
return ""
if mode == "phone":
raise NotImplementedError
if self.qr_login is not None:
return self.qr_login.url
self.qr_login = await self.client.qr_login()
async def wait_for_qr_login():
try:
await self.qr_login.wait()
await self.start()
except Exception as err:
logger.warning(f"wait for login, {err=}, {traceback.format_exc()}")
finally:
self.login_task = None
self.qr_login = None
self.login_task = self.client.loop.create_task(wait_for_qr_login())
return self.qr_login.url
async def start(self) -> None:
if self.is_valid():
return
if not self.client.is_connected():
await self.client.connect()
self.me = await self.client.get_me()
if self.me is None:
raise RuntimeError(f"The {self.session_name} Client Does Not Login")
for _ in range(self.MAX_WORKER_ROUTINE):
worker_routine = self.client.loop.create_task(self._worker_routine_handler())
self.worker_routines.append(worker_routine)
if len(self.client_param.whitelist_chat) > 0:
self._register_update_event(from_users=self.client_param.whitelist_chat)
await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat()))
# await self.task_queue.put((self._get_unique_task_id(), self._cache_whitelist_chat2()))
async def stop(self) -> None:
await self.client.loop.create_task(self._cancel_tasks())
while not self.task_queue.empty():
self.task_queue.get_nowait()
self.task_queue.task_done()
await self.client.disconnect()
async def _cancel_tasks(self) -> None:
for t in self.worker_routines:
try:
t.cancel()
except Exception as err:
logger.error(f"{err=}")
logger.error(traceback.format_exc())
async def _cache_whitelist_chat2(self):
for chat_id in self.client_param.whitelist_chat:
async for msg in self.client.iter_messages(chat_id):
if len(self.db.get_msg_by_unique_id(self.db.generate_unique_id_by_msg(self.me, msg))) != 0:
continue
self.db.insert_by_message(self.me, msg)
logger.info(f"{chat_id} quit cache task.")
async def _cache_whitelist_chat(self):
for chat_id in self.client_param.whitelist_chat:
# update newest msg
newest_msg = self.db.get_newest_msg_by_chat_id(chat_id)
if len(newest_msg) > 0:
newest_msg = newest_msg[0]
async for msg in self.client.iter_messages(chat_id):
if msg.id <= self.db.get_column_msg_id(newest_msg):
break
self.db.insert_by_message(self.me, msg)
# update oldest msg
oldest_msg = self.db.get_oldest_msg_by_chat_id(chat_id)
if len(oldest_msg) > 0:
oldest_msg = oldest_msg[0]
offset = self.db.get_column_msg_id(oldest_msg)
async for msg in self.client.iter_messages(chat_id, offset_id=offset):
self.db.insert_by_message(self.me, msg)
else:
async for msg in self.client.iter_messages(chat_id):
self.db.insert_by_message(self.me, msg)
logger.info(f"{chat_id} quit cache task.")
@_acheck_before_call
async def get_message(self, chat_id: int, msg_id: int) -> types.Message:
msg = await self.client.get_messages(chat_id, ids=msg_id)
return msg
@_acheck_before_call
async def get_dialogs(self, limit: int = 10, offset: int = 0, refresh: bool = False) -> hints.TotalList:
if self.dialogs_cache is not None and refresh is False:
return self.dialogs_cache[offset : offset + limit]
self.dialogs_cache = await self.client.get_dialogs()
return self.dialogs_cache[offset : offset + limit]
async def _worker_routine_handler(self) -> None:
while self.client.is_connected():
try:
task = await self.task_queue.get()
await task[1]
except Exception as err:
logger.error(f"{err=}")
logger.error(traceback.format_exc())
finally:
self.task_queue.task_done()
def _get_unique_task_id(self) -> int:
self.task_id += 1
return self.task_id
async def _get_offset_msg_id(self, chat_id: int, offset: int) -> int:
if offset != 0:
begin = await self.client.get_messages(chat_id, limit=1)
if len(begin) == 0:
return hints.TotalList()
first_id = begin[0].id
offset = first_id + offset
return offset
@_acheck_before_call
async def get_messages(self, chat_id: int, limit: int = 10, offset: int = 0) -> hints.TotalList:
offset = await self._get_offset_msg_id(chat_id, offset)
res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset)
return res_list
@_acheck_before_call
async def get_messages_by_search(
self,
chat_id: int,
search_word: str,
limit: int = 10,
offset: int = 0,
inner_search: bool = False,
ignore_case: bool = False,
) -> hints.TotalList:
offset = await self._get_offset_msg_id(chat_id, offset)
if inner_search:
res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word)
return res_list
# search by myself
res_list = hints.TotalList()
cnt = 0
async for msg in self.client.iter_messages(chat_id, offset_id=offset):
if cnt >= 1_000:
break
cnt += 1
if msg.text.find(search_word) == -1 and apiutils.get_message_media_name(msg).find(search_word) == -1:
continue
res_list.append(msg)
if len(res_list) >= limit:
break
return res_list
async def get_messages_by_search_db(
self,
chat_id: int,
search_word: str,
limit: int = 10,
offset: int = 0,
inc: bool = False,
ignore_case: bool = False,
) -> list[any]:
if chat_id not in self.client_param.whitelist_chat:
return []
res = self.db.get_msg_by_chat_id_and_keyword(
chat_id,
search_word,
limit=limit,
offset=offset,
inc=inc,
ignore_case=ignore_case,
)
res = [self.db.get_column_msg_js(v) for v in res]
return res
async def _download_media_chunk(self, msg: types.Message, media_holder: MediaChunkHolder) -> None:
try:
offset = media_holder.start + media_holder.length
target_size = media_holder.target_len - media_holder.length
remain_size = target_size
async for chunk in self.client.iter_download(msg, offset=offset, chunk_size=self.SINGLE_NET_CHUNK_SIZE):
if not isinstance(chunk, bytes):
chunk = chunk.tobytes()
remain_size -= len(chunk)
if remain_size <= 0:
media_holder.append_chunk_mem(chunk[: len(chunk) + remain_size])
else:
media_holder.append_chunk_mem(chunk)
if media_holder.is_completed():
media_holder.set_done()
break
if await media_holder.is_disconneted():
raise asyncio.CancelledError("all requester canceled.")
except asyncio.CancelledError as err:
logger.info(f"cancel holder:{media_holder}")
self.media_chunk_manager.cancel_media_chunk(media_holder)
except Exception as err:
logger.error(
f"_download_media_chunk err:{err=},{offset=},{target_size=},{media_holder},\r\n{err=}\r\n{traceback.format_exc()}"
)
finally:
logger.debug(f"downloaded chunk:{time.time()}.{offset=},{target_size=},{media_holder}")
async def streaming_get_iter(self, msg: types.Message, start: int, end: int, req: Request):
try:
logger.debug(f"new steaming request:{msg.chat_id=},{msg.id=},[{start}:{end}]")
cur_task_id = self._get_unique_task_id()
pos = start
while not await req.is_disconnected() and pos <= end:
cache_chunk = self.media_chunk_manager.get_media_chunk(msg, pos)
if cache_chunk is None:
# post download task
# align pos download task
file_size = msg.media.document.size
# align_pos = pos // self.SINGLE_MEDIA_SIZE * self.SINGLE_MEDIA_SIZE
align_pos = pos
align_size = min(self.SINGLE_MEDIA_SIZE, file_size - align_pos)
holder = self.media_chunk_manager.create_media_chunk_holder(msg.chat_id, msg.id, align_pos, align_size)
holder.add_chunk_requester(req)
self.media_chunk_manager.set_media_chunk(holder)
self.task_queue.put_nowait((cur_task_id, self._download_media_chunk(msg, holder)))
elif not cache_chunk.is_completed():
# yield return completed part
# await untill completed or pos > end
cache_chunk.add_chunk_requester(req)
while pos < cache_chunk.start + cache_chunk.target_len and pos <= end:
if await req.is_disconnected():
break
offset = pos - cache_chunk.start
if offset >= cache_chunk.length:
await cache_chunk.wait_chunk_update()
continue
need_len = min(cache_chunk.length - offset, end - pos + 1)
pos = pos + need_len
yield cache_chunk.mem[offset : offset + need_len]
else:
offset = pos - cache_chunk.start
if offset >= cache_chunk.length:
raise RuntimeError(f"lru cache missed!{pos=},{cache_chunk=}")
need_len = min(cache_chunk.length - offset, end - pos + 1)
pos = pos + need_len
yield cache_chunk.mem[offset : offset + need_len]
except Exception as err:
logger.error(f"stream iter:{err=}")
logger.error(traceback.format_exc())
finally:
async def _cancel_task_by_id(task_id: int):
for _ in range(self.task_queue.qsize()):
task = self.task_queue.get_nowait()
self.task_queue.task_done()
if task[0] != task_id:
self.task_queue.put_nowait(task)
await self.client.loop.create_task(_cancel_task_by_id(cur_task_id))
logger.debug(f"yield quit,{msg.chat_id=},{msg.id=},[{start}:{end}]")
def __enter__(self):
raise NotImplementedError
def __exit__(self):
raise NotImplementedError
async def __aenter__(self):
await self.start()
async def __aexit__(self):
await self.stop()