diff --git a/renderer/src/api/message-bridge.ts b/renderer/src/api/message-bridge.ts index c561845..7b10ce1 100644 --- a/renderer/src/api/message-bridge.ts +++ b/renderer/src/api/message-bridge.ts @@ -221,9 +221,13 @@ export class MessageBridge { // 单例实例 let messageBridge: MessageBridge; +export function createMessageBridge(setupSignature: any) { + messageBridge = new MessageBridge(setupSignature); +} + // 向外暴露一个独立函数,保证 MessageBridge 是单例的 export function useMessageBridge() { - if (!messageBridge) { + if (!messageBridge && getPlatform() !== 'nodejs') { messageBridge = new MessageBridge('ws://localhost:8080'); } const bridge = messageBridge; diff --git a/renderer/src/components/main-panel/chat/core/handle-tool-calls.ts b/renderer/src/components/main-panel/chat/core/handle-tool-calls.ts index b185055..2682d45 100644 --- a/renderer/src/components/main-panel/chat/core/handle-tool-calls.ts +++ b/renderer/src/components/main-panel/chat/core/handle-tool-calls.ts @@ -1,8 +1,13 @@ -import { ToolCallResponse } from "@/hook/type"; +import { ToolCallContent, ToolCallResponse } from "@/hook/type"; import { callTool } from "../../tool/tools"; import { MessageState, ToolCall } from "../chat-box/chat"; -export async function handleToolCalls(toolCall: ToolCall) { +export interface ToolCallResult { + state: MessageState; + content: ToolCallContent[]; +} + +export async function handleToolCalls(toolCall: ToolCall): Promise { // 反序列化 streaming 来的参数字符串 const toolName = toolCall.function.name; const argsResult = deserializeToolCallResponse(toolCall.function.arguments); diff --git a/renderer/src/components/main-panel/chat/core/task-loop.ts b/renderer/src/components/main-panel/chat/core/task-loop.ts index 082d556..137ddfe 100644 --- a/renderer/src/components/main-panel/chat/core/task-loop.ts +++ b/renderer/src/components/main-panel/chat/core/task-loop.ts @@ -1,12 +1,12 @@ /* eslint-disable */ import { ref, type Ref } from "vue"; import { ToolCall, ChatStorage, getToolSchema, MessageState } from "../chat-box/chat"; -import { useMessageBridge, MessageBridge } from "@/api/message-bridge"; +import { useMessageBridge, MessageBridge, createMessageBridge } from "@/api/message-bridge"; import type { OpenAI } from 'openai'; import { llmManager, llms } from "@/views/setting/llm"; import { pinkLog, redLog } from "@/views/setting/util"; import { ElMessage } from "element-plus"; -import { handleToolCalls } from "./handle-tool-calls"; +import { handleToolCalls, ToolCallResult } from "./handle-tool-calls"; import { getPlatform } from "@/api/platform"; export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk; @@ -39,6 +39,7 @@ export class TaskLoop { private onError: (error: IErrorMssage) => void = (msg) => {}; private onChunk: (chunk: ChatCompletionChunk) => void = (chunk) => {}; private onDone: () => void = () => {}; + private onToolCalled: (toolCallResult: ToolCallResult) => void = (toolCall) => {}; private onEpoch: () => void = () => {}; private completionUsage: ChatCompletionChunk['usage'] | undefined; private llmConfig: any; @@ -52,17 +53,17 @@ export class TaskLoop { // 根据当前环境决定是否要开启 messageBridge const platform = getPlatform(); if (platform === 'nodejs') { - const adapter = taskOptions.adapter; if (!adapter) { throw new Error('adapter is required'); } - this.bridge = new MessageBridge(adapter.emitter); - } else { - this.bridge = useMessageBridge(); - } + createMessageBridge(adapter.emitter); + } + + // web 环境下 bridge 会自动加载完成 + this.bridge = useMessageBridge(); } private handleChunkDeltaContent(chunk: ChatCompletionChunk) { @@ -235,6 +236,10 @@ export class TaskLoop { this.onEpoch = handler; } + public registerOnToolCalled(handler: (toolCallResult: ToolCallResult) => void) { + this.onToolCalled = handler; + } + public setMaxEpochs(maxEpochs: number) { this.taskOptions.maxEpochs = maxEpochs; } @@ -331,6 +336,7 @@ export class TaskLoop { for (const toolCall of this.streamingToolCalls.value || []) { const toolCallResult = await handleToolCalls(toolCall); + this.onToolCalled(toolCallResult); if (toolCallResult.state === MessageState.ParseJsonError) { // 如果是因为解析 JSON 错误,则重新开始 diff --git a/renderer/src/components/main-panel/tool/tools.ts b/renderer/src/components/main-panel/tool/tools.ts index 02ec19a..5bd4570 100644 --- a/renderer/src/components/main-panel/tool/tools.ts +++ b/renderer/src/components/main-panel/tool/tools.ts @@ -16,9 +16,9 @@ export interface ToolStorage { formData: Record; } -const bridge = useMessageBridge(); export function callTool(toolName: string, toolArgs: Record) { + const bridge = useMessageBridge(); return new Promise((resolve, reject) => { bridge.addCommandListener('tools/call', (data: CasualRestAPI) => { console.log(data.msg); diff --git a/resources/openmcp-sdk-release/task-loop.d.ts b/resources/openmcp-sdk-release/task-loop.d.ts index be03b38..befd21e 100644 --- a/resources/openmcp-sdk-release/task-loop.d.ts +++ b/resources/openmcp-sdk-release/task-loop.d.ts @@ -24,6 +24,17 @@ export interface ToolCall { } } +export interface ToolCallContent { + type: string; + text: string; + [key: string]: any; +} + +export interface ToolCallResult { + state: MessageState; + content: ToolCallContent[]; +} + export enum MessageState { ServerError = 'server internal error', ReceiveChunkError = 'receive chunk error', @@ -58,6 +69,7 @@ export class TaskLoop { private onError; private onChunk; private onDone; + private onToolCalled; private onEpoch; private completionUsage; private llmConfig; @@ -72,6 +84,7 @@ export class TaskLoop { registerOnChunk(handler: (chunk: ChatCompletionChunk) => void): void; registerOnDone(handler: () => void): void; registerOnEpoch(handler: () => void): void; + registerOnToolCalled(handler: (toolCallResult: ToolCallResult) => void): void; setMaxEpochs(maxEpochs: number): void; /** * @description 设置当前的 LLM 配置,用于 nodejs 环境运行