feat(auth):实现OAuth认证

This commit is contained in:
li1553770945 2025-05-30 00:01:29 +08:00
parent c50c75821c
commit 054017bfe5
2 changed files with 232 additions and 14 deletions

View File

@ -0,0 +1,186 @@
import { createServer } from 'node:http';
import { URL } from 'node:url';
import { OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth.js';
import { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
import open from 'open';
// const CALLBACK_PORT = 16203; // Use different port than auth server (3001)
// const CALLBACK_URL = `http://localhost:${CALLBACK_PORT}/callback`;
/**
* @description OAuth客户端提供者
*/
class InMemoryOAuthClientProvider implements OAuthClientProvider {
private _clientInformation?: OAuthClientInformationFull;
private _tokens?: OAuthTokens;
private _codeVerifier?: string;
constructor(
private readonly _redirectUrl: string | URL,
private readonly _clientMetadata: OAuthClientMetadata,
onRedirect?: (url: URL) => void
) {
this._onRedirect = onRedirect || ((url) => {
console.log(`Redirect to: ${url.toString()}`);
});
}
private _onRedirect: (url: URL) => void;
get redirectUrl(): string | URL {
return this._redirectUrl;
}
get clientMetadata(): OAuthClientMetadata {
return this._clientMetadata;
}
clientInformation(): OAuthClientInformation | undefined {
return this._clientInformation;
}
saveClientInformation(clientInformation: OAuthClientInformationFull): void {
this._clientInformation = clientInformation;
}
tokens(): OAuthTokens | undefined {
return this._tokens;
}
saveTokens(tokens: OAuthTokens): void {
this._tokens = tokens;
}
redirectToAuthorization(authorizationUrl: URL): void {
this._onRedirect(authorizationUrl);
}
saveCodeVerifier(codeVerifier: string): void {
this._codeVerifier = codeVerifier;
}
codeVerifier(): string {
if (!this._codeVerifier) {
throw new Error('No code verifier saved');
}
return this._codeVerifier;
}
}
export class OAuthClient {
port: number;
callbackUrl: string;
constructor() {
console.log('🔐 Initializing OAuth client...');
// 初始化OAuth客户端
this.port = Math.floor(Math.random() * (50000 - 40000 + 1)) + 40000;
//TODO: 如果端口被占用,重新生成一个端口
this.callbackUrl = `http://localhost:${this.port}/callback`;
}
/**
* @description OAuth回调请求
* @returns {Promise<string>}
* @throws {Error}
*/
public async waitForOAuthCallback(): Promise<string> {
return new Promise<string>((resolve, reject) => {
const server = createServer((req, res) => {
// Ignore favicon requests
if (req.url === '/favicon.ico') {
res.writeHead(404);
res.end();
return;
}
console.log(`📥 Received callback: ${req.url}`);
const parsedUrl = new URL(req.url || '', 'http://localhost');
const code = parsedUrl.searchParams.get('code');
const error = parsedUrl.searchParams.get('error');
if (code) {
console.log(`✅ Authorization code received: ${code?.substring(0, 10)}...`);
res.writeHead(200, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authorization Successful!</h1>
<p>You can close this window and return to the terminal.</p>
<script>setTimeout(() => window.close(), 2000);</script>
</body>
</html>
`);
resolve(code);
setTimeout(() => server.close(), 3000);
} else if (error) {
console.log(`❌ Authorization error: ${error}`);
res.writeHead(400, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authorization Failed</h1>
<p>Error: ${error}</p>
</body>
</html>
`);
reject(new Error(`OAuth authorization failed: ${error}`));
} else {
console.log(`❌ No authorization code or error in callback`);
res.writeHead(400);
res.end('Bad request');
reject(new Error('No authorization code provided'));
}
});
server.listen(this.port, () => {
console.log(`OAuth callback server started on http://localhost:${this.port}`);
});
});
}
/**
* @description Oauth认证provider
* @return {Promise<OAuthClientProvider>} OAuthClientProvider实例
*/
public async getOAuthProvider(): Promise<OAuthClientProvider> {
const clientMetadata: OAuthClientMetadata = {
client_name: 'Simple OAuth MCP Client',
redirect_uris: [this.callbackUrl],
grant_types: ['authorization_code', 'refresh_token'],
response_types: ['code'],
token_endpoint_auth_method: 'none',
};
console.log('🔐 Creating OAuth provider...');
const oauthProvider = new InMemoryOAuthClientProvider(
this.callbackUrl,
clientMetadata,
(redirectUrl: URL) => {
console.log(`📌 OAuth redirect handler called - opening browser`);
console.log(`Opening browser to: ${redirectUrl.toString()}`);
this.openBrowser(redirectUrl.toString());
}
);
console.log('🔐 OAuth provider created');
return oauthProvider;
}
/**
* @description
* @param url URL
*/
public async openBrowser(url: string): Promise<void> {
console.log(`🌐 Opening browser for authorization: ${url}`);
await open(url); // 自动适配不同操作系统
}
}

View File

@ -6,13 +6,17 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/
import type { McpOptions, McpTransport, IServerVersion, ToolCallResponse, ToolCallContent } from './client.dto.js'; import type { McpOptions, McpTransport, IServerVersion, ToolCallResponse, ToolCallContent } from './client.dto.js';
import { PostMessageble } from "../hook/adapter.js"; import { PostMessageble } from "../hook/adapter.js";
import { createOcrWorker, saveBase64ImageData } from "./ocr.service.js"; import { createOcrWorker, saveBase64ImageData } from "./ocr.service.js";
import { OAuthClient } from "./auth.service.js";
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js';
import { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
// 增强的客户端类 // 增强的客户端类
export class McpClient { export class McpClient {
private client: Client; private client: Client;
private transport?: McpTransport; private transport?: McpTransport;
private options: McpOptions; private options: McpOptions;
private serverVersion: IServerVersion; private serverVersion: IServerVersion;
private oAuthClient: OAuthClient;
private oauthPovider?: OAuthClientProvider;
constructor(options: McpOptions) { constructor(options: McpOptions) {
this.options = options; this.options = options;
@ -31,11 +35,15 @@ export class McpClient {
} }
} }
); );
this.oAuthClient = new OAuthClient();
} }
// 连接方法 // 连接方法
public async connect(): Promise<void> { public async connect(): Promise<void> {
if (!this.oauthPovider){
this.oauthPovider = await this.oAuthClient.getOAuthProvider();
}
// 根据连接类型创建传输层 // 根据连接类型创建传输层
switch (this.options.connectionType) { switch (this.options.connectionType) {
case 'STDIO': case 'STDIO':
@ -55,7 +63,7 @@ export class McpClient {
this.transport = new SSEClientTransport( this.transport = new SSEClientTransport(
new URL(this.options.url), new URL(this.options.url),
{ {
// authProvider: authProvider: this.oauthPovider
} }
); );
@ -66,7 +74,10 @@ export class McpClient {
throw new Error('URL is required for STREAMABLE_HTTP connection'); throw new Error('URL is required for STREAMABLE_HTTP connection');
} }
this.transport = new StreamableHTTPClientTransport( this.transport = new StreamableHTTPClientTransport(
new URL(this.options.url) new URL(this.options.url),
{
authProvider:this.oauthPovider
}
); );
break; break;
default: default:
@ -75,11 +86,32 @@ export class McpClient {
// 建立连接 // 建立连接
if (this.transport) { if (this.transport) {
await this.client.connect(this.transport); try {
console.log(`Connected to MCP server via ${this.options.connectionType}`); console.log(`🔌 Connecting to MCP server via ${this.options.connectionType}...`);
await this.client.connect(this.transport);
console.log(`Connected to MCP server via ${this.options.connectionType}`);
} catch (error) {
if (error instanceof UnauthorizedError) {
if (!(this.transport instanceof StreamableHTTPClientTransport) && !(this.transport instanceof SSEClientTransport)) {
console.error('❌ OAuth is only supported for StreamableHTTP and SSE transports. Please use one of these transports for OAuth authentication.');
return;
}
console.log('🔐 OAuth required - waiting for authorization...');
const callbackPromise = this.oAuthClient.waitForOAuthCallback();
const authCode = await callbackPromise;
await this.transport.finishAuth(authCode);
console.log('🔐 Authorization code received:', authCode);
console.log('🔌 Reconnecting with authenticated transport...');
await this.connect(); // 递归重试
} else {
console.error('❌ Connection failed with non-auth error:', error);
throw error;
}
}
} }
} }
public getServerVersion() { public getServerVersion() {
if (this.serverVersion) { if (this.serverVersion) {
return this.serverVersion; return this.serverVersion;