save
This commit is contained in:
parent
200ccdd0ac
commit
0dc727e781
@ -18,6 +18,7 @@ import { getXmlWrapperPrompt, getToolCallFromXmlString, getXmlsFromString, handl
|
|||||||
export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk;
|
export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk;
|
||||||
export interface TaskLoopChatOption {
|
export interface TaskLoopChatOption {
|
||||||
id?: string
|
id?: string
|
||||||
|
sessionId: string;
|
||||||
proxyServer?: string
|
proxyServer?: string
|
||||||
enableXmlWrapper?: boolean
|
enableXmlWrapper?: boolean
|
||||||
}
|
}
|
||||||
@ -193,9 +194,15 @@ export class TaskLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private doConversation(chatData: ChatCompletionCreateParamsBase, toolcallIndexAdapter: (toolCall: ToolCall) => IToolCallIndex) {
|
private doConversation(chatData: ChatCompletionCreateParamsBase, toolcallIndexAdapter: (toolCall: ToolCall) => IToolCallIndex) {
|
||||||
|
const sessionId = chatData.sessionId;
|
||||||
|
|
||||||
return new Promise<IDoConversationResult>((resolve, reject) => {
|
return new Promise<IDoConversationResult>((resolve, reject) => {
|
||||||
const chunkHandler = this.bridge.addCommandListener('llm/chat/completions/chunk', data => {
|
const chunkHandler = this.bridge.addCommandListener('llm/chat/completions/chunk', data => {
|
||||||
|
|
||||||
|
if (data.sessionId !== sessionId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// data.code 一定为 200,否则不会走这个 route
|
// data.code 一定为 200,否则不会走这个 route
|
||||||
const { chunk } = data.msg as { chunk: ChatCompletionChunk };
|
const { chunk } = data.msg as { chunk: ChatCompletionChunk };
|
||||||
|
|
||||||
@ -214,6 +221,10 @@ export class TaskLoop {
|
|||||||
}, { once: false });
|
}, { once: false });
|
||||||
|
|
||||||
const doneHandler = this.bridge.addCommandListener('llm/chat/completions/done', data => {
|
const doneHandler = this.bridge.addCommandListener('llm/chat/completions/done', data => {
|
||||||
|
if (data.sessionId !== sessionId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
this.consumeDones();
|
this.consumeDones();
|
||||||
|
|
||||||
chunkHandler();
|
chunkHandler();
|
||||||
@ -225,6 +236,10 @@ export class TaskLoop {
|
|||||||
}, { once: true });
|
}, { once: true });
|
||||||
|
|
||||||
const errorHandler = this.bridge.addCommandListener('llm/chat/completions/error', data => {
|
const errorHandler = this.bridge.addCommandListener('llm/chat/completions/error', data => {
|
||||||
|
if (data.sessionId !== sessionId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
this.consumeErrors({
|
this.consumeErrors({
|
||||||
state: MessageState.ReceiveChunkError,
|
state: MessageState.ReceiveChunkError,
|
||||||
msg: data.msg || '请求模型服务时发生错误'
|
msg: data.msg || '请求模型服务时发生错误'
|
||||||
@ -304,7 +319,7 @@ export class TaskLoop {
|
|||||||
const id = crypto.randomUUID();
|
const id = crypto.randomUUID();
|
||||||
|
|
||||||
const chatData = {
|
const chatData = {
|
||||||
id,
|
sessionId: id,
|
||||||
baseURL,
|
baseURL,
|
||||||
apiKey,
|
apiKey,
|
||||||
model,
|
model,
|
||||||
@ -575,7 +590,7 @@ export class TaskLoop {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.currentChatId = chatData.id!;
|
this.currentChatId = chatData.sessionId;
|
||||||
const llm = this.getLlmConfig();
|
const llm = this.getLlmConfig();
|
||||||
const toolcallIndexAdapter = getToolCallIndexAdapter(llm, chatData);
|
const toolcallIndexAdapter = getToolCallIndexAdapter(llm, chatData);
|
||||||
|
|
||||||
|
@ -161,124 +161,6 @@ export function topoSortParallel(state: DiagramState): string[][] {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function makeParallelTest(
|
|
||||||
dataViews: Reactive<NodeDataView>[],
|
|
||||||
enableXmlWrapper: boolean,
|
|
||||||
prompt: string | null = null,
|
|
||||||
context: DiagramContext
|
|
||||||
) {
|
|
||||||
if (dataViews.length === 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置所有节点状态为运行中
|
|
||||||
const createAt = Date.now();
|
|
||||||
dataViews.forEach(dataView => {
|
|
||||||
dataView.status = 'running';
|
|
||||||
dataView.createAt = createAt;
|
|
||||||
});
|
|
||||||
context.render();
|
|
||||||
|
|
||||||
try {
|
|
||||||
const loop = new TaskLoop({ maxEpochs: 1 });
|
|
||||||
|
|
||||||
// 构建所有工具的信息
|
|
||||||
const allTools = dataViews.map(dataView => ({
|
|
||||||
name: dataView.tool.name,
|
|
||||||
description: dataView.tool.description,
|
|
||||||
inputSchema: dataView.tool.inputSchema,
|
|
||||||
enabled: true
|
|
||||||
}));
|
|
||||||
|
|
||||||
// 构建测试提示词,包含所有工具
|
|
||||||
const toolNames = dataViews.map(dv => dv.tool.name).join(', ');
|
|
||||||
const usePrompt = (prompt || 'please call the tools {tool} to make some test').replace('{tool}', toolNames);
|
|
||||||
|
|
||||||
const chatStorage = {
|
|
||||||
messages: [],
|
|
||||||
settings: {
|
|
||||||
temperature: 0.6,
|
|
||||||
systemPrompt: '',
|
|
||||||
enableTools: allTools,
|
|
||||||
enableWebSearch: false,
|
|
||||||
contextLength: 5,
|
|
||||||
enableXmlWrapper,
|
|
||||||
parallelToolCalls: true // 开启并行工具调用
|
|
||||||
}
|
|
||||||
} as ChatStorage;
|
|
||||||
|
|
||||||
loop.setMaxEpochs(1);
|
|
||||||
|
|
||||||
// 记录工具调用信息,用于匹配工具调用结果
|
|
||||||
const toolCallMap: Map<string, Reactive<NodeDataView>> = new Map(); // toolCallId -> dataView
|
|
||||||
let completedToolsCount = 0;
|
|
||||||
|
|
||||||
loop.registerOnToolCall(toolCall => {
|
|
||||||
// 找到对应的dataView
|
|
||||||
const dataView = dataViews.find(dv => dv.tool.name === toolCall.function?.name);
|
|
||||||
if (dataView) {
|
|
||||||
dataView.function = toolCall.function;
|
|
||||||
dataView.llmTimecost = Date.now() - createAt;
|
|
||||||
context.render();
|
|
||||||
|
|
||||||
// 记录工具调用ID与dataView的映射
|
|
||||||
if (toolCall.id) {
|
|
||||||
toolCallMap.set(toolCall.id, dataView);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return toolCall;
|
|
||||||
});
|
|
||||||
|
|
||||||
loop.registerOnToolCalled(toolCalled => {
|
|
||||||
// 这里我们需要改变策略,因为没有工具调用ID信息
|
|
||||||
// 对于并行调用,我们可以根据工具调用的顺序来匹配结果
|
|
||||||
// 简单的策略:找到第一个仍在运行状态的工具
|
|
||||||
const runningView = dataViews.find(dv => dv.status === 'running');
|
|
||||||
if (runningView) {
|
|
||||||
runningView.toolcallTimecost = Date.now() - createAt - (runningView.llmTimecost || 0);
|
|
||||||
|
|
||||||
if (toolCalled.state === MessageState.Success) {
|
|
||||||
runningView.status = 'success';
|
|
||||||
runningView.result = toolCalled.content;
|
|
||||||
} else {
|
|
||||||
runningView.status = 'error';
|
|
||||||
runningView.result = toolCalled.content;
|
|
||||||
}
|
|
||||||
|
|
||||||
completedToolsCount++;
|
|
||||||
context.render();
|
|
||||||
|
|
||||||
// 如果所有工具都完成了,终止循环
|
|
||||||
if (completedToolsCount >= dataViews.length) {
|
|
||||||
loop.abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return toolCalled;
|
|
||||||
});
|
|
||||||
|
|
||||||
loop.registerOnError(error => {
|
|
||||||
dataViews.forEach(dataView => {
|
|
||||||
if (dataView.status === 'running') {
|
|
||||||
dataView.status = 'error';
|
|
||||||
dataView.result = error;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
context.render();
|
|
||||||
});
|
|
||||||
|
|
||||||
await loop.start(chatStorage, usePrompt);
|
|
||||||
|
|
||||||
} finally {
|
|
||||||
const finishAt = Date.now();
|
|
||||||
dataViews.forEach(dataView => {
|
|
||||||
dataView.finishAt = finishAt;
|
|
||||||
if (dataView.status === 'running') {
|
|
||||||
dataView.status = 'success';
|
|
||||||
}
|
|
||||||
});
|
|
||||||
context.render();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function makeNodeTest(
|
export async function makeNodeTest(
|
||||||
dataView: Reactive<NodeDataView>,
|
dataView: Reactive<NodeDataView>,
|
||||||
|
@ -62,11 +62,7 @@
|
|||||||
</div>
|
</div>
|
||||||
<div v-else>
|
<div v-else>
|
||||||
<span class="caption">
|
<span class="caption">
|
||||||
<el-tooltip
|
<el-tooltip placement="top" effect="light" :content="t('self-detect-caption')">
|
||||||
placement="top"
|
|
||||||
effect="light"
|
|
||||||
:content="t('self-detect-caption')"
|
|
||||||
>
|
|
||||||
<span class="iconfont icon-about"></span>
|
<span class="iconfont icon-about"></span>
|
||||||
</el-tooltip>
|
</el-tooltip>
|
||||||
</span>
|
</span>
|
||||||
@ -144,7 +140,7 @@ if (autoDetectDiagram) {
|
|||||||
// 新增:自检参数表单相关
|
// 新增:自检参数表单相关
|
||||||
const testFormVisible = ref(false);
|
const testFormVisible = ref(false);
|
||||||
const enableXmlWrapper = ref(false);
|
const enableXmlWrapper = ref(false);
|
||||||
const enableParallelTest = ref(false);
|
const enableParallelTest = ref(true);
|
||||||
const testPrompt = ref('please call the tool {tool} to make some test');
|
const testPrompt = ref('please call the tool {tool} to make some test');
|
||||||
|
|
||||||
async function onTestConfirm() {
|
async function onTestConfirm() {
|
||||||
@ -154,13 +150,15 @@ async function onTestConfirm() {
|
|||||||
tabStorage.autoDetectDiagram!.views = [];
|
tabStorage.autoDetectDiagram!.views = [];
|
||||||
|
|
||||||
if (state) {
|
if (state) {
|
||||||
if (enableParallelTest.value) {
|
const dispatches = topoSortParallel(state);
|
||||||
// 并行测试模式:一次性测试所有工具
|
|
||||||
const allViews = Array.from(state.dataView.values());
|
|
||||||
await makeParallelTest(allViews, enableXmlWrapper.value, testPrompt.value, context);
|
|
||||||
|
|
||||||
// 保存结果
|
if (enableParallelTest.value) {
|
||||||
allViews.forEach(view => {
|
for (const nodeIds of dispatches) {
|
||||||
|
|
||||||
|
await Promise.all(nodeIds.map(async id => {
|
||||||
|
const view = state.dataView.get(id);
|
||||||
|
if (view) {
|
||||||
|
await makeNodeTest(view, enableXmlWrapper.value, testPrompt.value, context);
|
||||||
tabStorage.autoDetectDiagram!.views!.push({
|
tabStorage.autoDetectDiagram!.views!.push({
|
||||||
tool: view.tool,
|
tool: view.tool,
|
||||||
status: view.status,
|
status: view.status,
|
||||||
@ -171,10 +169,12 @@ async function onTestConfirm() {
|
|||||||
llmTimecost: view.llmTimecost,
|
llmTimecost: view.llmTimecost,
|
||||||
toolcallTimecost: view.toolcallTimecost,
|
toolcallTimecost: view.toolcallTimecost,
|
||||||
});
|
});
|
||||||
});
|
context.render();
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 串行测试模式:按拓扑顺序逐个测试
|
// 串行测试模式:按拓扑顺序逐个测试
|
||||||
const dispatches = topoSortParallel(state);
|
|
||||||
for (const nodeIds of dispatches) {
|
for (const nodeIds of dispatches) {
|
||||||
for (const id of nodeIds) {
|
for (const id of nodeIds) {
|
||||||
const view = state.dataView.get(id);
|
const view = state.dataView.get(id);
|
||||||
|
@ -17,6 +17,8 @@ export class LlmController {
|
|||||||
webview.postMessage({
|
webview.postMessage({
|
||||||
command: 'llm/chat/completions/error',
|
command: 'llm/chat/completions/error',
|
||||||
data: {
|
data: {
|
||||||
|
sessionId: data.sessionId,
|
||||||
|
code: 500,
|
||||||
msg: error
|
msg: error
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -5,15 +5,16 @@ import { RestfulResponse } from "../common/index.dto.js";
|
|||||||
import { ocrDB } from "../hook/db.js";
|
import { ocrDB } from "../hook/db.js";
|
||||||
import type { ToolCallContent } from "../mcp/client.dto.js";
|
import type { ToolCallContent } from "../mcp/client.dto.js";
|
||||||
import { ocrWorkerStorage } from "../mcp/ocr.service.js";
|
import { ocrWorkerStorage } from "../mcp/ocr.service.js";
|
||||||
import Table from 'cli-table3';
|
|
||||||
|
|
||||||
export let currentStream: AsyncIterable<any> | null = null;
|
// 用 Map<string, AsyncIterable<any> | null> 管理多个流
|
||||||
|
export const chatStreams = new Map<string, AsyncIterable<any>>();
|
||||||
|
|
||||||
export async function streamingChatCompletion(
|
export async function streamingChatCompletion(
|
||||||
data: any,
|
data: any,
|
||||||
webview: PostMessageble
|
webview: PostMessageble
|
||||||
) {
|
) {
|
||||||
const {
|
const {
|
||||||
|
sessionId,
|
||||||
baseURL,
|
baseURL,
|
||||||
apiKey,
|
apiKey,
|
||||||
model,
|
model,
|
||||||
@ -24,16 +25,6 @@ export async function streamingChatCompletion(
|
|||||||
proxyServer = ''
|
proxyServer = ''
|
||||||
} = data;
|
} = data;
|
||||||
|
|
||||||
// 创建请求参数表格
|
|
||||||
const requestTable = new Table({
|
|
||||||
head: ['Parameter', 'Value'],
|
|
||||||
colWidths: [20, 40],
|
|
||||||
style: {
|
|
||||||
head: ['cyan'],
|
|
||||||
border: ['grey']
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
// 构建OpenRouter特定的请求头
|
// 构建OpenRouter特定的请求头
|
||||||
const defaultHeaders: Record<string, string> = {};
|
const defaultHeaders: Record<string, string> = {};
|
||||||
@ -54,19 +45,6 @@ export async function streamingChatCompletion(
|
|||||||
|
|
||||||
await postProcessMessages(messages);
|
await postProcessMessages(messages);
|
||||||
|
|
||||||
// // 使用表格渲染请求参数
|
|
||||||
// requestTable.push(
|
|
||||||
// ['Model', model],
|
|
||||||
// ['Base URL', baseURL || 'Default'],
|
|
||||||
// ['Temperature', temperature],
|
|
||||||
// ['Tools Count', tools.length],
|
|
||||||
// ['Parallel Tool Calls', parallelToolCalls],
|
|
||||||
// ['Proxy Server', proxyServer || 'No Proxy']
|
|
||||||
// );
|
|
||||||
|
|
||||||
// console.log('\nOpenAI Request Parameters:');
|
|
||||||
// console.log(requestTable.toString());
|
|
||||||
|
|
||||||
const stream = await client.chat.completions.create({
|
const stream = await client.chat.completions.create({
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
@ -76,19 +54,20 @@ export async function streamingChatCompletion(
|
|||||||
stream: true
|
stream: true
|
||||||
});
|
});
|
||||||
|
|
||||||
// 存储当前的流式传输对象
|
// 用 sessionId 作为 key 存储流
|
||||||
currentStream = stream;
|
if (sessionId) {
|
||||||
|
chatStreams.set(sessionId, stream);
|
||||||
|
}
|
||||||
|
|
||||||
// 流式传输结果
|
// 流式传输结果
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
if (!currentStream) {
|
if (!chatStreams.has(sessionId)) {
|
||||||
// 如果流被中止,则停止循环
|
// 如果流被中止,则停止循环
|
||||||
// TODO: 为每一个标签页设置不同的 currentStream 管理器
|
|
||||||
stream.controller.abort();
|
stream.controller.abort();
|
||||||
// 传输结束
|
|
||||||
webview.postMessage({
|
webview.postMessage({
|
||||||
command: 'llm/chat/completions/done',
|
command: 'llm/chat/completions/done',
|
||||||
data: {
|
data: {
|
||||||
|
sessionId,
|
||||||
code: 200,
|
code: 200,
|
||||||
msg: {
|
msg: {
|
||||||
success: true,
|
success: true,
|
||||||
@ -100,24 +79,27 @@ export async function streamingChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (chunk.choices) {
|
if (chunk.choices) {
|
||||||
const chunkResult = {
|
webview.postMessage({
|
||||||
|
command: 'llm/chat/completions/chunk',
|
||||||
|
data: {
|
||||||
|
sessionId,
|
||||||
code: 200,
|
code: 200,
|
||||||
msg: {
|
msg: {
|
||||||
chunk
|
chunk
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
webview.postMessage({
|
|
||||||
command: 'llm/chat/completions/chunk',
|
|
||||||
data: chunkResult
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 传输结束
|
// 传输结束,移除对应的 stream
|
||||||
|
if (sessionId) {
|
||||||
|
chatStreams.delete(sessionId);
|
||||||
|
}
|
||||||
webview.postMessage({
|
webview.postMessage({
|
||||||
command: 'llm/chat/completions/done',
|
command: 'llm/chat/completions/done',
|
||||||
data: {
|
data: {
|
||||||
|
sessionId,
|
||||||
code: 200,
|
code: 200,
|
||||||
msg: {
|
msg: {
|
||||||
success: true,
|
success: true,
|
||||||
@ -130,9 +112,9 @@ export async function streamingChatCompletion(
|
|||||||
|
|
||||||
// 处理中止消息的函数
|
// 处理中止消息的函数
|
||||||
export function abortMessageService(data: any, webview: PostMessageble): RestfulResponse {
|
export function abortMessageService(data: any, webview: PostMessageble): RestfulResponse {
|
||||||
if (currentStream) {
|
const sessionId = data?.sessionId;
|
||||||
// 标记流已中止
|
if (sessionId) {
|
||||||
currentStream = null;
|
chatStreams.delete(sessionId);
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user