diff --git a/.gitignore b/.gitignore index 3369b1d..cb2ad04 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ __pycache__ .vscode *.session *.toml +*.db +*.service diff --git a/TgFileSystemClient.py b/TgFileSystemClient.py index 93e324e..5a5354a 100644 --- a/TgFileSystemClient.py +++ b/TgFileSystemClient.py @@ -1,9 +1,11 @@ import asyncio -from typing import Union +import json +from typing import Union, Optional -from telethon import TelegramClient, types +from telethon import TelegramClient, types, hints import configParse +import apiutils class TgFileSystemClient(object): @@ -12,12 +14,13 @@ class TgFileSystemClient(object): session_name: str proxy_param: dict[str, any] client: TelegramClient + dialogs_cache: Optional[hints.TotalList] = None me: Union[types.User, types.InputPeerUser] - def __init__(self, param: configParse.TgToFileSystemParameter) -> None: + def __init__(self, session_name: str, param: configParse.TgToFileSystemParameter) -> None: self.api_id = param.tgApi.api_id self.api_hash = param.tgApi.api_hash - self.session_name = param.base.name + self.session_name = session_name self.proxy_param = { 'proxy_type': param.proxy.proxy_type, 'addr': param.proxy.addr, @@ -26,26 +29,108 @@ class TgFileSystemClient(object): self.client = TelegramClient( self.session_name, self.api_id, self.api_hash, proxy=self.proxy_param) + def __del__(self) -> None: + self.client.disconnect() def __repr__(self) -> str: if not self.client.is_connected: return f"client disconnected, session_name:{self.session_name}" return f"client connected, session_name:{self.session_name}, username:{self.me.username}, phone:{self.me.phone}, detail:{self.me.stringify()}" - async def init_client(self): + def _call_before_check(func): + def call_check_wrapper(self, *args, **kwargs): + if not self.is_valid(): + raise RuntimeError("Client does not run.") + result = func(self, *args, **kwargs) + return result + return call_check_wrapper + + def _acall_before_check(func): + async def call_check_wrapper(self, *args, **kwargs): + if not self.is_valid(): + raise RuntimeError("Client does not run.") + result = await func(self, *args, **kwargs) + return result + return call_check_wrapper + + @_call_before_check + def to_dict(self) -> dict: + return self.me.to_dict() + + @_call_before_check + def to_json(self) -> str: + return self.me.to_json() + + def is_valid(self) -> bool: + return self.client.is_connected() and self.me is not None + + async def start(self) -> None: + if not self.client.is_connected(): + await self.client.connect() self.me = await self.client.get_me() + if self.me is None: + raise RuntimeError( + f"The {self.session_name} Client Does Not Login") + + async def stop(self) -> None: + await self.client.disconnect() + + @_acall_before_check + async def get_message(self, chat_id: int, msg_id: int) -> types.Message: + msg = await self.client.get_messages(chat_id, ids=msg_id) + return msg + + @_acall_before_check + async def get_dialogs(self, limit: int = 10, offset: int = 0, refresh: bool = False) -> hints.TotalList: + def _to_json(item) -> str: + return json.dumps({"id": item.id, "is_channel": item.is_channel, + "is_group": item.is_group, "is_user": item.is_user, "name": item.name, }) + if self.dialogs_cache is not None and refresh is False: + return self.dialogs_cache[offset:offset+limit] + self.dialogs_cache = await self.client.get_dialogs() + for item in self.dialogs_cache: + item.to_json = _to_json + return self.dialogs_cache[offset:offset+limit] + + async def _get_offset_msg_id(self, chat_id: int, offset: int) -> int: + if offset != 0: + begin = await self.client.get_messages(chat_id, limit=1) + if len(begin) == 0: + return hints.TotalList() + first_id = begin[0].id + offset = first_id + offset + return offset + + @_acall_before_check + async def get_messages(self, chat_id: int, limit: int = 10, offset: int = 0) -> hints.TotalList: + offset = await self._get_offset_msg_id(chat_id, offset) + res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset) + return res_list + + @_acall_before_check + async def get_messages_by_search(self, chat_id: int, search_word: str, limit: int = 10, offset: int = 0, inner_search: bool = False) -> hints.TotalList: + offset = await self._get_offset_msg_id(chat_id, offset) + if inner_search: + res_list = await self.client.get_messages(chat_id, limit=limit, offset_id=offset, search=search_word) + return res_list + # search by myself + res_list = hints.TotalList() + async for msg in self.client.iter_messages(chat_id, offset_id=offset): + if msg.text.find(search_word) == -1 and apiutils.get_message_media_name(msg).find(search_word) == -1: + continue + res_list.append(msg) + if len(res_list) >= limit: + break + return res_list def __enter__(self): - self.client.__enter__() - self.client.loop.run_until_complete(self.init_client()) + raise NotImplemented def __exit__(self): - self.client.__exit__() - self.me = None + raise NotImplemented async def __aenter__(self): - await self.client.__enter__() - await self.init_client() + await self.start() async def __aexit__(self): - await self.client.__aexit__() + await self.stop() diff --git a/TgFileSystemClientManager.py b/TgFileSystemClientManager.py index 9998f88..9586531 100644 --- a/TgFileSystemClientManager.py +++ b/TgFileSystemClientManager.py @@ -1,35 +1,63 @@ from typing import Any +import time +import hashlib +import os + from TgFileSystemClient import TgFileSystemClient +from UserManager import UserManager +import configParse class TgFileSystemClientManager(object): MAX_MANAGE_CLIENTS: int = 10 - clients: dict[int, TgFileSystemClient] - - def __init__(self) -> None: + param: configParse.TgToFileSystemParameter + clients: dict[str, TgFileSystemClient] = {} + + def __init__(self, param: configParse.TgToFileSystemParameter) -> None: + self.param = param + self.db = UserManager() + + def __del__(self) -> None: pass - - def push_client(self, client: TgFileSystemClient) -> int: - """ - push client to manager. - Arguments - client + def check_client_session_exist(self, client_id: str) -> bool: + return os.path.isfile(client_id + '.session') - Returns - client id + def generate_client_id(self) -> str: + return hashlib.md5( + (str(time.perf_counter()) + self.param.base.salt).encode('utf-8')).hexdigest() - """ - self.clients[id(client)] = client - return id(client) - - def get_client(self, client_id: int) -> TgFileSystemClient: + def create_client(self, client_id: str = None) -> TgFileSystemClient: + if client_id is None: + client_id = self.generate_client_id() + client = TgFileSystemClient(client_id, self.param) + return client + + def register_client(self, client: TgFileSystemClient) -> bool: + self.clients[client.session_name] = client + return True + + def deregister_client(self, client_id: str) -> bool: + self.clients.pop(client_id) + return True + + def get_client(self, client_id: str) -> TgFileSystemClient: client = self.clients.get(client_id) return client - + async def get_client_force(self, client_id: str) -> TgFileSystemClient: + client = self.get_client(client_id) + if client is None: + if not self.check_client_session_exist(client_id): + raise RuntimeError("Client session does not found.") + client = self.create_client(client_id=client_id) + if not client.is_valid(): + await client.start() + self.register_client(client) + return client + if __name__ == "__main__": import configParse - t: TgFileSystemClient = TgFileSystemClient(configParse.get_TgToFileSystemParameter()) + # t: TgFileSystemClient = TgFileSystemClient(configParse.get_TgToFileSystemParameter()) print(f"{t.session_name=}") diff --git a/UserManager.py b/UserManager.py new file mode 100644 index 0000000..a83a6b1 --- /dev/null +++ b/UserManager.py @@ -0,0 +1,59 @@ +import sqlite3 + +from pydantic import BaseModel + + +class UserUpdateParam(BaseModel): + client_id: str + username: str + phone: str + tg_user_id: int + last_login_time: int + + +class MessageUpdateParam(BaseModel): + tg_chat_id: int + tg_message_id: int + client_id: str + username: str + phone: str + tg_user_id: int + + +class UserManager(object): + def __init__(self) -> None: + self.con = sqlite3.connect("user.db") + self.cur = self.con.cursor() + if not self._table_has_been_inited(): + self._first_runtime_run_once() + + def __del__(self) -> None: + self.con.commit() + self.con.close() + + def update_user(self) -> None: + raise NotImplemented + + def update_message(self) -> None: + raise NotImplemented + + def get_user_info() -> None: + raise NotImplemented + + def _table_has_been_inited(self) -> bool: + res = self.cur.execute("SELECT name FROM sqlite_master") + return len(res.fetchall()) != 0 + + def _first_runtime_run_once(self) -> None: + if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='user'").fetchall()) == 0: + self.cur.execute( + "CREATE TABLE user(client_id, username, phone, tg_user_id, last_login_time)") + if len(self.cur.execute("SELECT name FROM sqlite_master WHERE name='message'").fetchall()) == 0: + self.cur.execute( + "CREATE TABLE message(tg_chat_id, tg_message_id, client_id, username, phone, tg_user_id, msg_ctx, msg_type)") + + +if __name__ == "__main__": + db = UserManager() + res = db.cur.execute("SELECT name FROM sqlite_master") + print(res.fetchall()) diff --git a/apiutils.py b/apiutils.py new file mode 100644 index 0000000..74804fa --- /dev/null +++ b/apiutils.py @@ -0,0 +1,64 @@ +import time + +from fastapi import status, HTTPException +from telethon import types +from functools import wraps + +import configParse + + +def get_range_header(range_header: str, file_size: int) -> tuple[int, int]: + def _invalid_range(): + return HTTPException( + status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, + detail=f"Invalid request range (Range:{range_header!r})", + ) + + try: + h = range_header.replace("bytes=", "").split("-") + start = int(h[0]) if h[0] != "" else 0 + end = int(h[1]) if h[1] != "" else file_size - 1 + except ValueError: + raise _invalid_range() + + if start > end or start < 0 or end > file_size - 1: + raise _invalid_range() + return start, end + + +def get_message_media_name(msg: types.Message) -> str: + if msg.media is None or msg.media.document is None: + return "" + for attr in msg.media.document.attributes: + if isinstance(attr, types.DocumentAttributeFilename): + return attr.file_name + + +def timeit(func): + if configParse.get_TgToFileSystemParameter().base.timeit_enable: + @wraps(func) + def timeit_wrapper(*args, **kwargs): + start_time = time.perf_counter() + result = func(*args, **kwargs) + end_time = time.perf_counter() + total_time = end_time - start_time + print( + f'Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') + return result + return timeit_wrapper + return func + + +def atimeit(func): + if configParse.get_TgToFileSystemParameter().base.timeit_enable: + @wraps(func) + async def timeit_wrapper(*args, **kwargs): + start_time = time.perf_counter() + result = await func(*args, **kwargs) + end_time = time.perf_counter() + total_time = end_time - start_time + print( + f'AFunction {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') + return result + return timeit_wrapper + return func diff --git a/config.toml.example b/config.toml.example new file mode 100644 index 0000000..c8c08ec --- /dev/null +++ b/config.toml.example @@ -0,0 +1,14 @@ +[base] +salt = "AnyTokenYouWanted" +port = 7777 +timeit_enable = false + +[tgApi] +api_id = int_app_id_from_tg +api_hash = "api_hash_from_tg" + +[proxy] +enable = false +proxy_type = "socks5" +addr = "172.25.32.1" +port = 7890 diff --git a/configParse.py b/configParse.py index 69da618..3c0a2f4 100644 --- a/configParse.py +++ b/configParse.py @@ -4,8 +4,9 @@ from pydantic import BaseModel class TgToFileSystemParameter(BaseModel): class BaseParameter(BaseModel): - name: str + salt: str port: int + timeit_enable: bool base: BaseParameter class ApiParameter(BaseModel): diff --git a/start.py b/start.py index 1b46449..ce4e215 100644 --- a/start.py +++ b/start.py @@ -1,22 +1,29 @@ import asyncio +import time +import json import uvicorn -from fastapi import FastAPI -from fastapi import status +from fastapi import FastAPI, status, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import Response +from fastapi.responses import Response, StreamingResponse from contextlib import asynccontextmanager -from telethon import TelegramClient +from telethon import types, hints +from pydantic import BaseModel import configParse +import apiutils +from TgFileSystemClientManager import TgFileSystemClientManager +from TgFileSystemClient import TgFileSystemClient + +clients_mgr: TgFileSystemClientManager = None + @asynccontextmanager async def lifespan(app: FastAPI): + global clients_mgr param = configParse.get_TgToFileSystemParameter() - loop = asyncio.get_event_loop() - tg_client_task = loop.create_task(start_tg_client(param)) + clients_mgr = TgFileSystemClientManager(param) yield - asyncio.gather(*[tg_client_task]) app = FastAPI(lifespan=lifespan) @@ -28,47 +35,100 @@ app.add_middleware( allow_headers=["*"], ) -@app.post("/tg/{chat_id}/{message_id}") -async def get_test(chat_id: str, message_id: str): - print(f"test: {chat_id=}, {message_id=}") - return Response(status_code=status.HTTP_200_OK) - -async def start_tg_client(param: configParse.TgToFileSystemParameter): - api_id = param.tgApi.api_id - api_hash = param.tgApi.api_hash - session_name = param.base.name - proxy_param = { - 'proxy_type': param.proxy.proxy_type, - 'addr': param.proxy.addr, - 'port': param.proxy.port, - } if param.proxy.enable else {} - client = TelegramClient(session_name, api_id, api_hash, proxy=proxy_param) +class TgToFileListRequestBody(BaseModel): + token: str + search: str = "" + chat_id: int = 0 + index: int = 0 + length: int = 10 + refresh: bool = False + inner: bool = False - async def tg_client_main(): - # Getting information about yourself - me = await client.get_me() - # "me" is a user object. You can pretty-print - # any Telegram object with the "stringify" method: - print(me.stringify()) +@app.post("/tg/api/v1/file/list") +@apiutils.atimeit +async def get_tg_file_list(body: TgToFileListRequestBody): + try: + res = hints.TotalList() + res_type = "chat" + client = await clients_mgr.get_client_force(body.token) + res_dict = {} + if body.chat_id == 0: + res = await client.get_dialogs(limit=body.length, offset=body.index, refresh=body.refresh) + res_dict = [{"id": item.id, "is_channel": item.is_channel, + "is_group": item.is_group, "is_user": item.is_user, "name": item.name, } for item in res] + elif body.search != "": + res = await client.get_messages_by_search(body.chat_id, search_word=body.search, limit=body.length, offset=body.index, inner_search=body.inner) + res_type = "msg" + res_dict = [json.loads(item.to_json()) for item in res] + else: + res = await client.get_messages(body.chat_id, limit=body.length, offset=body.index) + res_type = "msg" + res_dict = [json.loads(item.to_json()) for item in res] - # When you print something, you see a representation of it. - # You can access all attributes of Telegram objects with - # the dot operator. For example, to get the username: - username = me.username - print(username) - print(me.phone) - # You can print all the dialogs/conversations that you are part of: - dialogs = await client.get_dialogs() - for dialog in dialogs: - print(f"{dialog.name} has ID {dialog.id}") - # async for dialog in client.iter_dialogs(): - # print(dialog.name, 'has ID', dialog.id) + response_dict = { + "client": json.loads(client.to_json()), + "type": res_type, + "list": res_dict, + } + return Response(json.dumps(response_dict), status_code=status.HTTP_200_OK) + except Exception as err: + print(f"{err=}") + return Response(f"{err=}", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) - async with client: - await tg_client_main() +@app.get("/tg/api/v1/file/msg") +@apiutils.atimeit +async def get_tg_file_media_stream(token: str, cid: int, mid: int, request: Request): + async def get_msg_media_range_requests(client: TgFileSystemClient, msg: types.Message, start: int, end: int): + MAX_CHUNK_SIZE = 1024 * 1024 + pos = start + async for chunk in client.client.iter_download(msg, offset=pos, chunk_size=min(end + 1 - pos, MAX_CHUNK_SIZE)): + pos = pos + len(chunk) + yield chunk.tobytes() + msg_id = mid + chat_id = cid + headers = { + # "content-type": "video/mp4", + "accept-ranges": "bytes", + "content-encoding": "identity", + # "content-length": stream_file_size, + "access-control-expose-headers": ( + "content-type, accept-ranges, content-length, " + "content-range, content-encoding" + ), + } + range_header = request.headers.get("range") + try: + client = await clients_mgr.get_client_force(token) + msg = await client.get_message(chat_id, msg_id) + file_size = msg.media.document.size + start = 0 + end = file_size - 1 + status_code = status.HTTP_200_OK + mime_type = msg.media.document.mime_type + headers["content-type"] = mime_type + file_name = apiutils.get_message_media_name(msg) + if file_name == "": + maybe_file_type = mime_type.split("/")[-1] + file_name = f"{chat_id}.{msg_id}.{maybe_file_type}" + headers["Content-Disposition"] = f'Content-Disposition: inline; filename="{file_name}"' + + if range_header is not None: + start, end = apiutils.get_range_header(range_header, file_size) + size = end - start + 1 + headers["content-length"] = str(size) + headers["content-range"] = f"bytes {start}-{end}/{file_size}" + status_code = status.HTTP_206_PARTIAL_CONTENT + return StreamingResponse( + get_msg_media_range_requests(client, msg, start, end), + headers=headers, + status_code=status_code, + ) + except Exception as err: + print(f"{err=}") + return Response(f"{err=}", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) if __name__ == "__main__": diff --git a/test.py b/test.py index 13b67b2..a983274 100644 --- a/test.py +++ b/test.py @@ -4,8 +4,8 @@ import configParse param = configParse.get_TgToFileSystemParameter() # Remember to use your own values from my.telegram.org! -api_id = param.ApiParameter.api_id -api_hash = param.ApiParameter.api_hash +api_id = param.tgApi.api_id +api_hash = param.tgApi.api_hash client = TelegramClient('anon', api_id, api_hash, proxy={ 'proxy_type': 'socks5', 'addr': '172.25.32.1', @@ -30,8 +30,12 @@ async def main(): print(me.phone) # You can print all the dialogs/conversations that you are part of: - async for dialog in client.iter_dialogs(): - print(dialog.name, 'has ID', dialog.id) + # async for dialog in client.iter_dialogs(): + # print(dialog.name, 'has ID', dialog.id) + # test_res = await client.get_input_entity(dialog.id) + # print(test_res) + # await client.send_message(-1001150067822, "test message from python") + # nep_channel = await client.get_dialogs("-1001251458407") # You can send messages to yourself... # await client.send_message('me', 'Hello, myself!') @@ -60,9 +64,14 @@ async def main(): # await client.send_file('me', './test.py') # You can print the message history of any chat: - message = await client.get_messages('me', ids=206963) - async for message in client.iter_messages('me'): + # message = await client.get_messages(nep_channel[0]) + chat = await client.get_input_entity(-1001216816802) + async for message in client.iter_messages(chat, ids=98724): print(message.id, message.text) + # print(message.stringify()) + # print(message.to_json()) + # print(message.to_dict()) + # await client.download_media(message) # You can download media from messages, too! # The method will return the path where the file was saved. @@ -72,3 +81,54 @@ async def main(): with client: client.loop.run_until_complete(main()) + + +async def start_tg_client(param: configParse.TgToFileSystemParameter): + api_id = param.tgApi.api_id + api_hash = param.tgApi.api_hash + session_name = "test" + proxy_param = { + 'proxy_type': param.proxy.proxy_type, + 'addr': param.proxy.addr, + 'port': param.proxy.port, + } if param.proxy.enable else {} + client = TelegramClient(session_name, api_id, api_hash, proxy=proxy_param) + + async def tg_client_main(): + # Getting information about yourself + me = await client.get_me() + + # "me" is a user object. You can pretty-print + # any Telegram object with the "stringify" method: + print(me.stringify()) + + # When you print something, you see a representation of it. + # You can access all attributes of Telegram objects with + # the dot operator. For example, to get the username: + username = me.username + print(username) + print(me.phone) + # You can print all the dialogs/conversations that you are part of: + # dialogs = await client.get_dialogs() + # for dialog in dialogs: + # print(f"{dialog.name} has ID {dialog.id}")\ + path_task_list = [] + async for dialog in client.iter_dialogs(): + print(dialog.name, 'has ID', dialog.id) + # path = await client.download_profile_photo(dialog.id) + # t = client.loop.create_task( + # client.download_profile_photo(dialog.id)) + # path_task_list.append(t) + # res = await asyncio.gather(*path_task_list) + # for path in res: + # print(path) + + # async with client: + # await tg_client_main() + await client.connect() + # qr_login = await client.qr_login() + await client.start() + # print(qr_login.url) + # await qr_login.wait() + await tg_client_main() + await client.disconnect()