优化 streaming 架构

This commit is contained in:
锦恢 2025-05-11 18:26:15 +08:00
parent a622fdd1e2
commit 1f55126c82
7 changed files with 50 additions and 18 deletions

View File

@ -93,7 +93,8 @@ function handleSend(newMessage?: string) {
isLoading.value = true; isLoading.value = true;
autoScroll.value = true; autoScroll.value = true;
loop = new TaskLoop(streamingContent, streamingToolCalls); loop = new TaskLoop();
loop.bindStreaming(streamingContent, streamingToolCalls);
loop.registerOnError((error) => { loop.registerOnError((error) => {
console.log('error.msg'); console.log('error.msg');

View File

@ -1,5 +1,5 @@
/* eslint-disable */ /* eslint-disable */
import type { Ref } from "vue"; import { ref, type Ref } from "vue";
import { ToolCall, ChatStorage, getToolSchema, MessageState } from "../chat-box/chat"; import { ToolCall, ChatStorage, getToolSchema, MessageState } from "../chat-box/chat";
import { useMessageBridge, MessageBridge } from "@/api/message-bridge"; import { useMessageBridge, MessageBridge } from "@/api/message-bridge";
import type { OpenAI } from 'openai'; import type { OpenAI } from 'openai';
@ -32,6 +32,9 @@ export interface IDoConversationResult {
*/ */
export class TaskLoop { export class TaskLoop {
private bridge: MessageBridge; private bridge: MessageBridge;
private streamingContent: Ref<string>;
private streamingToolCalls: Ref<ToolCall[]>;
private currentChatId = ''; private currentChatId = '';
private onError: (error: IErrorMssage) => void = (msg) => {}; private onError: (error: IErrorMssage) => void = (msg) => {};
private onChunk: (chunk: ChatCompletionChunk) => void = (chunk) => {}; private onChunk: (chunk: ChatCompletionChunk) => void = (chunk) => {};
@ -41,10 +44,11 @@ export class TaskLoop {
private llmConfig: any; private llmConfig: any;
constructor( constructor(
private readonly streamingContent: Ref<string>,
private readonly streamingToolCalls: Ref<ToolCall[]>,
private readonly taskOptions: TaskLoopOptions = { maxEpochs: 20, maxJsonParseRetry: 3, adapter: undefined }, private readonly taskOptions: TaskLoopOptions = { maxEpochs: 20, maxJsonParseRetry: 3, adapter: undefined },
) { ) {
this.streamingContent = ref('');
this.streamingToolCalls = ref([]);
// 根据当前环境决定是否要开启 messageBridge // 根据当前环境决定是否要开启 messageBridge
const platform = getPlatform(); const platform = getPlatform();
if (platform === 'nodejs') { if (platform === 'nodejs') {
@ -150,7 +154,7 @@ export class TaskLoop {
}, { once: true }); }, { once: true });
console.log(chatData); // console.log(chatData);
this.bridge.postMessage({ this.bridge.postMessage({
command: 'llm/chat/completions', command: 'llm/chat/completions',
@ -250,6 +254,11 @@ export class TaskLoop {
this.llmConfig = config; this.llmConfig = config;
} }
public bindStreaming(content: Ref<string>, toolCalls: Ref<ToolCall[]>) {
this.streamingContent = content;
this.streamingToolCalls = toolCalls;
}
public getLlmConfig() { public getLlmConfig() {
if (this.llmConfig) { if (this.llmConfig) {
return this.llmConfig; return this.llmConfig;
@ -301,8 +310,8 @@ export class TaskLoop {
// 发送请求 // 发送请求
const doConverationResult = await this.doConversation(chatData); const doConverationResult = await this.doConversation(chatData);
console.log('[doConverationResult] Response');
console.log(doConverationResult); console.log(doConverationResult);
// 如果存在需要调度的工具 // 如果存在需要调度的工具
if (this.streamingToolCalls.value.length > 0) { if (this.streamingToolCalls.value.length > 0) {

View File

@ -26,9 +26,7 @@ export async function makeSimpleTalk() {
// 使用最简单的 hello 来测试 // 使用最简单的 hello 来测试
const testMessage = 'hello'; const testMessage = 'hello';
const s1 = ref(''); const loop = new TaskLoop();
const s2 = ref([]);
const loop = new TaskLoop(s1, s2);
const chatStorage: ChatStorage = { const chatStorage: ChatStorage = {
messages: [], messages: [],

View File

@ -1,5 +1,5 @@
import { requestHandlerStorage } from "."; import { requestHandlerStorage } from ".";
import { PostMessageble } from "../hook/adapter"; import type { PostMessageble } from "../hook/adapter";
import { LlmController } from "../llm/llm.controller"; import { LlmController } from "../llm/llm.controller";
import { ClientController } from "../mcp/client.controller"; import { ClientController } from "../mcp/client.controller";
import { ConnectController } from "../mcp/connect.controller"; import { ConnectController } from "../mcp/connect.controller";

View File

@ -1,5 +1,8 @@
import { WebSocket } from 'ws'; import { WebSocket } from 'ws';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import { routeMessage } from '../common/router';
import { McpOptions } from '../mcp/client.dto';
import { client, connectService } from '../mcp/connect.service';
// WebSocket 消息格式 // WebSocket 消息格式
export interface WebSocketMessage { export interface WebSocketMessage {
@ -66,7 +69,7 @@ export class VSCodeWebViewLike {
} }
export class EventAdapter { export class TaskLoopAdapter {
public emitter: EventEmitter; public emitter: EventEmitter;
private messageHandlers: Set<MessageHandler>; private messageHandlers: Set<MessageHandler>;
@ -77,28 +80,48 @@ export class EventAdapter {
this.emitter.on('message/renderer', (message: WebSocketMessage) => { this.emitter.on('message/renderer', (message: WebSocketMessage) => {
this.messageHandlers.forEach((handler) => handler(message)); this.messageHandlers.forEach((handler) => handler(message));
}); });
// 默认需要将监听的消息导入到 routeMessage 中
this.onDidReceiveMessage((message) => {
const { command, data } = message;
routeMessage(command, data, this);
});
} }
/** /**
* * @description
* @param message - command args * @param message - command args
*/ */
postMessage(message: WebSocketMessage): void { public postMessage(message: WebSocketMessage): void {
console.log('message/renderer', message);
this.emitter.emit('message/service', message); this.emitter.emit('message/service', message);
} }
/** /**
* * @description
* @param callback - * @param callback -
* @returns {{ dispose: () => void }} - * @returns {{ dispose: () => void }} -
*/ */
onDidReceiveMessage(callback: MessageHandler): { dispose: () => void } { public onDidReceiveMessage(callback: MessageHandler): { dispose: () => void } {
this.messageHandlers.add(callback); this.messageHandlers.add(callback);
return { return {
dispose: () => this.messageHandlers.delete(callback), dispose: () => this.messageHandlers.delete(callback),
}; };
} }
/**
* @description mcp
* @param mcpOption
*/
public async connectMcpServer(mcpOption: McpOptions) {
const res = await connectService(undefined, mcpOption);
if (res.code === 200) {
console.log('✅ 成功连接 mcp 服务器: ' + res.msg);
const version = client?.getServerVersion();
console.log(version);
} else {
console.error('❌ 连接 mcp 服务器失败:' + res.msg);
}
}
} }

View File

@ -1,5 +1,5 @@
export { routeMessage } from './common/router'; export { routeMessage } from './common/router';
export { VSCodeWebViewLike, EventAdapter } from './hook/adapter'; export { VSCodeWebViewLike, TaskLoopAdapter } from './hook/adapter';
export { setVscodeWorkspace, setRunningCWD } from './hook/setting'; export { setVscodeWorkspace, setRunningCWD } from './hook/setting';
// TODO: 更加规范 // TODO: 更加规范
export { client } from './mcp/connect.service'; export { client } from './mcp/connect.service';

View File

@ -45,6 +45,7 @@ module.exports = {
plugins: [ plugins: [
new webpack.DefinePlugin({ new webpack.DefinePlugin({
window: { window: {
nodejs: true,
navigator: { navigator: {
userAgent: 2 userAgent: 2
}, },