增加 toolcall 后的 hook registerOnToolCalled

This commit is contained in:
锦恢 2025-05-11 23:09:19 +08:00
parent de7d118d7b
commit 2db3ab8888
5 changed files with 39 additions and 11 deletions

View File

@ -221,9 +221,13 @@ export class MessageBridge {
// 单例实例 // 单例实例
let messageBridge: MessageBridge; let messageBridge: MessageBridge;
export function createMessageBridge(setupSignature: any) {
messageBridge = new MessageBridge(setupSignature);
}
// 向外暴露一个独立函数,保证 MessageBridge 是单例的 // 向外暴露一个独立函数,保证 MessageBridge 是单例的
export function useMessageBridge() { export function useMessageBridge() {
if (!messageBridge) { if (!messageBridge && getPlatform() !== 'nodejs') {
messageBridge = new MessageBridge('ws://localhost:8080'); messageBridge = new MessageBridge('ws://localhost:8080');
} }
const bridge = messageBridge; const bridge = messageBridge;

View File

@ -1,8 +1,13 @@
import { ToolCallResponse } from "@/hook/type"; import { ToolCallContent, ToolCallResponse } from "@/hook/type";
import { callTool } from "../../tool/tools"; import { callTool } from "../../tool/tools";
import { MessageState, ToolCall } from "../chat-box/chat"; import { MessageState, ToolCall } from "../chat-box/chat";
export async function handleToolCalls(toolCall: ToolCall) { export interface ToolCallResult {
state: MessageState;
content: ToolCallContent[];
}
export async function handleToolCalls(toolCall: ToolCall): Promise<ToolCallResult> {
// 反序列化 streaming 来的参数字符串 // 反序列化 streaming 来的参数字符串
const toolName = toolCall.function.name; const toolName = toolCall.function.name;
const argsResult = deserializeToolCallResponse(toolCall.function.arguments); const argsResult = deserializeToolCallResponse(toolCall.function.arguments);

View File

@ -1,12 +1,12 @@
/* eslint-disable */ /* eslint-disable */
import { ref, type Ref } from "vue"; import { ref, type Ref } from "vue";
import { ToolCall, ChatStorage, getToolSchema, MessageState } from "../chat-box/chat"; import { ToolCall, ChatStorage, getToolSchema, MessageState } from "../chat-box/chat";
import { useMessageBridge, MessageBridge } from "@/api/message-bridge"; import { useMessageBridge, MessageBridge, createMessageBridge } from "@/api/message-bridge";
import type { OpenAI } from 'openai'; import type { OpenAI } from 'openai';
import { llmManager, llms } from "@/views/setting/llm"; import { llmManager, llms } from "@/views/setting/llm";
import { pinkLog, redLog } from "@/views/setting/util"; import { pinkLog, redLog } from "@/views/setting/util";
import { ElMessage } from "element-plus"; import { ElMessage } from "element-plus";
import { handleToolCalls } from "./handle-tool-calls"; import { handleToolCalls, ToolCallResult } from "./handle-tool-calls";
import { getPlatform } from "@/api/platform"; import { getPlatform } from "@/api/platform";
export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk; export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk;
@ -39,6 +39,7 @@ export class TaskLoop {
private onError: (error: IErrorMssage) => void = (msg) => {}; private onError: (error: IErrorMssage) => void = (msg) => {};
private onChunk: (chunk: ChatCompletionChunk) => void = (chunk) => {}; private onChunk: (chunk: ChatCompletionChunk) => void = (chunk) => {};
private onDone: () => void = () => {}; private onDone: () => void = () => {};
private onToolCalled: (toolCallResult: ToolCallResult) => void = (toolCall) => {};
private onEpoch: () => void = () => {}; private onEpoch: () => void = () => {};
private completionUsage: ChatCompletionChunk['usage'] | undefined; private completionUsage: ChatCompletionChunk['usage'] | undefined;
private llmConfig: any; private llmConfig: any;
@ -52,17 +53,17 @@ export class TaskLoop {
// 根据当前环境决定是否要开启 messageBridge // 根据当前环境决定是否要开启 messageBridge
const platform = getPlatform(); const platform = getPlatform();
if (platform === 'nodejs') { if (platform === 'nodejs') {
const adapter = taskOptions.adapter; const adapter = taskOptions.adapter;
if (!adapter) { if (!adapter) {
throw new Error('adapter is required'); throw new Error('adapter is required');
} }
this.bridge = new MessageBridge(adapter.emitter); createMessageBridge(adapter.emitter);
} else { }
this.bridge = useMessageBridge();
} // web 环境下 bridge 会自动加载完成
this.bridge = useMessageBridge();
} }
private handleChunkDeltaContent(chunk: ChatCompletionChunk) { private handleChunkDeltaContent(chunk: ChatCompletionChunk) {
@ -235,6 +236,10 @@ export class TaskLoop {
this.onEpoch = handler; this.onEpoch = handler;
} }
public registerOnToolCalled(handler: (toolCallResult: ToolCallResult) => void) {
this.onToolCalled = handler;
}
public setMaxEpochs(maxEpochs: number) { public setMaxEpochs(maxEpochs: number) {
this.taskOptions.maxEpochs = maxEpochs; this.taskOptions.maxEpochs = maxEpochs;
} }
@ -331,6 +336,7 @@ export class TaskLoop {
for (const toolCall of this.streamingToolCalls.value || []) { for (const toolCall of this.streamingToolCalls.value || []) {
const toolCallResult = await handleToolCalls(toolCall); const toolCallResult = await handleToolCalls(toolCall);
this.onToolCalled(toolCallResult);
if (toolCallResult.state === MessageState.ParseJsonError) { if (toolCallResult.state === MessageState.ParseJsonError) {
// 如果是因为解析 JSON 错误,则重新开始 // 如果是因为解析 JSON 错误,则重新开始

View File

@ -16,9 +16,9 @@ export interface ToolStorage {
formData: Record<string, any>; formData: Record<string, any>;
} }
const bridge = useMessageBridge();
export function callTool(toolName: string, toolArgs: Record<string, any>) { export function callTool(toolName: string, toolArgs: Record<string, any>) {
const bridge = useMessageBridge();
return new Promise<ToolCallResponse>((resolve, reject) => { return new Promise<ToolCallResponse>((resolve, reject) => {
bridge.addCommandListener('tools/call', (data: CasualRestAPI<ToolCallResponse>) => { bridge.addCommandListener('tools/call', (data: CasualRestAPI<ToolCallResponse>) => {
console.log(data.msg); console.log(data.msg);

View File

@ -24,6 +24,17 @@ export interface ToolCall {
} }
} }
export interface ToolCallContent {
type: string;
text: string;
[key: string]: any;
}
export interface ToolCallResult {
state: MessageState;
content: ToolCallContent[];
}
export enum MessageState { export enum MessageState {
ServerError = 'server internal error', ServerError = 'server internal error',
ReceiveChunkError = 'receive chunk error', ReceiveChunkError = 'receive chunk error',
@ -58,6 +69,7 @@ export class TaskLoop {
private onError; private onError;
private onChunk; private onChunk;
private onDone; private onDone;
private onToolCalled;
private onEpoch; private onEpoch;
private completionUsage; private completionUsage;
private llmConfig; private llmConfig;
@ -72,6 +84,7 @@ export class TaskLoop {
registerOnChunk(handler: (chunk: ChatCompletionChunk) => void): void; registerOnChunk(handler: (chunk: ChatCompletionChunk) => void): void;
registerOnDone(handler: () => void): void; registerOnDone(handler: () => void): void;
registerOnEpoch(handler: () => void): void; registerOnEpoch(handler: () => void): void;
registerOnToolCalled(handler: (toolCallResult: ToolCallResult) => void): void;
setMaxEpochs(maxEpochs: number): void; setMaxEpochs(maxEpochs: number): void;
/** /**
* @description LLM nodejs * @description LLM nodejs