openmcp-client/service/src/llm/llm.service.ts
2025-05-30 00:26:11 +08:00

177 lines
4.2 KiB
TypeScript

import { PostMessageble } from "../hook/adapter.js";
import { OpenAI } from "openai";
import { MyMessageType, MyToolMessageType } from "./llm.dto.js";
import { RestfulResponse } from "../common/index.dto.js";
import { ocrDB } from "../hook/db.js";
import type { ToolCallContent } from "../mcp/client.dto.js";
import { ocrWorkerStorage } from "../mcp/ocr.service.js";
export let currentStream: AsyncIterable<any> | null = null;
export async function streamingChatCompletion(
data: any,
webview: PostMessageble
) {
let {
baseURL,
apiKey,
model,
messages,
temperature,
tools = [],
parallelToolCalls = true
} = data;
const client = new OpenAI({
baseURL,
apiKey
});
if (tools.length === 0) {
tools = undefined;
}
await postProcessMessages(messages);
const stream = await client.chat.completions.create({
model,
messages,
temperature,
tools,
parallel_tool_calls: parallelToolCalls,
stream: true
});
// 存储当前的流式传输对象
currentStream = stream;
// 流式传输结果
for await (const chunk of stream) {
if (!currentStream) {
// 如果流被中止,则停止循环
// TODO: 为每一个标签页设置不同的 currentStream 管理器
stream.controller.abort();
// 传输结束
webview.postMessage({
command: 'llm/chat/completions/done',
data: {
code: 200,
msg: {
success: true,
stage: 'abort'
}
}
});
break;
}
if (chunk.choices) {
const chunkResult = {
code: 200,
msg: {
chunk
}
};
webview.postMessage({
command: 'llm/chat/completions/chunk',
data: chunkResult
});
}
}
// 传输结束
webview.postMessage({
command: 'llm/chat/completions/done',
data: {
code: 200,
msg: {
success: true,
stage: 'done'
}
}
});
}
// 处理中止消息的函数
export function abortMessageService(data: any, webview: PostMessageble): RestfulResponse {
if (currentStream) {
// 标记流已中止
currentStream = null;
}
return {
code: 200,
msg: {
success: true
}
}
}
async function postProcessToolMessages(message: MyToolMessageType) {
if (typeof message.content === 'string') {
return;
}
for (const content of message.content) {
const contentType = content.type as string;
const rawContent = content as ToolCallContent;
if (contentType === 'image') {
rawContent.type = 'text';
// 此时图片只会存在三个状态
// 1. 图片在 ocrDB 中
// 2. 图片的 OCR 仍然在进行中
// 3. 图片已被删除
// rawContent.data 就是 filename
const result = await ocrDB.findById(rawContent.data);
if (result) {
rawContent.text = result.text || '';
} else if (rawContent._meta) {
const workerId = rawContent._meta.workerId;
const worker = ocrWorkerStorage.get(workerId);
if (worker) {
const text = await worker.fut;
rawContent.text = text;
}
} else {
rawContent.text = '无效的图片';
}
delete rawContent._meta;
}
}
message.content = JSON.stringify(message.content);
}
export async function postProcessMessages(messages: MyMessageType[]) {
for (const message of messages) {
// 去除 extraInfo 属性
delete message.extraInfo;
switch (message.role) {
case 'user':
break;
case 'assistant':
break;
case 'system':
break;
case 'tool':
await postProcessToolMessages(message);
break;
default:
break;
}
}
}