save
This commit is contained in:
parent
200ccdd0ac
commit
0dc727e781
@ -18,6 +18,7 @@ import { getXmlWrapperPrompt, getToolCallFromXmlString, getXmlsFromString, handl
|
||||
export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk;
|
||||
export interface TaskLoopChatOption {
|
||||
id?: string
|
||||
sessionId: string;
|
||||
proxyServer?: string
|
||||
enableXmlWrapper?: boolean
|
||||
}
|
||||
@ -193,9 +194,15 @@ export class TaskLoop {
|
||||
}
|
||||
|
||||
private doConversation(chatData: ChatCompletionCreateParamsBase, toolcallIndexAdapter: (toolCall: ToolCall) => IToolCallIndex) {
|
||||
const sessionId = chatData.sessionId;
|
||||
|
||||
return new Promise<IDoConversationResult>((resolve, reject) => {
|
||||
const chunkHandler = this.bridge.addCommandListener('llm/chat/completions/chunk', data => {
|
||||
|
||||
if (data.sessionId !== sessionId) {
|
||||
return;
|
||||
}
|
||||
|
||||
// data.code 一定为 200,否则不会走这个 route
|
||||
const { chunk } = data.msg as { chunk: ChatCompletionChunk };
|
||||
|
||||
@ -214,6 +221,10 @@ export class TaskLoop {
|
||||
}, { once: false });
|
||||
|
||||
const doneHandler = this.bridge.addCommandListener('llm/chat/completions/done', data => {
|
||||
if (data.sessionId !== sessionId) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.consumeDones();
|
||||
|
||||
chunkHandler();
|
||||
@ -225,6 +236,10 @@ export class TaskLoop {
|
||||
}, { once: true });
|
||||
|
||||
const errorHandler = this.bridge.addCommandListener('llm/chat/completions/error', data => {
|
||||
if (data.sessionId !== sessionId) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.consumeErrors({
|
||||
state: MessageState.ReceiveChunkError,
|
||||
msg: data.msg || '请求模型服务时发生错误'
|
||||
@ -304,7 +319,7 @@ export class TaskLoop {
|
||||
const id = crypto.randomUUID();
|
||||
|
||||
const chatData = {
|
||||
id,
|
||||
sessionId: id,
|
||||
baseURL,
|
||||
apiKey,
|
||||
model,
|
||||
@ -575,7 +590,7 @@ export class TaskLoop {
|
||||
break;
|
||||
}
|
||||
|
||||
this.currentChatId = chatData.id!;
|
||||
this.currentChatId = chatData.sessionId;
|
||||
const llm = this.getLlmConfig();
|
||||
const toolcallIndexAdapter = getToolCallIndexAdapter(llm, chatData);
|
||||
|
||||
|
@ -161,124 +161,6 @@ export function topoSortParallel(state: DiagramState): string[][] {
|
||||
return result;
|
||||
}
|
||||
|
||||
export async function makeParallelTest(
|
||||
dataViews: Reactive<NodeDataView>[],
|
||||
enableXmlWrapper: boolean,
|
||||
prompt: string | null = null,
|
||||
context: DiagramContext
|
||||
) {
|
||||
if (dataViews.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 设置所有节点状态为运行中
|
||||
const createAt = Date.now();
|
||||
dataViews.forEach(dataView => {
|
||||
dataView.status = 'running';
|
||||
dataView.createAt = createAt;
|
||||
});
|
||||
context.render();
|
||||
|
||||
try {
|
||||
const loop = new TaskLoop({ maxEpochs: 1 });
|
||||
|
||||
// 构建所有工具的信息
|
||||
const allTools = dataViews.map(dataView => ({
|
||||
name: dataView.tool.name,
|
||||
description: dataView.tool.description,
|
||||
inputSchema: dataView.tool.inputSchema,
|
||||
enabled: true
|
||||
}));
|
||||
|
||||
// 构建测试提示词,包含所有工具
|
||||
const toolNames = dataViews.map(dv => dv.tool.name).join(', ');
|
||||
const usePrompt = (prompt || 'please call the tools {tool} to make some test').replace('{tool}', toolNames);
|
||||
|
||||
const chatStorage = {
|
||||
messages: [],
|
||||
settings: {
|
||||
temperature: 0.6,
|
||||
systemPrompt: '',
|
||||
enableTools: allTools,
|
||||
enableWebSearch: false,
|
||||
contextLength: 5,
|
||||
enableXmlWrapper,
|
||||
parallelToolCalls: true // 开启并行工具调用
|
||||
}
|
||||
} as ChatStorage;
|
||||
|
||||
loop.setMaxEpochs(1);
|
||||
|
||||
// 记录工具调用信息,用于匹配工具调用结果
|
||||
const toolCallMap: Map<string, Reactive<NodeDataView>> = new Map(); // toolCallId -> dataView
|
||||
let completedToolsCount = 0;
|
||||
|
||||
loop.registerOnToolCall(toolCall => {
|
||||
// 找到对应的dataView
|
||||
const dataView = dataViews.find(dv => dv.tool.name === toolCall.function?.name);
|
||||
if (dataView) {
|
||||
dataView.function = toolCall.function;
|
||||
dataView.llmTimecost = Date.now() - createAt;
|
||||
context.render();
|
||||
|
||||
// 记录工具调用ID与dataView的映射
|
||||
if (toolCall.id) {
|
||||
toolCallMap.set(toolCall.id, dataView);
|
||||
}
|
||||
}
|
||||
return toolCall;
|
||||
});
|
||||
|
||||
loop.registerOnToolCalled(toolCalled => {
|
||||
// 这里我们需要改变策略,因为没有工具调用ID信息
|
||||
// 对于并行调用,我们可以根据工具调用的顺序来匹配结果
|
||||
// 简单的策略:找到第一个仍在运行状态的工具
|
||||
const runningView = dataViews.find(dv => dv.status === 'running');
|
||||
if (runningView) {
|
||||
runningView.toolcallTimecost = Date.now() - createAt - (runningView.llmTimecost || 0);
|
||||
|
||||
if (toolCalled.state === MessageState.Success) {
|
||||
runningView.status = 'success';
|
||||
runningView.result = toolCalled.content;
|
||||
} else {
|
||||
runningView.status = 'error';
|
||||
runningView.result = toolCalled.content;
|
||||
}
|
||||
|
||||
completedToolsCount++;
|
||||
context.render();
|
||||
|
||||
// 如果所有工具都完成了,终止循环
|
||||
if (completedToolsCount >= dataViews.length) {
|
||||
loop.abort();
|
||||
}
|
||||
}
|
||||
return toolCalled;
|
||||
});
|
||||
|
||||
loop.registerOnError(error => {
|
||||
dataViews.forEach(dataView => {
|
||||
if (dataView.status === 'running') {
|
||||
dataView.status = 'error';
|
||||
dataView.result = error;
|
||||
}
|
||||
});
|
||||
context.render();
|
||||
});
|
||||
|
||||
await loop.start(chatStorage, usePrompt);
|
||||
|
||||
} finally {
|
||||
const finishAt = Date.now();
|
||||
dataViews.forEach(dataView => {
|
||||
dataView.finishAt = finishAt;
|
||||
if (dataView.status === 'running') {
|
||||
dataView.status = 'success';
|
||||
}
|
||||
});
|
||||
context.render();
|
||||
}
|
||||
}
|
||||
|
||||
export async function makeNodeTest(
|
||||
dataView: Reactive<NodeDataView>,
|
||||
|
@ -56,19 +56,15 @@
|
||||
<el-scrollbar height="80vh">
|
||||
<Diagram :tab-id="props.tabId" />
|
||||
</el-scrollbar>
|
||||
|
||||
|
||||
<div class="caption" v-if="showCaption">
|
||||
{{ caption }}
|
||||
</div>
|
||||
<div v-else>
|
||||
<span class="caption">
|
||||
<el-tooltip
|
||||
placement="top"
|
||||
effect="light"
|
||||
:content="t('self-detect-caption')"
|
||||
>
|
||||
<span class="iconfont icon-about"></span>
|
||||
</el-tooltip>
|
||||
<el-tooltip placement="top" effect="light" :content="t('self-detect-caption')">
|
||||
<span class="iconfont icon-about"></span>
|
||||
</el-tooltip>
|
||||
</span>
|
||||
</div>
|
||||
</el-dialog>
|
||||
@ -144,37 +140,41 @@ if (autoDetectDiagram) {
|
||||
// 新增:自检参数表单相关
|
||||
const testFormVisible = ref(false);
|
||||
const enableXmlWrapper = ref(false);
|
||||
const enableParallelTest = ref(false);
|
||||
const enableParallelTest = ref(true);
|
||||
const testPrompt = ref('please call the tool {tool} to make some test');
|
||||
|
||||
async function onTestConfirm() {
|
||||
testFormVisible.value = false;
|
||||
const state = context.state;
|
||||
|
||||
|
||||
tabStorage.autoDetectDiagram!.views = [];
|
||||
|
||||
if (state) {
|
||||
const dispatches = topoSortParallel(state);
|
||||
|
||||
if (enableParallelTest.value) {
|
||||
// 并行测试模式:一次性测试所有工具
|
||||
const allViews = Array.from(state.dataView.values());
|
||||
await makeParallelTest(allViews, enableXmlWrapper.value, testPrompt.value, context);
|
||||
|
||||
// 保存结果
|
||||
allViews.forEach(view => {
|
||||
tabStorage.autoDetectDiagram!.views!.push({
|
||||
tool: view.tool,
|
||||
status: view.status,
|
||||
function: view.function,
|
||||
result: view.result,
|
||||
createAt: view.createAt,
|
||||
finishAt: view.finishAt,
|
||||
llmTimecost: view.llmTimecost,
|
||||
toolcallTimecost: view.toolcallTimecost,
|
||||
});
|
||||
});
|
||||
for (const nodeIds of dispatches) {
|
||||
|
||||
await Promise.all(nodeIds.map(async id => {
|
||||
const view = state.dataView.get(id);
|
||||
if (view) {
|
||||
await makeNodeTest(view, enableXmlWrapper.value, testPrompt.value, context);
|
||||
tabStorage.autoDetectDiagram!.views!.push({
|
||||
tool: view.tool,
|
||||
status: view.status,
|
||||
function: view.function,
|
||||
result: view.result,
|
||||
createAt: view.createAt,
|
||||
finishAt: view.finishAt,
|
||||
llmTimecost: view.llmTimecost,
|
||||
toolcallTimecost: view.toolcallTimecost,
|
||||
});
|
||||
context.render();
|
||||
}
|
||||
}));
|
||||
}
|
||||
} else {
|
||||
// 串行测试模式:按拓扑顺序逐个测试
|
||||
const dispatches = topoSortParallel(state);
|
||||
for (const nodeIds of dispatches) {
|
||||
for (const id of nodeIds) {
|
||||
const view = state.dataView.get(id);
|
||||
|
@ -17,6 +17,8 @@ export class LlmController {
|
||||
webview.postMessage({
|
||||
command: 'llm/chat/completions/error',
|
||||
data: {
|
||||
sessionId: data.sessionId,
|
||||
code: 500,
|
||||
msg: error
|
||||
}
|
||||
});
|
||||
|
@ -5,15 +5,16 @@ import { RestfulResponse } from "../common/index.dto.js";
|
||||
import { ocrDB } from "../hook/db.js";
|
||||
import type { ToolCallContent } from "../mcp/client.dto.js";
|
||||
import { ocrWorkerStorage } from "../mcp/ocr.service.js";
|
||||
import Table from 'cli-table3';
|
||||
|
||||
export let currentStream: AsyncIterable<any> | null = null;
|
||||
// 用 Map<string, AsyncIterable<any> | null> 管理多个流
|
||||
export const chatStreams = new Map<string, AsyncIterable<any>>();
|
||||
|
||||
export async function streamingChatCompletion(
|
||||
data: any,
|
||||
webview: PostMessageble
|
||||
) {
|
||||
const {
|
||||
sessionId,
|
||||
baseURL,
|
||||
apiKey,
|
||||
model,
|
||||
@ -24,16 +25,6 @@ export async function streamingChatCompletion(
|
||||
proxyServer = ''
|
||||
} = data;
|
||||
|
||||
// 创建请求参数表格
|
||||
const requestTable = new Table({
|
||||
head: ['Parameter', 'Value'],
|
||||
colWidths: [20, 40],
|
||||
style: {
|
||||
head: ['cyan'],
|
||||
border: ['grey']
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
// 构建OpenRouter特定的请求头
|
||||
const defaultHeaders: Record<string, string> = {};
|
||||
@ -47,26 +38,13 @@ export async function streamingChatCompletion(
|
||||
apiKey,
|
||||
defaultHeaders: Object.keys(defaultHeaders).length > 0 ? defaultHeaders : undefined
|
||||
});
|
||||
|
||||
const seriableTools = (tools.length === 0) ? undefined: tools;
|
||||
const seriableParallelToolCalls = (tools.length === 0)?
|
||||
undefined: model.startsWith('gemini') ? undefined : parallelToolCalls;
|
||||
|
||||
|
||||
const seriableTools = (tools.length === 0) ? undefined : tools;
|
||||
const seriableParallelToolCalls = (tools.length === 0) ?
|
||||
undefined : model.startsWith('gemini') ? undefined : parallelToolCalls;
|
||||
|
||||
await postProcessMessages(messages);
|
||||
|
||||
// // 使用表格渲染请求参数
|
||||
// requestTable.push(
|
||||
// ['Model', model],
|
||||
// ['Base URL', baseURL || 'Default'],
|
||||
// ['Temperature', temperature],
|
||||
// ['Tools Count', tools.length],
|
||||
// ['Parallel Tool Calls', parallelToolCalls],
|
||||
// ['Proxy Server', proxyServer || 'No Proxy']
|
||||
// );
|
||||
|
||||
// console.log('\nOpenAI Request Parameters:');
|
||||
// console.log(requestTable.toString());
|
||||
|
||||
const stream = await client.chat.completions.create({
|
||||
model,
|
||||
messages,
|
||||
@ -76,19 +54,20 @@ export async function streamingChatCompletion(
|
||||
stream: true
|
||||
});
|
||||
|
||||
// 存储当前的流式传输对象
|
||||
currentStream = stream;
|
||||
// 用 sessionId 作为 key 存储流
|
||||
if (sessionId) {
|
||||
chatStreams.set(sessionId, stream);
|
||||
}
|
||||
|
||||
// 流式传输结果
|
||||
for await (const chunk of stream) {
|
||||
if (!currentStream) {
|
||||
if (!chatStreams.has(sessionId)) {
|
||||
// 如果流被中止,则停止循环
|
||||
// TODO: 为每一个标签页设置不同的 currentStream 管理器
|
||||
stream.controller.abort();
|
||||
// 传输结束
|
||||
webview.postMessage({
|
||||
command: 'llm/chat/completions/done',
|
||||
data: {
|
||||
sessionId,
|
||||
code: 200,
|
||||
msg: {
|
||||
success: true,
|
||||
@ -98,26 +77,29 @@ export async function streamingChatCompletion(
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
if (chunk.choices) {
|
||||
const chunkResult = {
|
||||
code: 200,
|
||||
msg: {
|
||||
chunk
|
||||
}
|
||||
};
|
||||
|
||||
if (chunk.choices) {
|
||||
webview.postMessage({
|
||||
command: 'llm/chat/completions/chunk',
|
||||
data: chunkResult
|
||||
data: {
|
||||
sessionId,
|
||||
code: 200,
|
||||
msg: {
|
||||
chunk
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 传输结束
|
||||
// 传输结束,移除对应的 stream
|
||||
if (sessionId) {
|
||||
chatStreams.delete(sessionId);
|
||||
}
|
||||
webview.postMessage({
|
||||
command: 'llm/chat/completions/done',
|
||||
data: {
|
||||
sessionId,
|
||||
code: 200,
|
||||
msg: {
|
||||
success: true,
|
||||
@ -130,9 +112,9 @@ export async function streamingChatCompletion(
|
||||
|
||||
// 处理中止消息的函数
|
||||
export function abortMessageService(data: any, webview: PostMessageble): RestfulResponse {
|
||||
if (currentStream) {
|
||||
// 标记流已中止
|
||||
currentStream = null;
|
||||
const sessionId = data?.sessionId;
|
||||
if (sessionId) {
|
||||
chatStreams.delete(sessionId);
|
||||
}
|
||||
|
||||
return {
|
||||
@ -144,18 +126,18 @@ export function abortMessageService(data: any, webview: PostMessageble): Restful
|
||||
}
|
||||
|
||||
async function postProcessToolMessages(message: MyToolMessageType) {
|
||||
if (typeof message.content === 'string') {
|
||||
return;
|
||||
}
|
||||
if (typeof message.content === 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const content of message.content) {
|
||||
const contentType = content.type as string;
|
||||
const rawContent = content as ToolCallContent;
|
||||
for (const content of message.content) {
|
||||
const contentType = content.type as string;
|
||||
const rawContent = content as ToolCallContent;
|
||||
|
||||
if (contentType === 'image') {
|
||||
rawContent.type = 'text';
|
||||
|
||||
// 此时图片只会存在三个状态
|
||||
if (contentType === 'image') {
|
||||
rawContent.type = 'text';
|
||||
|
||||
// 此时图片只会存在三个状态
|
||||
// 1. 图片在 ocrDB 中
|
||||
// 2. 图片的 OCR 仍然在进行中
|
||||
// 3. 图片已被删除
|
||||
@ -164,7 +146,7 @@ async function postProcessToolMessages(message: MyToolMessageType) {
|
||||
// rawContent.data 就是 filename
|
||||
const result = await ocrDB.findById(rawContent.data);
|
||||
if (result) {
|
||||
rawContent.text = result.text || '';
|
||||
rawContent.text = result.text || '';
|
||||
} else if (rawContent._meta) {
|
||||
const workerId = rawContent._meta.workerId;
|
||||
const worker = ocrWorkerStorage.get(workerId);
|
||||
@ -177,35 +159,35 @@ async function postProcessToolMessages(message: MyToolMessageType) {
|
||||
}
|
||||
|
||||
delete rawContent._meta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message.content = JSON.stringify(message.content);
|
||||
message.content = JSON.stringify(message.content);
|
||||
}
|
||||
|
||||
export async function postProcessMessages(messages: MyMessageType[]) {
|
||||
for (const message of messages) {
|
||||
// 去除 extraInfo 属性
|
||||
delete message.extraInfo;
|
||||
for (const message of messages) {
|
||||
// 去除 extraInfo 属性
|
||||
delete message.extraInfo;
|
||||
|
||||
switch (message.role) {
|
||||
case 'user':
|
||||
switch (message.role) {
|
||||
case 'user':
|
||||
|
||||
break;
|
||||
|
||||
case 'assistant':
|
||||
case 'assistant':
|
||||
|
||||
break;
|
||||
|
||||
case 'system':
|
||||
|
||||
break;
|
||||
case 'system':
|
||||
|
||||
case 'tool':
|
||||
await postProcessToolMessages(message);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case 'tool':
|
||||
await postProcessToolMessages(message);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user