This commit is contained in:
锦恢 2025-07-13 04:04:15 +08:00
parent 200ccdd0ac
commit 0dc727e781
5 changed files with 107 additions and 226 deletions

View File

@ -18,6 +18,7 @@ import { getXmlWrapperPrompt, getToolCallFromXmlString, getXmlsFromString, handl
export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk; export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk;
export interface TaskLoopChatOption { export interface TaskLoopChatOption {
id?: string id?: string
sessionId: string;
proxyServer?: string proxyServer?: string
enableXmlWrapper?: boolean enableXmlWrapper?: boolean
} }
@ -193,9 +194,15 @@ export class TaskLoop {
} }
private doConversation(chatData: ChatCompletionCreateParamsBase, toolcallIndexAdapter: (toolCall: ToolCall) => IToolCallIndex) { private doConversation(chatData: ChatCompletionCreateParamsBase, toolcallIndexAdapter: (toolCall: ToolCall) => IToolCallIndex) {
const sessionId = chatData.sessionId;
return new Promise<IDoConversationResult>((resolve, reject) => { return new Promise<IDoConversationResult>((resolve, reject) => {
const chunkHandler = this.bridge.addCommandListener('llm/chat/completions/chunk', data => { const chunkHandler = this.bridge.addCommandListener('llm/chat/completions/chunk', data => {
if (data.sessionId !== sessionId) {
return;
}
// data.code 一定为 200否则不会走这个 route // data.code 一定为 200否则不会走这个 route
const { chunk } = data.msg as { chunk: ChatCompletionChunk }; const { chunk } = data.msg as { chunk: ChatCompletionChunk };
@ -214,6 +221,10 @@ export class TaskLoop {
}, { once: false }); }, { once: false });
const doneHandler = this.bridge.addCommandListener('llm/chat/completions/done', data => { const doneHandler = this.bridge.addCommandListener('llm/chat/completions/done', data => {
if (data.sessionId !== sessionId) {
return;
}
this.consumeDones(); this.consumeDones();
chunkHandler(); chunkHandler();
@ -225,6 +236,10 @@ export class TaskLoop {
}, { once: true }); }, { once: true });
const errorHandler = this.bridge.addCommandListener('llm/chat/completions/error', data => { const errorHandler = this.bridge.addCommandListener('llm/chat/completions/error', data => {
if (data.sessionId !== sessionId) {
return;
}
this.consumeErrors({ this.consumeErrors({
state: MessageState.ReceiveChunkError, state: MessageState.ReceiveChunkError,
msg: data.msg || '请求模型服务时发生错误' msg: data.msg || '请求模型服务时发生错误'
@ -304,7 +319,7 @@ export class TaskLoop {
const id = crypto.randomUUID(); const id = crypto.randomUUID();
const chatData = { const chatData = {
id, sessionId: id,
baseURL, baseURL,
apiKey, apiKey,
model, model,
@ -575,7 +590,7 @@ export class TaskLoop {
break; break;
} }
this.currentChatId = chatData.id!; this.currentChatId = chatData.sessionId;
const llm = this.getLlmConfig(); const llm = this.getLlmConfig();
const toolcallIndexAdapter = getToolCallIndexAdapter(llm, chatData); const toolcallIndexAdapter = getToolCallIndexAdapter(llm, chatData);

View File

@ -161,124 +161,6 @@ export function topoSortParallel(state: DiagramState): string[][] {
return result; return result;
} }
export async function makeParallelTest(
dataViews: Reactive<NodeDataView>[],
enableXmlWrapper: boolean,
prompt: string | null = null,
context: DiagramContext
) {
if (dataViews.length === 0) {
return;
}
// 设置所有节点状态为运行中
const createAt = Date.now();
dataViews.forEach(dataView => {
dataView.status = 'running';
dataView.createAt = createAt;
});
context.render();
try {
const loop = new TaskLoop({ maxEpochs: 1 });
// 构建所有工具的信息
const allTools = dataViews.map(dataView => ({
name: dataView.tool.name,
description: dataView.tool.description,
inputSchema: dataView.tool.inputSchema,
enabled: true
}));
// 构建测试提示词,包含所有工具
const toolNames = dataViews.map(dv => dv.tool.name).join(', ');
const usePrompt = (prompt || 'please call the tools {tool} to make some test').replace('{tool}', toolNames);
const chatStorage = {
messages: [],
settings: {
temperature: 0.6,
systemPrompt: '',
enableTools: allTools,
enableWebSearch: false,
contextLength: 5,
enableXmlWrapper,
parallelToolCalls: true // 开启并行工具调用
}
} as ChatStorage;
loop.setMaxEpochs(1);
// 记录工具调用信息,用于匹配工具调用结果
const toolCallMap: Map<string, Reactive<NodeDataView>> = new Map(); // toolCallId -> dataView
let completedToolsCount = 0;
loop.registerOnToolCall(toolCall => {
// 找到对应的dataView
const dataView = dataViews.find(dv => dv.tool.name === toolCall.function?.name);
if (dataView) {
dataView.function = toolCall.function;
dataView.llmTimecost = Date.now() - createAt;
context.render();
// 记录工具调用ID与dataView的映射
if (toolCall.id) {
toolCallMap.set(toolCall.id, dataView);
}
}
return toolCall;
});
loop.registerOnToolCalled(toolCalled => {
// 这里我们需要改变策略因为没有工具调用ID信息
// 对于并行调用,我们可以根据工具调用的顺序来匹配结果
// 简单的策略:找到第一个仍在运行状态的工具
const runningView = dataViews.find(dv => dv.status === 'running');
if (runningView) {
runningView.toolcallTimecost = Date.now() - createAt - (runningView.llmTimecost || 0);
if (toolCalled.state === MessageState.Success) {
runningView.status = 'success';
runningView.result = toolCalled.content;
} else {
runningView.status = 'error';
runningView.result = toolCalled.content;
}
completedToolsCount++;
context.render();
// 如果所有工具都完成了,终止循环
if (completedToolsCount >= dataViews.length) {
loop.abort();
}
}
return toolCalled;
});
loop.registerOnError(error => {
dataViews.forEach(dataView => {
if (dataView.status === 'running') {
dataView.status = 'error';
dataView.result = error;
}
});
context.render();
});
await loop.start(chatStorage, usePrompt);
} finally {
const finishAt = Date.now();
dataViews.forEach(dataView => {
dataView.finishAt = finishAt;
if (dataView.status === 'running') {
dataView.status = 'success';
}
});
context.render();
}
}
export async function makeNodeTest( export async function makeNodeTest(
dataView: Reactive<NodeDataView>, dataView: Reactive<NodeDataView>,

View File

@ -62,11 +62,7 @@
</div> </div>
<div v-else> <div v-else>
<span class="caption"> <span class="caption">
<el-tooltip <el-tooltip placement="top" effect="light" :content="t('self-detect-caption')">
placement="top"
effect="light"
:content="t('self-detect-caption')"
>
<span class="iconfont icon-about"></span> <span class="iconfont icon-about"></span>
</el-tooltip> </el-tooltip>
</span> </span>
@ -144,7 +140,7 @@ if (autoDetectDiagram) {
// //
const testFormVisible = ref(false); const testFormVisible = ref(false);
const enableXmlWrapper = ref(false); const enableXmlWrapper = ref(false);
const enableParallelTest = ref(false); const enableParallelTest = ref(true);
const testPrompt = ref('please call the tool {tool} to make some test'); const testPrompt = ref('please call the tool {tool} to make some test');
async function onTestConfirm() { async function onTestConfirm() {
@ -154,13 +150,15 @@ async function onTestConfirm() {
tabStorage.autoDetectDiagram!.views = []; tabStorage.autoDetectDiagram!.views = [];
if (state) { if (state) {
if (enableParallelTest.value) { const dispatches = topoSortParallel(state);
//
const allViews = Array.from(state.dataView.values());
await makeParallelTest(allViews, enableXmlWrapper.value, testPrompt.value, context);
// if (enableParallelTest.value) {
allViews.forEach(view => { for (const nodeIds of dispatches) {
await Promise.all(nodeIds.map(async id => {
const view = state.dataView.get(id);
if (view) {
await makeNodeTest(view, enableXmlWrapper.value, testPrompt.value, context);
tabStorage.autoDetectDiagram!.views!.push({ tabStorage.autoDetectDiagram!.views!.push({
tool: view.tool, tool: view.tool,
status: view.status, status: view.status,
@ -171,10 +169,12 @@ async function onTestConfirm() {
llmTimecost: view.llmTimecost, llmTimecost: view.llmTimecost,
toolcallTimecost: view.toolcallTimecost, toolcallTimecost: view.toolcallTimecost,
}); });
}); context.render();
}
}));
}
} else { } else {
// //
const dispatches = topoSortParallel(state);
for (const nodeIds of dispatches) { for (const nodeIds of dispatches) {
for (const id of nodeIds) { for (const id of nodeIds) {
const view = state.dataView.get(id); const view = state.dataView.get(id);

View File

@ -17,6 +17,8 @@ export class LlmController {
webview.postMessage({ webview.postMessage({
command: 'llm/chat/completions/error', command: 'llm/chat/completions/error',
data: { data: {
sessionId: data.sessionId,
code: 500,
msg: error msg: error
} }
}); });

View File

@ -5,15 +5,16 @@ import { RestfulResponse } from "../common/index.dto.js";
import { ocrDB } from "../hook/db.js"; import { ocrDB } from "../hook/db.js";
import type { ToolCallContent } from "../mcp/client.dto.js"; import type { ToolCallContent } from "../mcp/client.dto.js";
import { ocrWorkerStorage } from "../mcp/ocr.service.js"; import { ocrWorkerStorage } from "../mcp/ocr.service.js";
import Table from 'cli-table3';
export let currentStream: AsyncIterable<any> | null = null; // 用 Map<string, AsyncIterable<any> | null> 管理多个流
export const chatStreams = new Map<string, AsyncIterable<any>>();
export async function streamingChatCompletion( export async function streamingChatCompletion(
data: any, data: any,
webview: PostMessageble webview: PostMessageble
) { ) {
const { const {
sessionId,
baseURL, baseURL,
apiKey, apiKey,
model, model,
@ -24,16 +25,6 @@ export async function streamingChatCompletion(
proxyServer = '' proxyServer = ''
} = data; } = data;
// 创建请求参数表格
const requestTable = new Table({
head: ['Parameter', 'Value'],
colWidths: [20, 40],
style: {
head: ['cyan'],
border: ['grey']
}
});
// 构建OpenRouter特定的请求头 // 构建OpenRouter特定的请求头
const defaultHeaders: Record<string, string> = {}; const defaultHeaders: Record<string, string> = {};
@ -48,25 +39,12 @@ export async function streamingChatCompletion(
defaultHeaders: Object.keys(defaultHeaders).length > 0 ? defaultHeaders : undefined defaultHeaders: Object.keys(defaultHeaders).length > 0 ? defaultHeaders : undefined
}); });
const seriableTools = (tools.length === 0) ? undefined: tools; const seriableTools = (tools.length === 0) ? undefined : tools;
const seriableParallelToolCalls = (tools.length === 0)? const seriableParallelToolCalls = (tools.length === 0) ?
undefined: model.startsWith('gemini') ? undefined : parallelToolCalls; undefined : model.startsWith('gemini') ? undefined : parallelToolCalls;
await postProcessMessages(messages); await postProcessMessages(messages);
// // 使用表格渲染请求参数
// requestTable.push(
// ['Model', model],
// ['Base URL', baseURL || 'Default'],
// ['Temperature', temperature],
// ['Tools Count', tools.length],
// ['Parallel Tool Calls', parallelToolCalls],
// ['Proxy Server', proxyServer || 'No Proxy']
// );
// console.log('\nOpenAI Request Parameters:');
// console.log(requestTable.toString());
const stream = await client.chat.completions.create({ const stream = await client.chat.completions.create({
model, model,
messages, messages,
@ -76,19 +54,20 @@ export async function streamingChatCompletion(
stream: true stream: true
}); });
// 存储当前的流式传输对象 // 用 sessionId 作为 key 存储流
currentStream = stream; if (sessionId) {
chatStreams.set(sessionId, stream);
}
// 流式传输结果 // 流式传输结果
for await (const chunk of stream) { for await (const chunk of stream) {
if (!currentStream) { if (!chatStreams.has(sessionId)) {
// 如果流被中止,则停止循环 // 如果流被中止,则停止循环
// TODO: 为每一个标签页设置不同的 currentStream 管理器
stream.controller.abort(); stream.controller.abort();
// 传输结束
webview.postMessage({ webview.postMessage({
command: 'llm/chat/completions/done', command: 'llm/chat/completions/done',
data: { data: {
sessionId,
code: 200, code: 200,
msg: { msg: {
success: true, success: true,
@ -100,24 +79,27 @@ export async function streamingChatCompletion(
} }
if (chunk.choices) { if (chunk.choices) {
const chunkResult = { webview.postMessage({
command: 'llm/chat/completions/chunk',
data: {
sessionId,
code: 200, code: 200,
msg: { msg: {
chunk chunk
} }
}; }
webview.postMessage({
command: 'llm/chat/completions/chunk',
data: chunkResult
}); });
} }
} }
// 传输结束 // 传输结束,移除对应的 stream
if (sessionId) {
chatStreams.delete(sessionId);
}
webview.postMessage({ webview.postMessage({
command: 'llm/chat/completions/done', command: 'llm/chat/completions/done',
data: { data: {
sessionId,
code: 200, code: 200,
msg: { msg: {
success: true, success: true,
@ -130,9 +112,9 @@ export async function streamingChatCompletion(
// 处理中止消息的函数 // 处理中止消息的函数
export function abortMessageService(data: any, webview: PostMessageble): RestfulResponse { export function abortMessageService(data: any, webview: PostMessageble): RestfulResponse {
if (currentStream) { const sessionId = data?.sessionId;
// 标记流已中止 if (sessionId) {
currentStream = null; chatStreams.delete(sessionId);
} }
return { return {