From 054017bfe5d2b2fc4dd3fa57e8f7ef2edf716206 Mon Sep 17 00:00:00 2001 From: li1553770945 <1553770945@qq.com> Date: Fri, 30 May 2025 00:01:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(auth):=E5=AE=9E=E7=8E=B0OAuth=E8=AE=A4?= =?UTF-8?q?=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- service/src/mcp/auth.service.ts | 186 ++++++++++++++++++++++++++++++ service/src/mcp/client.service.ts | 60 +++++++--- 2 files changed, 232 insertions(+), 14 deletions(-) create mode 100644 service/src/mcp/auth.service.ts diff --git a/service/src/mcp/auth.service.ts b/service/src/mcp/auth.service.ts new file mode 100644 index 0000000..0bdc631 --- /dev/null +++ b/service/src/mcp/auth.service.ts @@ -0,0 +1,186 @@ +import { createServer } from 'node:http'; +import { URL } from 'node:url'; +import { OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth.js'; +import { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; +import open from 'open'; + +// const CALLBACK_PORT = 16203; // Use different port than auth server (3001) +// const CALLBACK_URL = `http://localhost:${CALLBACK_PORT}/callback`; + +/** + * @description 内存中的OAuth客户端提供者 + */ +class InMemoryOAuthClientProvider implements OAuthClientProvider { + private _clientInformation?: OAuthClientInformationFull; + private _tokens?: OAuthTokens; + private _codeVerifier?: string; + + constructor( + private readonly _redirectUrl: string | URL, + private readonly _clientMetadata: OAuthClientMetadata, + onRedirect?: (url: URL) => void + ) { + this._onRedirect = onRedirect || ((url) => { + console.log(`Redirect to: ${url.toString()}`); + }); + } + + private _onRedirect: (url: URL) => void; + + get redirectUrl(): string | URL { + return this._redirectUrl; + } + + get clientMetadata(): OAuthClientMetadata { + return this._clientMetadata; + } + + clientInformation(): OAuthClientInformation | undefined { + return this._clientInformation; + } + + saveClientInformation(clientInformation: OAuthClientInformationFull): void { + this._clientInformation = clientInformation; + } + + tokens(): OAuthTokens | undefined { + return this._tokens; + } + + saveTokens(tokens: OAuthTokens): void { + this._tokens = tokens; + } + + redirectToAuthorization(authorizationUrl: URL): void { + this._onRedirect(authorizationUrl); + } + + saveCodeVerifier(codeVerifier: string): void { + this._codeVerifier = codeVerifier; + } + + codeVerifier(): string { + if (!this._codeVerifier) { + throw new Error('No code verifier saved'); + } + return this._codeVerifier; + } +} + + +export class OAuthClient { + port: number; + callbackUrl: string; + + constructor() { + console.log('🔐 Initializing OAuth client...'); + // 初始化OAuth客户端 + this.port = Math.floor(Math.random() * (50000 - 40000 + 1)) + 40000; + //TODO: 如果端口被占用,重新生成一个端口 + this.callbackUrl = `http://localhost:${this.port}/callback`; + + } + /** + * @description 开启本地服务器上,并监听OAuth回调请求,并解析授权码或错误信息 + * @returns {Promise} 返回授权码 + * @throws {Error} 如果没有收到授权码或发生错误 + */ + + public async waitForOAuthCallback(): Promise { + return new Promise((resolve, reject) => { + const server = createServer((req, res) => { + // Ignore favicon requests + if (req.url === '/favicon.ico') { + res.writeHead(404); + res.end(); + return; + } + + console.log(`📥 Received callback: ${req.url}`); + const parsedUrl = new URL(req.url || '', 'http://localhost'); + const code = parsedUrl.searchParams.get('code'); + const error = parsedUrl.searchParams.get('error'); + + if (code) { + console.log(`✅ Authorization code received: ${code?.substring(0, 10)}...`); + res.writeHead(200, { 'Content-Type': 'text/html' }); + res.end(` + + +

Authorization Successful!

+

You can close this window and return to the terminal.

+ + + + `); + + resolve(code); + setTimeout(() => server.close(), 3000); + } else if (error) { + console.log(`❌ Authorization error: ${error}`); + res.writeHead(400, { 'Content-Type': 'text/html' }); + res.end(` + + +

Authorization Failed

+

Error: ${error}

+ + + `); + reject(new Error(`OAuth authorization failed: ${error}`)); + } else { + console.log(`❌ No authorization code or error in callback`); + res.writeHead(400); + res.end('Bad request'); + reject(new Error('No authorization code provided')); + } + }); + + server.listen(this.port, () => { + console.log(`OAuth callback server started on http://localhost:${this.port}`); + }); + }); + } + + + + /** + * @description 获取Oauth认证provider + * @return {Promise} 返回一个OAuthClientProvider实例 + */ + public async getOAuthProvider(): Promise { + + const clientMetadata: OAuthClientMetadata = { + client_name: 'Simple OAuth MCP Client', + redirect_uris: [this.callbackUrl], + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'none', + }; + + console.log('🔐 Creating OAuth provider...'); + const oauthProvider = new InMemoryOAuthClientProvider( + this.callbackUrl, + clientMetadata, + (redirectUrl: URL) => { + console.log(`📌 OAuth redirect handler called - opening browser`); + console.log(`Opening browser to: ${redirectUrl.toString()}`); + this.openBrowser(redirectUrl.toString()); + } + ); + console.log('🔐 OAuth provider created'); + return oauthProvider; + } + + /** + * @description 打开浏览器 + * @param url 授权URL + */ + + public async openBrowser(url: string): Promise { + console.log(`🌐 Opening browser for authorization: ${url}`); + await open(url); // 自动适配不同操作系统 + } +} + + diff --git a/service/src/mcp/client.service.ts b/service/src/mcp/client.service.ts index f3c9aeb..d7f3d4e 100644 --- a/service/src/mcp/client.service.ts +++ b/service/src/mcp/client.service.ts @@ -6,13 +6,17 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/ import type { McpOptions, McpTransport, IServerVersion, ToolCallResponse, ToolCallContent } from './client.dto.js'; import { PostMessageble } from "../hook/adapter.js"; import { createOcrWorker, saveBase64ImageData } from "./ocr.service.js"; - +import { OAuthClient } from "./auth.service.js"; +import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'; +import { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; // 增强的客户端类 export class McpClient { private client: Client; private transport?: McpTransport; private options: McpOptions; private serverVersion: IServerVersion; + private oAuthClient: OAuthClient; + private oauthPovider?: OAuthClientProvider; constructor(options: McpOptions) { this.options = options; @@ -31,14 +35,18 @@ export class McpClient { } } ); + + this.oAuthClient = new OAuthClient(); } // 连接方法 public async connect(): Promise { - + if (!this.oauthPovider){ + this.oauthPovider = await this.oAuthClient.getOAuthProvider(); + } // 根据连接类型创建传输层 switch (this.options.connectionType) { - case 'STDIO': + case 'STDIO': this.transport = new StdioClientTransport({ command: this.options.command || '', args: this.options.args || [], @@ -55,18 +63,21 @@ export class McpClient { this.transport = new SSEClientTransport( new URL(this.options.url), { - // authProvider: + authProvider: this.oauthPovider } ); break; - + case 'STREAMABLE_HTTP': if (!this.options.url) { throw new Error('URL is required for STREAMABLE_HTTP connection'); } this.transport = new StreamableHTTPClientTransport( - new URL(this.options.url) + new URL(this.options.url), + { + authProvider:this.oauthPovider + } ); break; default: @@ -75,11 +86,32 @@ export class McpClient { // 建立连接 if (this.transport) { - await this.client.connect(this.transport); - console.log(`Connected to MCP server via ${this.options.connectionType}`); + try { + console.log(`🔌 Connecting to MCP server via ${this.options.connectionType}...`); + await this.client.connect(this.transport); + console.log(`Connected to MCP server via ${this.options.connectionType}`); + } catch (error) { + if (error instanceof UnauthorizedError) { + if (!(this.transport instanceof StreamableHTTPClientTransport) && !(this.transport instanceof SSEClientTransport)) { + console.error('❌ OAuth is only supported for StreamableHTTP and SSE transports. Please use one of these transports for OAuth authentication.'); + return; + } + console.log('🔐 OAuth required - waiting for authorization...'); + const callbackPromise = this.oAuthClient.waitForOAuthCallback(); + const authCode = await callbackPromise; + await this.transport.finishAuth(authCode); + console.log('🔐 Authorization code received:', authCode); + console.log('🔌 Reconnecting with authenticated transport...'); + await this.connect(); // 递归重试 + } else { + console.error('❌ Connection failed with non-auth error:', error); + throw error; + } + } } } + public getServerVersion() { if (this.serverVersion) { return this.serverVersion; @@ -93,7 +125,7 @@ export class McpClient { // 断开连接 public async disconnect(): Promise { await this.client.close(); - + console.log('Disconnected from MCP server'); } @@ -103,7 +135,7 @@ export class McpClient { } // 获取提示 - public async getPrompt(name: string, args: Record = {}) { + public async getPrompt(name: string, args: Record = {}) { return await this.client.getPrompt({ name, arguments: args }); @@ -138,7 +170,7 @@ export class McpClient { console.log('callToolOption', callToolOption); const res = await this.client.callTool(methodArgs, undefined, callToolOption); console.log('callTool res', res); - + return res; } } @@ -157,10 +189,10 @@ async function handleImage( if (content.data && content.mimeType) { const filename = saveBase64ImageData(content.data, content.mimeType); content.data = filename; - + // 加入工作线程 const worker = createOcrWorker(filename, webview); - + content._meta = { ocr: true, workerId: worker.id @@ -191,7 +223,7 @@ export function postProcessMcpToolcallResponse( case 'image': handleImage(content, webview); break; - + default: break; }