From 0dc727e781b0d076a8a6190a7469fa3ffd54d815 Mon Sep 17 00:00:00 2001 From: Kirigaya <1193466151@qq.com> Date: Sun, 13 Jul 2025 04:04:15 +0800 Subject: [PATCH] save --- .../main-panel/chat/core/task-loop.ts | 19 ++- .../main-panel/tool/auto-detector/diagram.ts | 118 --------------- .../main-panel/tool/auto-detector/index.vue | 56 +++---- service/src/llm/llm.controller.ts | 2 + service/src/llm/llm.service.ts | 138 ++++++++---------- 5 files changed, 107 insertions(+), 226 deletions(-) diff --git a/renderer/src/components/main-panel/chat/core/task-loop.ts b/renderer/src/components/main-panel/chat/core/task-loop.ts index e48cc76..96c46b6 100644 --- a/renderer/src/components/main-panel/chat/core/task-loop.ts +++ b/renderer/src/components/main-panel/chat/core/task-loop.ts @@ -18,6 +18,7 @@ import { getXmlWrapperPrompt, getToolCallFromXmlString, getXmlsFromString, handl export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk; export interface TaskLoopChatOption { id?: string + sessionId: string; proxyServer?: string enableXmlWrapper?: boolean } @@ -193,9 +194,15 @@ export class TaskLoop { } private doConversation(chatData: ChatCompletionCreateParamsBase, toolcallIndexAdapter: (toolCall: ToolCall) => IToolCallIndex) { + const sessionId = chatData.sessionId; return new Promise((resolve, reject) => { const chunkHandler = this.bridge.addCommandListener('llm/chat/completions/chunk', data => { + + if (data.sessionId !== sessionId) { + return; + } + // data.code 一定为 200,否则不会走这个 route const { chunk } = data.msg as { chunk: ChatCompletionChunk }; @@ -214,6 +221,10 @@ export class TaskLoop { }, { once: false }); const doneHandler = this.bridge.addCommandListener('llm/chat/completions/done', data => { + if (data.sessionId !== sessionId) { + return; + } + this.consumeDones(); chunkHandler(); @@ -225,6 +236,10 @@ export class TaskLoop { }, { once: true }); const errorHandler = this.bridge.addCommandListener('llm/chat/completions/error', data => { + if (data.sessionId !== sessionId) { + return; + } + this.consumeErrors({ state: MessageState.ReceiveChunkError, msg: data.msg || '请求模型服务时发生错误' @@ -304,7 +319,7 @@ export class TaskLoop { const id = crypto.randomUUID(); const chatData = { - id, + sessionId: id, baseURL, apiKey, model, @@ -575,7 +590,7 @@ export class TaskLoop { break; } - this.currentChatId = chatData.id!; + this.currentChatId = chatData.sessionId; const llm = this.getLlmConfig(); const toolcallIndexAdapter = getToolCallIndexAdapter(llm, chatData); diff --git a/renderer/src/components/main-panel/tool/auto-detector/diagram.ts b/renderer/src/components/main-panel/tool/auto-detector/diagram.ts index bf6596b..633140f 100644 --- a/renderer/src/components/main-panel/tool/auto-detector/diagram.ts +++ b/renderer/src/components/main-panel/tool/auto-detector/diagram.ts @@ -161,124 +161,6 @@ export function topoSortParallel(state: DiagramState): string[][] { return result; } -export async function makeParallelTest( - dataViews: Reactive[], - enableXmlWrapper: boolean, - prompt: string | null = null, - context: DiagramContext -) { - if (dataViews.length === 0) { - return; - } - - // 设置所有节点状态为运行中 - const createAt = Date.now(); - dataViews.forEach(dataView => { - dataView.status = 'running'; - dataView.createAt = createAt; - }); - context.render(); - - try { - const loop = new TaskLoop({ maxEpochs: 1 }); - - // 构建所有工具的信息 - const allTools = dataViews.map(dataView => ({ - name: dataView.tool.name, - description: dataView.tool.description, - inputSchema: dataView.tool.inputSchema, - enabled: true - })); - - // 构建测试提示词,包含所有工具 - const toolNames = dataViews.map(dv => dv.tool.name).join(', '); - const usePrompt = (prompt || 'please call the tools {tool} to make some test').replace('{tool}', toolNames); - - const chatStorage = { - messages: [], - settings: { - temperature: 0.6, - systemPrompt: '', - enableTools: allTools, - enableWebSearch: false, - contextLength: 5, - enableXmlWrapper, - parallelToolCalls: true // 开启并行工具调用 - } - } as ChatStorage; - - loop.setMaxEpochs(1); - - // 记录工具调用信息,用于匹配工具调用结果 - const toolCallMap: Map> = new Map(); // toolCallId -> dataView - let completedToolsCount = 0; - - loop.registerOnToolCall(toolCall => { - // 找到对应的dataView - const dataView = dataViews.find(dv => dv.tool.name === toolCall.function?.name); - if (dataView) { - dataView.function = toolCall.function; - dataView.llmTimecost = Date.now() - createAt; - context.render(); - - // 记录工具调用ID与dataView的映射 - if (toolCall.id) { - toolCallMap.set(toolCall.id, dataView); - } - } - return toolCall; - }); - - loop.registerOnToolCalled(toolCalled => { - // 这里我们需要改变策略,因为没有工具调用ID信息 - // 对于并行调用,我们可以根据工具调用的顺序来匹配结果 - // 简单的策略:找到第一个仍在运行状态的工具 - const runningView = dataViews.find(dv => dv.status === 'running'); - if (runningView) { - runningView.toolcallTimecost = Date.now() - createAt - (runningView.llmTimecost || 0); - - if (toolCalled.state === MessageState.Success) { - runningView.status = 'success'; - runningView.result = toolCalled.content; - } else { - runningView.status = 'error'; - runningView.result = toolCalled.content; - } - - completedToolsCount++; - context.render(); - - // 如果所有工具都完成了,终止循环 - if (completedToolsCount >= dataViews.length) { - loop.abort(); - } - } - return toolCalled; - }); - - loop.registerOnError(error => { - dataViews.forEach(dataView => { - if (dataView.status === 'running') { - dataView.status = 'error'; - dataView.result = error; - } - }); - context.render(); - }); - - await loop.start(chatStorage, usePrompt); - - } finally { - const finishAt = Date.now(); - dataViews.forEach(dataView => { - dataView.finishAt = finishAt; - if (dataView.status === 'running') { - dataView.status = 'success'; - } - }); - context.render(); - } -} export async function makeNodeTest( dataView: Reactive, diff --git a/renderer/src/components/main-panel/tool/auto-detector/index.vue b/renderer/src/components/main-panel/tool/auto-detector/index.vue index 217af14..5e38d4c 100644 --- a/renderer/src/components/main-panel/tool/auto-detector/index.vue +++ b/renderer/src/components/main-panel/tool/auto-detector/index.vue @@ -56,19 +56,15 @@ - +
{{ caption }}
- - - + + +
@@ -144,37 +140,41 @@ if (autoDetectDiagram) { // 新增:自检参数表单相关 const testFormVisible = ref(false); const enableXmlWrapper = ref(false); -const enableParallelTest = ref(false); +const enableParallelTest = ref(true); const testPrompt = ref('please call the tool {tool} to make some test'); async function onTestConfirm() { testFormVisible.value = false; const state = context.state; - + tabStorage.autoDetectDiagram!.views = []; if (state) { + const dispatches = topoSortParallel(state); + if (enableParallelTest.value) { - // 并行测试模式:一次性测试所有工具 - const allViews = Array.from(state.dataView.values()); - await makeParallelTest(allViews, enableXmlWrapper.value, testPrompt.value, context); - - // 保存结果 - allViews.forEach(view => { - tabStorage.autoDetectDiagram!.views!.push({ - tool: view.tool, - status: view.status, - function: view.function, - result: view.result, - createAt: view.createAt, - finishAt: view.finishAt, - llmTimecost: view.llmTimecost, - toolcallTimecost: view.toolcallTimecost, - }); - }); + for (const nodeIds of dispatches) { + + await Promise.all(nodeIds.map(async id => { + const view = state.dataView.get(id); + if (view) { + await makeNodeTest(view, enableXmlWrapper.value, testPrompt.value, context); + tabStorage.autoDetectDiagram!.views!.push({ + tool: view.tool, + status: view.status, + function: view.function, + result: view.result, + createAt: view.createAt, + finishAt: view.finishAt, + llmTimecost: view.llmTimecost, + toolcallTimecost: view.toolcallTimecost, + }); + context.render(); + } + })); + } } else { // 串行测试模式:按拓扑顺序逐个测试 - const dispatches = topoSortParallel(state); for (const nodeIds of dispatches) { for (const id of nodeIds) { const view = state.dataView.get(id); diff --git a/service/src/llm/llm.controller.ts b/service/src/llm/llm.controller.ts index c42d203..08301d2 100644 --- a/service/src/llm/llm.controller.ts +++ b/service/src/llm/llm.controller.ts @@ -17,6 +17,8 @@ export class LlmController { webview.postMessage({ command: 'llm/chat/completions/error', data: { + sessionId: data.sessionId, + code: 500, msg: error } }); diff --git a/service/src/llm/llm.service.ts b/service/src/llm/llm.service.ts index a132fab..7cc68fa 100644 --- a/service/src/llm/llm.service.ts +++ b/service/src/llm/llm.service.ts @@ -5,15 +5,16 @@ import { RestfulResponse } from "../common/index.dto.js"; import { ocrDB } from "../hook/db.js"; import type { ToolCallContent } from "../mcp/client.dto.js"; import { ocrWorkerStorage } from "../mcp/ocr.service.js"; -import Table from 'cli-table3'; -export let currentStream: AsyncIterable | null = null; +// 用 Map | null> 管理多个流 +export const chatStreams = new Map>(); export async function streamingChatCompletion( data: any, webview: PostMessageble ) { const { + sessionId, baseURL, apiKey, model, @@ -24,16 +25,6 @@ export async function streamingChatCompletion( proxyServer = '' } = data; - // 创建请求参数表格 - const requestTable = new Table({ - head: ['Parameter', 'Value'], - colWidths: [20, 40], - style: { - head: ['cyan'], - border: ['grey'] - } - }); - // 构建OpenRouter特定的请求头 const defaultHeaders: Record = {}; @@ -47,26 +38,13 @@ export async function streamingChatCompletion( apiKey, defaultHeaders: Object.keys(defaultHeaders).length > 0 ? defaultHeaders : undefined }); - - const seriableTools = (tools.length === 0) ? undefined: tools; - const seriableParallelToolCalls = (tools.length === 0)? - undefined: model.startsWith('gemini') ? undefined : parallelToolCalls; - + + const seriableTools = (tools.length === 0) ? undefined : tools; + const seriableParallelToolCalls = (tools.length === 0) ? + undefined : model.startsWith('gemini') ? undefined : parallelToolCalls; + await postProcessMessages(messages); - // // 使用表格渲染请求参数 - // requestTable.push( - // ['Model', model], - // ['Base URL', baseURL || 'Default'], - // ['Temperature', temperature], - // ['Tools Count', tools.length], - // ['Parallel Tool Calls', parallelToolCalls], - // ['Proxy Server', proxyServer || 'No Proxy'] - // ); - - // console.log('\nOpenAI Request Parameters:'); - // console.log(requestTable.toString()); - const stream = await client.chat.completions.create({ model, messages, @@ -76,19 +54,20 @@ export async function streamingChatCompletion( stream: true }); - // 存储当前的流式传输对象 - currentStream = stream; + // 用 sessionId 作为 key 存储流 + if (sessionId) { + chatStreams.set(sessionId, stream); + } // 流式传输结果 for await (const chunk of stream) { - if (!currentStream) { + if (!chatStreams.has(sessionId)) { // 如果流被中止,则停止循环 - // TODO: 为每一个标签页设置不同的 currentStream 管理器 stream.controller.abort(); - // 传输结束 webview.postMessage({ command: 'llm/chat/completions/done', data: { + sessionId, code: 200, msg: { success: true, @@ -98,26 +77,29 @@ export async function streamingChatCompletion( }); break; } - - if (chunk.choices) { - const chunkResult = { - code: 200, - msg: { - chunk - } - }; + if (chunk.choices) { webview.postMessage({ command: 'llm/chat/completions/chunk', - data: chunkResult + data: { + sessionId, + code: 200, + msg: { + chunk + } + } }); } } - // 传输结束 + // 传输结束,移除对应的 stream + if (sessionId) { + chatStreams.delete(sessionId); + } webview.postMessage({ command: 'llm/chat/completions/done', data: { + sessionId, code: 200, msg: { success: true, @@ -130,9 +112,9 @@ export async function streamingChatCompletion( // 处理中止消息的函数 export function abortMessageService(data: any, webview: PostMessageble): RestfulResponse { - if (currentStream) { - // 标记流已中止 - currentStream = null; + const sessionId = data?.sessionId; + if (sessionId) { + chatStreams.delete(sessionId); } return { @@ -144,18 +126,18 @@ export function abortMessageService(data: any, webview: PostMessageble): Restful } async function postProcessToolMessages(message: MyToolMessageType) { - if (typeof message.content === 'string') { - return; - } + if (typeof message.content === 'string') { + return; + } - for (const content of message.content) { - const contentType = content.type as string; - const rawContent = content as ToolCallContent; + for (const content of message.content) { + const contentType = content.type as string; + const rawContent = content as ToolCallContent; - if (contentType === 'image') { - rawContent.type = 'text'; - - // 此时图片只会存在三个状态 + if (contentType === 'image') { + rawContent.type = 'text'; + + // 此时图片只会存在三个状态 // 1. 图片在 ocrDB 中 // 2. 图片的 OCR 仍然在进行中 // 3. 图片已被删除 @@ -164,7 +146,7 @@ async function postProcessToolMessages(message: MyToolMessageType) { // rawContent.data 就是 filename const result = await ocrDB.findById(rawContent.data); if (result) { - rawContent.text = result.text || ''; + rawContent.text = result.text || ''; } else if (rawContent._meta) { const workerId = rawContent._meta.workerId; const worker = ocrWorkerStorage.get(workerId); @@ -177,35 +159,35 @@ async function postProcessToolMessages(message: MyToolMessageType) { } delete rawContent._meta; - } - } + } + } - message.content = JSON.stringify(message.content); + message.content = JSON.stringify(message.content); } export async function postProcessMessages(messages: MyMessageType[]) { - for (const message of messages) { - // 去除 extraInfo 属性 - delete message.extraInfo; + for (const message of messages) { + // 去除 extraInfo 属性 + delete message.extraInfo; - switch (message.role) { - case 'user': + switch (message.role) { + case 'user': break; - case 'assistant': + case 'assistant': break; - - case 'system': - break; + case 'system': - case 'tool': - await postProcessToolMessages(message); - break; - default: - break; - } - } + break; + + case 'tool': + await postProcessToolMessages(message); + break; + default: + break; + } + } } \ No newline at end of file