diff --git a/CHANGELOG.md b/CHANGELOG.md index 76cae35..ac6a847 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,9 @@ # Change Log ## [main] 0.0.9 -- 修复 0.0.8 引入的bug:system prompt 返回的是索引而非真实内容 -- +- 修复 0.0.8 引入的 bug:system prompt 返回的是索引而非真实内容 +- 新特性:支持同时连入多个 mcp server +- 新特性:更新协议内容,支持 streamable http 协议,未来将逐步取代 SSE 的连接方式 ## [main] 0.0.8 - 大模型 API 测试时更加完整的报错 diff --git a/renderer/.env.development b/renderer/.env.development index 254e7de..7d53432 100644 --- a/renderer/.env.development +++ b/renderer/.env.development @@ -1 +1,2 @@ +VITE_USE_AUTH=false VITE_WEBSOCKET_URL=ws://localhost:8282 \ No newline at end of file diff --git a/renderer/.env.production b/renderer/.env.production index 254e7de..7d53432 100644 --- a/renderer/.env.production +++ b/renderer/.env.production @@ -1 +1,2 @@ +VITE_USE_AUTH=false VITE_WEBSOCKET_URL=ws://localhost:8282 \ No newline at end of file diff --git a/renderer/README.md b/renderer/README.md index 7493165..2b9f1f2 100644 --- a/renderer/README.md +++ b/renderer/README.md @@ -1,33 +1,18 @@ -# test-vite +## dev -This template should help get you started developing with Vue 3 in Vite. +如果想要部署到公网中,想要通过密码认证才能进入,进行如下步骤: -## Recommended IDE Setup - -[VSCode](https://code.visualstudio.com/) + [Volar](https://marketplace.visualstudio.com/items?itemName=Vue.volar) (and disable Vetur). - -## Type Support for `.vue` Imports in TS - -TypeScript cannot handle type information for `.vue` imports by default, so we replace the `tsc` CLI with `vue-tsc` for type checking. In editors, we need [Volar](https://marketplace.visualstudio.com/items?itemName=Vue.volar) to make the TypeScript language service aware of `.vue` types. - -## Customize configuration - -See [Vite Configuration Reference](https://vite.dev/config/). - -## Project Setup - -```sh -npm install +```bash +touch .env.website.local ``` -### Compile and Hot-Reload for Development +写入: -```sh -npm run dev +```toml +VITE_USE_AUTH=true +VITE_WEBSOCKET_URL=wss:///<路径> ``` -### Type-Check, Compile and Minify for Production +使用 `npm run serve:website` 进行测试(服务端使用 ts-node src/server.ts) -```sh -npm run build -``` +使用 `npm run build:website` 进行打包 \ No newline at end of file diff --git a/renderer/public/CascadiaCode.woff2 b/renderer/public/CascadiaCode.woff2 new file mode 100644 index 0000000..ed0335c Binary files /dev/null and b/renderer/public/CascadiaCode.woff2 differ diff --git a/renderer/src/App.vue b/renderer/src/App.vue index 622c1aa..1370127 100644 --- a/renderer/src/App.vue +++ b/renderer/src/App.vue @@ -4,7 +4,7 @@ - + @@ -37,8 +37,8 @@ bridge.addCommandListener('hello', data => { const route = useRoute(); const router = useRouter(); -const password = Boolean(import.meta.env.VITE_USE_PASSWORD); -privilegeStatus.allow = !Boolean(password); +const useAuth = Boolean(import.meta.env.VITE_USE_AUTH); +privilegeStatus.allow = !Boolean(useAuth); onMounted(async () => { // 初始化 css diff --git a/renderer/src/api/message-bridge.ts b/renderer/src/api/message-bridge.ts index 70ecbed..1229688 100644 --- a/renderer/src/api/message-bridge.ts +++ b/renderer/src/api/message-bridge.ts @@ -7,9 +7,9 @@ export interface VSCodeMessage { callbackId?: string; } -export interface RestFulResponse { +export interface RestFulResponse { code: number; - msg: any; + msg: T; } export type MessageHandler = (message: VSCodeMessage) => void; @@ -206,7 +206,7 @@ export class MessageBridge { * @param data * @returns */ - public commandRequest(command: string, data?: any) { + public commandRequest(command: string, data?: any): Promise> { return new Promise((resolve, reject) => { this.addCommandListener(command, (data) => { resolve(data as RestFulResponse); diff --git a/renderer/src/hook/type.ts b/renderer/src/hook/type.ts index 1967de9..728404f 100644 --- a/renderer/src/hook/type.ts +++ b/renderer/src/hook/type.ts @@ -149,7 +149,7 @@ export type APIRequest = | ToolCallRequest; export interface IStdioConnectionItem { - type: 'stdio'; + type: 'STDIO'; name: string; command: string; args: string[]; @@ -159,7 +159,7 @@ export interface IStdioConnectionItem { } export interface ISSEConnectionItem { - type: 'sse'; + type: 'SSE'; name: string; url: string; oauth?: string; @@ -169,13 +169,13 @@ export interface ISSEConnectionItem { export interface IStdioLaunchSignature { - type: 'stdio'; + type: 'STDIO'; commandString: string; cwd: string; } export interface ISSELaunchSignature { - type:'sse'; + type:'SSE'; url: string; oauth: string; } diff --git a/renderer/src/views/connect/connection-item.ts b/renderer/src/views/connect/connection-item.ts new file mode 100644 index 0000000..891b733 --- /dev/null +++ b/renderer/src/views/connect/connection-item.ts @@ -0,0 +1,58 @@ +import { reactive, type Reactive } from "vue"; + +export type ConnectionType = 'STDIO' | 'SSE' | 'STREAMABLE_HTTP'; + +export interface ConnectionTypeOptionItem { + value: ConnectionType; + label: string; +} + +export const connectionSelectDataViewOption: ConnectionTypeOptionItem[] = [ + { + value: 'STDIO', + label: 'STDIO' + }, + { + value: 'SSE', + label: 'SSE' + }, + { + value: 'STREAMABLE_HTTP', + label: 'STREAMABLE_HTTP' + } +] + +export interface IConnectionArgs { + type: ConnectionType; + commandString?: string; + cwd?: string; + urlString?: string; +} + +export class McpClient { + public clientId?: string; + public name?: string; + public version?: string; + public connectionArgs: Reactive; + + constructor() { + this.connectionArgs = reactive({ + type: 'STDIO', + commandString: '', + cwd: '', + urlString: '' + }); + } + + async connect() { + + } +} + +// 用于描述一个连接的数据结构 +export interface McpServer { + type: ConnectionType; + clientId: string; + name: string; + +} \ No newline at end of file diff --git a/renderer/src/views/connect/connection.ts b/renderer/src/views/connect/connection.ts index 82d23a2..8715b99 100644 --- a/renderer/src/views/connect/connection.ts +++ b/renderer/src/views/connect/connection.ts @@ -5,8 +5,15 @@ import { ElLoading, ElMessage } from 'element-plus'; import { getPlatform, type OpenMcpSupportPlatform } from '@/api/platform'; import { getTour, loadSetting } from '@/hook/setting'; import { loadPanels } from '@/hook/panel'; +import type { ConnectionType } from './connection-item'; -export const connectionMethods = reactive({ +export const connectionMethods = reactive<{ + current: ConnectionType, + data: { + value: ConnectionType, + label: string + }[] +}>({ current: 'STDIO', data: [ { @@ -16,6 +23,10 @@ export const connectionMethods = reactive({ { value: 'SSE', label: 'SSE' + }, + { + value: 'STREAMABLE_HTTP', + label: 'STREAMABLE_HTTP' } ] }); @@ -23,6 +34,7 @@ export const connectionMethods = reactive({ export const connectionSettingRef = ref(null); export const connectionLogRef = ref(null); +// 主 mcp 服务器的连接参数 export const connectionArgs = reactive({ commandString: '', cwd: '', @@ -41,6 +53,13 @@ export interface IConnectionEnv { newValue: string } +export interface ConnectionResult { + status: string + clientId: string + name: string + version: string +} + export const connectionEnv = reactive({ data: [], newKey: '', @@ -56,9 +75,6 @@ export function makeEnv() { } -// 定义连接类型 -type ConnectionType = 'STDIO' | 'SSE'; - // 定义命令行参数接口 export interface McpOptions { connectionType: ConnectionType; @@ -74,6 +90,14 @@ export interface McpOptions { clientVersion?: string; } +/** + * @description 试图启动 mcp 服务器,它会 + * 1. 请求启动参数 + * 2. 启动 mcp 服务器 + * 3. 将本次的启动参数同步到本地 + * @param option + * @returns + */ export async function doConnect( option: { namespace: OpenMcpSupportPlatform @@ -86,32 +110,21 @@ export async function doConnect( updateCommandString = true } = option; + // 如果是初始化,则需要请求启动参数 if (updateCommandString) { pinkLog('请求启动参数'); const connectionItem = await getLaunchSignature(namespace + '/launch-signature'); - - if (connectionItem.type ==='stdio') { - connectionMethods.current = 'STDIO'; - connectionArgs.commandString = connectionItem.commandString; - connectionArgs.cwd = connectionItem.cwd; - - if (connectionArgs.commandString.length === 0) { - return; - } - } else { - connectionMethods.current = 'SSE'; - connectionArgs.urlString = connectionItem.url || ''; - - if (connectionArgs.urlString.length === 0) { - return; - } - } + connectionMethods.current = connectionItem.type; + connectionArgs.commandString = connectionItem.commandString || ''; + connectionArgs.cwd = connectionItem.cwd || ''; + connectionArgs.oauth = connectionItem.oauth || ''; + connectionArgs.urlString = connectionItem.url || ''; } if (connectionMethods.current === 'STDIO') { - await launchStdio(namespace); + return await launchStdio(namespace); } else { - await launchSSE(namespace); + return await launchRemote(namespace); } } @@ -128,26 +141,25 @@ async function launchStdio(namespace: string) { command: command, args: commandComponents, cwd: connectionArgs.cwd, - clientName: 'openmcp.connect.stdio', + clientName: 'openmcp.connect.STDIO', clientVersion: '0.0.1', env }; - const { code, msg } = await bridge.commandRequest('connect', connectOption); + const { code, msg } = await bridge.commandRequest('connect', connectOption); connectionResult.success = (code === 200); if (code === 200) { - connectionResult.logString.push({ - type: 'info', - message: msg - }); - const res = await getServerVersion() as { name: string, version: string }; - connectionResult.serverInfo.name = res.name || ''; - connectionResult.serverInfo.version = res.version || ''; + const message = `connect to ${msg.name} ${msg.version} success, clientId: ${msg.clientId}`; + connectionResult.logString.push({ type: 'info', message }); - // 同步信息到 vscode + connectionResult.serverInfo.name = msg.name || ''; + connectionResult.serverInfo.version = msg.version || ''; + connectionResult.clientId = msg.clientId || ''; + + // 同步信息到 后端 const commandComponents = connectionArgs.commandString.split(/\s+/g); const command = commandComponents[0]; commandComponents.shift(); @@ -155,7 +167,7 @@ async function launchStdio(namespace: string) { const clientStdioConnectionItem = { serverInfo: connectionResult.serverInfo, connectionType: 'STDIO', - name: 'openmcp.connect.stdio', + name: 'openmcp.connect.STDIO', command: command, args: commandComponents, cwd: connectionArgs.cwd, @@ -168,46 +180,49 @@ async function launchStdio(namespace: string) { }); } else { + const messaage = msg.toString(); connectionResult.logString.push({ type: 'error', - message: msg + message: messaage }); - ElMessage.error(msg); + ElMessage.error(messaage); } } -async function launchSSE(namespace: string) { +async function launchRemote(namespace: string) { const bridge = useMessageBridge(); const env = makeEnv(); const connectOption: McpOptions = { - connectionType: 'SSE', + connectionType: connectionMethods.current, url: connectionArgs.urlString, - clientName: 'openmcp.connect.sse', + clientName: 'openmcp.connect.' + connectionMethods.current, clientVersion: '0.0.1', env }; - const { code, msg } = await bridge.commandRequest('connect', connectOption); + const { code, msg } = await bridge.commandRequest('connect', connectOption); connectionResult.success = (code === 200); if (code === 200) { + const message = `connect to ${msg.name} ${msg.version} success, clientId: ${msg.clientId}`; + connectionResult.logString.push({ type: 'info', - message: msg + message: message }); - const res = await getServerVersion() as { name: string, version: string }; - connectionResult.serverInfo.name = res.name || ''; - connectionResult.serverInfo.version = res.version || ''; + connectionResult.serverInfo.name = msg.name || ''; + connectionResult.serverInfo.version = msg.version || ''; + connectionResult.clientId = msg.clientId || ''; // 同步信息到 vscode const clientSseConnectionItem = { serverInfo: connectionResult.serverInfo, - connectionType: 'SSE', - name: 'openmcp.connect.sse', + connectionType: connectionMethods.current, + name: 'openmcp.connect.' + connectionMethods.current, url: connectionArgs.urlString, oauth: connectionArgs.oauth, env: env @@ -219,12 +234,13 @@ async function launchSSE(namespace: string) { }); } else { + const message = msg.toString(); connectionResult.logString.push({ type: 'error', - message: msg + message: message }); - ElMessage.error(msg); + ElMessage.error(message); } } @@ -247,32 +263,19 @@ export const connectionResult = reactive<{ serverInfo: { name: string, version: string - } + }, + clientId: string }>({ success: false, logString: [], serverInfo: { name: '', version: '' - } + }, + clientId: '' }); -export function getServerVersion() { - return new Promise((resolve, reject) => { - const bridge = useMessageBridge(); - bridge.addCommandListener('server/version', data => { - if (data.code === 200) { - resolve(data.msg); - } else { - reject(data.msg); - } - }, { once: true }); - bridge.postMessage({ - command: 'server/version', - }); - }); -} export const envVarStatus = { launched: false diff --git a/servers/uv.lock b/servers/uv.lock index 71c408f..d34f9b7 100644 --- a/servers/uv.lock +++ b/servers/uv.lock @@ -175,10 +175,10 @@ wheels = [ ] [[package]] -name = "httpx-sse" +name = "httpx-SSE" version = "0.4.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624 } +sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-SSE-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624 } wheels = [ { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819 }, ] @@ -258,10 +258,10 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "httpx" }, - { name = "httpx-sse" }, + { name = "httpx-SSE" }, { name = "pydantic" }, { name = "pydantic-settings" }, - { name = "sse-starlette" }, + { name = "SSE-starlette" }, { name = "starlette" }, { name = "uvicorn" }, ] @@ -477,7 +477,7 @@ wheels = [ ] [[package]] -name = "sse-starlette" +name = "SSE-starlette" version = "2.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ diff --git a/service/src/common/index.dto.ts b/service/src/common/index.dto.ts index e6ab551..9d302f7 100644 --- a/service/src/common/index.dto.ts +++ b/service/src/common/index.dto.ts @@ -3,9 +3,13 @@ import { McpClient } from "../mcp/client.service"; export type RequestClientType = McpClient | undefined; +export interface RequestData { + clientId?: string; + [key: string]: any; +} + export type RequestHandler = ( - client: RequestClientType, - data: T, + data: T & RequestData, webview: PostMessageble ) => Promise; diff --git a/service/src/common/router.ts b/service/src/common/router.ts index 297ebd1..5ff9ab9 100644 --- a/service/src/common/router.ts +++ b/service/src/common/router.ts @@ -3,7 +3,6 @@ import type { PostMessageble } from "../hook/adapter"; import { LlmController } from "../llm/llm.controller"; import { ClientController } from "../mcp/client.controller"; import { ConnectController } from "../mcp/connect.controller"; -import { client } from "../mcp/connect.service"; import { OcrController } from "../mcp/ocr.controller"; import { PanelController } from "../panel/panel.controller"; import { SettingController } from "../setting/setting.controller"; @@ -24,7 +23,7 @@ export async function routeMessage(command: string, data: any, webview: PostMess try { // TODO: select client based on something - const res = await handler(client, data, webview); + const res = await handler(data, webview); // res.code = -1 代表当前请求不需要返回发送 if (res.code >= 0) { diff --git a/service/src/hook/adapter.ts b/service/src/hook/adapter.ts index b037468..825625e 100644 --- a/service/src/hook/adapter.ts +++ b/service/src/hook/adapter.ts @@ -2,7 +2,7 @@ import { WebSocket } from 'ws'; import { EventEmitter } from 'events'; import { routeMessage } from '../common/router'; import { McpOptions } from '../mcp/client.dto'; -import { client, connectService } from '../mcp/connect.service'; +import { clientMap, connectService } from '../mcp/connect.service'; // WebSocket 消息格式 export interface WebSocketMessage { @@ -114,9 +114,12 @@ export class TaskLoopAdapter { * @param mcpOption */ public async connectMcpServer(mcpOption: McpOptions) { - const res = await connectService(undefined, mcpOption); + const res = await connectService(mcpOption); if (res.code === 200) { console.log('✅ 成功连接 mcp 服务器: ' + res.msg); + + const uuid = res.msg.uuid; + const client = clientMap.get(uuid); const version = client?.getServerVersion(); console.log(version); } else { @@ -129,14 +132,19 @@ export class TaskLoopAdapter { * @returns */ public async listTools() { - const tools = await client?.listTools(); - if (tools?.tools) { - return tools.tools.map((tool) => { - const enabledTools = { ...tool, enabled: true }; - return enabledTools; - }); + const tools = []; + for (const client of clientMap.values()) { + const clientTools = await client?.listTools(); + if (clientTools?.tools) { + const enabledTools = clientTools.tools.map((tool) => { + const enabledTools = {...tool, enabled: true }; + return enabledTools; + }); + tools.push(...enabledTools); + } } - return []; + + return tools; } } diff --git a/service/src/llm/llm.controller.ts b/service/src/llm/llm.controller.ts index 4252c51..08804bf 100644 --- a/service/src/llm/llm.controller.ts +++ b/service/src/llm/llm.controller.ts @@ -1,13 +1,17 @@ import { Controller, RequestClientType } from "../common"; +import { RequestData } from "../common/index.dto"; import { PostMessageble } from "../hook/adapter"; +import { getClient } from "../mcp/connect.service"; import { abortMessageService, streamingChatCompletion } from "./llm.service"; export class LlmController { @Controller('llm/chat/completions') - async chatCompletion(client: RequestClientType, data: any, webview: PostMessageble) { + async chatCompletion(data: RequestData, webview: PostMessageble) { let { tools = [] } = data; + const client = getClient(data.clientId); + if (tools.length > 0 && !client) { return { code: 501, @@ -37,7 +41,7 @@ export class LlmController { } @Controller('llm/chat/completions/abort') - async abortChatCompletion(client: RequestClientType, data: any, webview: PostMessageble) { + async abortChatCompletion(data: RequestData, webview: PostMessageble) { return abortMessageService(data, webview); } diff --git a/service/src/main.ts b/service/src/main.ts index 6fdde5a..e2983e1 100644 --- a/service/src/main.ts +++ b/service/src/main.ts @@ -28,13 +28,13 @@ const logger = pino({ export type MessageHandler = (message: VSCodeMessage) => void; interface IStdioLaunchSignature { - type: 'stdio'; + type: 'STDIO'; commandString: string; cwd: string; } interface ISSELaunchSignature { - type:'sse'; + type:'SSE'; url: string; oauth: string; } @@ -43,7 +43,7 @@ export type ILaunchSigature = IStdioLaunchSignature | ISSELaunchSignature; function refreshConnectionOption(envPath: string) { const defaultOption = { - type:'stdio', + type:'STDIO', command: 'mcp', args: ['run', 'main.py'], cwd: '../server' @@ -76,7 +76,7 @@ function updateConnectionOption(data: any) { if (data.connectionType === 'STDIO') { const connectionItem = { - type: 'stdio', + type: 'STDIO', command: data.command, args: data.args, cwd: data.cwd.replace(/\\/g, '/') @@ -85,7 +85,7 @@ function updateConnectionOption(data: any) { fs.writeFileSync(envPath, JSON.stringify(connectionItem, null, 4)); } else { const connectionItem = { - type: 'sse', + type: 'SSE', url: data.url, oauth: data.oauth }; @@ -124,14 +124,14 @@ wss.on('connection', (ws: any) => { switch (command) { case 'web/launch-signature': - const launchResultMessage: ILaunchSigature = option.type === 'stdio' ? + const launchResultMessage: ILaunchSigature = option.type === 'STDIO' ? { - type: 'stdio', + type: 'STDIO', commandString: option.command + ' ' + option.args.join(' '), cwd: option.cwd || '' } : { - type: 'sse', + type: 'SSE', url: option.url, oauth: option.oauth || '' }; diff --git a/service/src/mcp/client.controller.ts b/service/src/mcp/client.controller.ts index 0238fb6..19efaf7 100644 --- a/service/src/mcp/client.controller.ts +++ b/service/src/mcp/client.controller.ts @@ -1,11 +1,14 @@ -import { Controller, RequestClientType } from "../common"; +import { Controller } from "../common"; +import { RequestData } from "../common/index.dto"; import { PostMessageble } from "../hook/adapter"; import { postProcessMcpToolcallResponse } from "./client.service"; +import { getClient } from "./connect.service"; export class ClientController { @Controller('server/version') - async getServerVersion(client: RequestClientType, data: any, webview: PostMessageble) { + async getServerVersion(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { return { code: 501, @@ -21,7 +24,8 @@ export class ClientController { } @Controller('prompts/list') - async listPrompts(client: RequestClientType, data: any, webview: PostMessageble) { + async listPrompts(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { const connectResult = { code: 501, @@ -39,7 +43,8 @@ export class ClientController { } @Controller('prompts/get') - async getPrompt(client: RequestClientType, option: any, webview: PostMessageble) { + async getPrompt(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { return { code: 501, @@ -47,7 +52,7 @@ export class ClientController { }; } - const prompt = await client.getPrompt(option.promptId, option.args || {}); + const prompt = await client.getPrompt(data.promptId, data.args || {}); return { code: 200, msg: prompt @@ -55,7 +60,8 @@ export class ClientController { } @Controller('resources/list') - async listResources(client: RequestClientType, data: any, webview: PostMessageble) { + async listResources(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { return { code: 501, @@ -71,8 +77,8 @@ export class ClientController { } @Controller('resources/templates/list') - async listResourceTemplates(client: RequestClientType, data: any, webview: PostMessageble) { - + async listResourceTemplates(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { return { code: 501, @@ -88,7 +94,8 @@ export class ClientController { } @Controller('resources/read') - async readResource(client: RequestClientType, option: any, webview: PostMessageble) { + async readResource(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { return { code: 501, @@ -96,7 +103,7 @@ export class ClientController { }; } - const resource = await client.readResource(option.resourceUri); + const resource = await client.readResource(data.resourceUri); console.log(resource); return { @@ -106,7 +113,8 @@ export class ClientController { } @Controller('tools/list') - async listTools(client: RequestClientType, data: any, webview: PostMessageble) { + async listTools(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { return { code: 501, @@ -122,7 +130,8 @@ export class ClientController { } @Controller('tools/call') - async callTool(client: RequestClientType, option: any, webview: PostMessageble) { + async callTool(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { return { code: 501, @@ -131,18 +140,13 @@ export class ClientController { } const toolResult = await client.callTool({ - name: option.toolName, - arguments: option.toolArgs, - callToolOption: option.callToolOption + name: data.toolName, + arguments: data.toolArgs, + callToolOption: data.callToolOption }); - // console.log(JSON.stringify(toolResult, null, 2)); - postProcessMcpToolcallResponse(toolResult, webview); - // console.log(JSON.stringify(toolResult, null, 2)); - - return { code: 200, msg: toolResult diff --git a/service/src/mcp/client.dto.ts b/service/src/mcp/client.dto.ts index cd30ce0..84f1ce6 100644 --- a/service/src/mcp/client.dto.ts +++ b/service/src/mcp/client.dto.ts @@ -1,5 +1,5 @@ -import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; -import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/STDIO.js"; +import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/SSE.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { Implementation } from "@modelcontextprotocol/sdk/types"; diff --git a/service/src/mcp/client.service.ts b/service/src/mcp/client.service.ts index 29ebe99..2106012 100644 --- a/service/src/mcp/client.service.ts +++ b/service/src/mcp/client.service.ts @@ -1,7 +1,7 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; -import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/STDIO.js"; +import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/SSE.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import type { McpOptions, McpTransport, IServerVersion, ToolCallResponse, ToolCallContent } from './client.dto'; import { PostMessageble } from "../hook/adapter"; diff --git a/service/src/mcp/connect.controller.ts b/service/src/mcp/connect.controller.ts index 6f979bc..92c13fc 100644 --- a/service/src/mcp/connect.controller.ts +++ b/service/src/mcp/connect.controller.ts @@ -1,17 +1,19 @@ -import { Controller, RequestClientType } from '../common'; +import { Controller } from '../common'; import { PostMessageble } from '../hook/adapter'; -import { connectService } from './connect.service'; +import { RequestData } from '../common/index.dto'; +import { connectService, getClient } from './connect.service'; export class ConnectController { @Controller('connect') - async connect(client: RequestClientType, data: any, webview: PostMessageble) { - const res = await connectService(client, data); + async connect(data: any, webview: PostMessageble) { + const res = await connectService(data); return res; } @Controller('lookup-env-var') - async lookupEnvVar(client: RequestClientType, data: any, webview: PostMessageble) { + async lookupEnvVar(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); const { keys } = data; const values = keys.map((key: string) => process.env[key] || ''); @@ -22,7 +24,8 @@ export class ConnectController { } @Controller('ping') - async ping(client: RequestClientType, data: any, webview: PostMessageble) { + async ping(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); if (!client) { const connectResult = { code: 501, diff --git a/service/src/mcp/connect.service.ts b/service/src/mcp/connect.service.ts index 83f2cef..ed9461f 100644 --- a/service/src/mcp/connect.service.ts +++ b/service/src/mcp/connect.service.ts @@ -3,10 +3,12 @@ import { RequestClientType } from '../common'; import { connect } from './client.service'; import { RestfulResponse } from '../common/index.dto'; import { McpOptions } from './client.dto'; +import { randomUUID } from 'node:crypto'; - -// TODO: 更多的 client -export let client: RequestClientType = undefined; +export const clientMap: Map = new Map(); +export function getClient(clientId?: string): RequestClientType | undefined { + return clientMap.get(clientId || ''); +} export function tryGetRunCommandError(command: string, args: string[] = [], cwd?: string): string | null { try { @@ -15,7 +17,7 @@ export function tryGetRunCommandError(command: string, args: string[] = [], cwd? const result = spawnSync(command, args, { cwd: cwd || process.cwd(), - stdio: 'pipe', + STDIO: 'pipe', encoding: 'utf-8' }); @@ -32,7 +34,6 @@ export function tryGetRunCommandError(command: string, args: string[] = [], cwd? } export async function connectService( - _client: RequestClientType, option: McpOptions ): Promise { try { @@ -48,19 +49,25 @@ export async function connectService( }); } - client = await connect(option); + const client = await connect(option); + const uuid = randomUUID(); + clientMap.set(uuid, client); + + const versionInfo = client.getServerVersion(); + const connectResult = { code: 200, - msg: 'Connect to OpenMCP successfully\nWelcome back, Kirigaya' + msg: { + status: 'success', + clientId: uuid, + name: versionInfo?.name, + version: versionInfo?.version + } }; return connectResult; } catch (error) { - - console.log('meet error'); - console.log(error); - // TODO: 这边获取到的 error 不够精致,如何才能获取到更加精准的错误 // 比如 error: Failed to spawn: `server.py` // Caused by: No such file or directory (os error 2) diff --git a/service/src/panel/panel.controller.ts b/service/src/panel/panel.controller.ts index f7239a1..92332ef 100644 --- a/service/src/panel/panel.controller.ts +++ b/service/src/panel/panel.controller.ts @@ -1,13 +1,16 @@ -import { Controller, RequestClientType } from "../common"; +import { Controller } from "../common"; import { PostMessageble } from "../hook/adapter"; +import { RequestData } from "../common/index.dto"; +import { getClient } from "../mcp/connect.service"; import { systemPromptDB } from "../hook/db"; import { loadTabSaveConfig, saveTabSaveConfig } from "./panel.service"; export class PanelController { @Controller('panel/save') - async savePanel(client: RequestClientType, data: any, webview: PostMessageble) { + async savePanel(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); const serverInfo = client?.getServerVersion(); - saveTabSaveConfig(serverInfo, data); + saveTabSaveConfig(serverInfo, data); return { code: 200, @@ -15,11 +18,11 @@ export class PanelController { }; } - @Controller('panel/load') - async loadPanel(client: RequestClientType, data: any, webview: PostMessageble) { + async loadPanel(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); const serverInfo = client?.getServerVersion(); - const config = loadTabSaveConfig(serverInfo); + const config = loadTabSaveConfig(serverInfo); return { code: 200, @@ -28,7 +31,8 @@ export class PanelController { } @Controller('system-prompts/set') - async setSystemPrompt(client: RequestClientType, data: any, webview: PostMessageble) { + async setSystemPrompt(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); const { name, content } = data; await systemPromptDB.insert({ @@ -44,7 +48,8 @@ export class PanelController { } @Controller('system-prompts/delete') - async deleteSystemPrompt(client: RequestClientType, data: any, webview: PostMessageble) { + async deleteSystemPrompt(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); const { name } = data; await systemPromptDB.delete(name); return { @@ -54,7 +59,8 @@ export class PanelController { } @Controller('system-prompts/save') - async saveSystemPrompts(client: RequestClientType, data: any, webview: PostMessageble) { + async saveSystemPrompts(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); const { prompts } = data; await Promise.all(prompts.map((prompt: any) => { @@ -72,8 +78,8 @@ export class PanelController { } @Controller('system-prompts/load') - async loadSystemPrompts(client: RequestClientType, data: any, webview: PostMessageble) { - + async loadSystemPrompts(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); const queryPrompts = await systemPromptDB.findAll(); const prompts = []; for (const prompt of queryPrompts) { diff --git a/service/src/server.ts b/service/src/server.ts index d5dfb4b..b036b45 100644 --- a/service/src/server.ts +++ b/service/src/server.ts @@ -29,13 +29,13 @@ const logger = pino({ export type MessageHandler = (message: VSCodeMessage) => void; interface IStdioLaunchSignature { - type: 'stdio'; + type: 'STDIO'; commandString: string; cwd: string; } interface ISSELaunchSignature { - type: 'sse'; + type: 'SSE'; url: string; oauth: string; } @@ -44,7 +44,7 @@ export type ILaunchSigature = IStdioLaunchSignature | ISSELaunchSignature; function refreshConnectionOption(envPath: string) { const defaultOption = { - type: 'stdio', + type: 'STDIO', command: 'mcp', args: ['run', 'main.py'], cwd: '../server' @@ -84,7 +84,7 @@ function updateConnectionOption(data: any) { if (data.connectionType === 'STDIO') { const connectionItem = { - type: 'stdio', + type: 'STDIO', command: data.command, args: data.args, cwd: data.cwd.replace(/\\/g, '/') @@ -93,7 +93,7 @@ function updateConnectionOption(data: any) { fs.writeFileSync(envPath, JSON.stringify(connectionItem, null, 4)); } else { const connectionItem = { - type: 'sse', + type: 'SSE', url: data.url, oauth: data.oauth }; @@ -155,14 +155,14 @@ wss.on('connection', (ws: any) => { switch (command) { case 'web/launch-signature': - const launchResultMessage: ILaunchSigature = option.type === 'stdio' ? + const launchResultMessage: ILaunchSigature = option.type === 'STDIO' ? { - type: 'stdio', + type: 'STDIO', commandString: option.command + ' ' + option.args.join(' '), cwd: option.cwd || '' } : { - type: 'sse', + type: 'SSE', url: option.url, oauth: option.oauth || '' }; diff --git a/service/src/setting/setting.controller.ts b/service/src/setting/setting.controller.ts index 12ddc3b..3a0a6b2 100644 --- a/service/src/setting/setting.controller.ts +++ b/service/src/setting/setting.controller.ts @@ -1,11 +1,14 @@ -import { Controller, RequestClientType } from "../common"; +import { Controller } from "../common"; import { PostMessageble } from "../hook/adapter"; +import { RequestData } from "../common/index.dto"; +import { getClient } from "../mcp/connect.service"; import { getTour, loadSetting, saveSetting, setTour } from "./setting.service"; export class SettingController { @Controller('setting/save') - async saveSetting(client: RequestClientType, data: any, webview: PostMessageble) { + async saveSetting(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); saveSetting(data); console.log('Settings saved successfully'); @@ -16,8 +19,8 @@ export class SettingController { } @Controller('setting/load') - async loadSetting(client: RequestClientType, data: any, webview: PostMessageble) { - + async loadSetting(data: RequestData, webview: PostMessageble) { + const client = getClient(data.clientId); const config = loadSetting(); return { code: 200, @@ -26,10 +29,8 @@ export class SettingController { } @Controller('setting/set-tour') - async setTourController(client: RequestClientType, data: any, webview: PostMessageble) { - + async setTourController(data: any, webview: PostMessageble) { const { userHasReadGuide } = data; - setTour(userHasReadGuide); return { @@ -39,7 +40,7 @@ export class SettingController { } @Controller('setting/get-tour') - async getTourController(client: RequestClientType, data: any, webview: PostMessageble) { + async getTourController(data: any, webview: PostMessageble) { const { userHasReadGuide } = getTour(); diff --git a/software/src/main.ts b/software/src/main.ts index bbbdc87..3d098da 100644 --- a/software/src/main.ts +++ b/software/src/main.ts @@ -40,14 +40,14 @@ function createWindow(): void { switch (command) { case 'electron/launch-signature': - const launchResultMessage: ILaunchSigature = option.type === 'stdio' ? + const launchResultMessage: ILaunchSigature = option.type === 'STDIO' ? { - type: 'stdio', + type: 'STDIO', commandString: option.command + ' ' + option.args.join(' '), cwd: option.cwd || '' } : { - type: 'sse', + type: 'SSE', url: option.url, oauth: option.oauth || '' }; diff --git a/software/src/util.ts b/software/src/util.ts index 21fa7d2..43584f6 100644 --- a/software/src/util.ts +++ b/software/src/util.ts @@ -24,13 +24,13 @@ export class ElectronIPCLike { interface IStdioLaunchSignature { - type: 'stdio'; + type: 'STDIO'; commandString: string; cwd: string; } interface ISSELaunchSignature { - type:'sse'; + type:'SSE'; url: string; oauth: string; } @@ -39,7 +39,7 @@ export type ILaunchSigature = IStdioLaunchSignature | ISSELaunchSignature; export function refreshConnectionOption(envPath: string) { const defaultOption = { - type:'stdio', + type:'STDIO', command: 'mcp', args: ['run', 'main.py'], cwd: '../server' @@ -80,7 +80,7 @@ export function updateConnectionOption(data: any) { if (data.connectionType === 'STDIO') { const connectionItem = { - type: 'stdio', + type: 'STDIO', command: data.command, args: data.args, cwd: data.cwd.replace(/\\/g, '/') @@ -89,7 +89,7 @@ export function updateConnectionOption(data: any) { fs.writeFileSync(envPath, JSON.stringify(connectionItem, null, 4)); } else { const connectionItem = { - type: 'sse', + type: 'SSE', url: data.url, oauth: data.oauth }; diff --git a/src/global.ts b/src/global.ts index 504733f..9cef060 100644 --- a/src/global.ts +++ b/src/global.ts @@ -7,7 +7,7 @@ export type FsPath = string; export const panels = new Map(); export interface IStdioConnectionItem { - type: 'stdio'; + type: 'STDIO'; name: string; version?: string; command: string; @@ -18,7 +18,7 @@ export interface IStdioConnectionItem { } export interface ISSEConnectionItem { - type: 'sse'; + type: 'SSE'; name: string; version: string; url: string; @@ -29,13 +29,13 @@ export interface ISSEConnectionItem { interface IStdioLaunchSignature { - type: 'stdio'; + type: 'STDIO'; commandString: string; cwd: string; } interface ISSELaunchSignature { - type:'sse'; + type:'SSE'; url: string; oauth: string; } @@ -123,7 +123,7 @@ export function getWorkspaceConnectionConfig() { if (item.filePath && item.filePath.startsWith('{workspace}')) { item.filePath = item.filePath.replace('{workspace}', workspacePath).replace(/\\/g, '/'); } - if (item.type === 'stdio' && item.cwd && item.cwd.startsWith('{workspace}')) { + if (item.type === 'STDIO' && item.cwd && item.cwd.startsWith('{workspace}')) { item.cwd = item.cwd.replace('{workspace}', workspacePath).replace(/\\/g, '/'); } } @@ -169,7 +169,7 @@ export function saveWorkspaceConnectionConfig(workspace: string) { if (item.filePath && item.filePath.replace(/\\/g, '/').startsWith(workspacePath)) { item.filePath = item.filePath.replace(workspacePath, '{workspace}').replace(/\\/g, '/'); } - if (item.type ==='stdio' && item.cwd && item.cwd.replace(/\\/g, '/').startsWith(workspacePath)) { + if (item.type ==='STDIO' && item.cwd && item.cwd.replace(/\\/g, '/').startsWith(workspacePath)) { item.cwd = item.cwd.replace(workspacePath, '{workspace}').replace(/\\/g, '/'); } } @@ -213,7 +213,7 @@ export function updateWorkspaceConnectionConfig( if (data.connectionType === 'STDIO') { const connectionItem: IStdioConnectionItem = { - type: 'stdio', + type: 'STDIO', name: data.serverInfo.name, version: data.serverInfo.version, command: data.command, @@ -234,7 +234,7 @@ export function updateWorkspaceConnectionConfig( } else { const connectionItem: ISSEConnectionItem = { - type: 'sse', + type: 'SSE', name: data.serverInfo.name, version: data.serverInfo.version, url: data.url, @@ -267,7 +267,7 @@ export function updateInstalledConnectionConfig( if (data.connectionType === 'STDIO') { const connectionItem: IStdioConnectionItem = { - type: 'stdio', + type: 'STDIO', name: data.serverInfo.name, version: data.serverInfo.version, command: data.command, @@ -287,7 +287,7 @@ export function updateInstalledConnectionConfig( } else { const connectionItem: ISSEConnectionItem = { - type: 'sse', + type: 'SSE', name: data.serverInfo.name, version: data.serverInfo.version, url: data.url, diff --git a/src/sidebar/installed.service.ts b/src/sidebar/installed.service.ts index 42b2720..542ebc3 100644 --- a/src/sidebar/installed.service.ts +++ b/src/sidebar/installed.service.ts @@ -52,7 +52,7 @@ export async function validateAndGetCommandPath(commandString: string, cwd?: str export async function acquireInstalledConnection(): Promise { // 让用户选择连接类型 - const connectionType = await vscode.window.showQuickPick(['stdio', 'sse'], { + const connectionType = await vscode.window.showQuickPick(['STDIO', 'SSE'], { placeHolder: '请选择连接类型', canPickMany: false }); @@ -61,7 +61,7 @@ export async function acquireInstalledConnection(): Promise { // 让用户选择连接类型 - const connectionType = await vscode.window.showQuickPick(['stdio', 'sse'], { + const connectionType = await vscode.window.showQuickPick(['STDIO', 'SSE'], { placeHolder: '请选择连接类型' }); @@ -14,7 +14,7 @@ export async function acquireUserCustomConnection(): Promise