测试 OCR 全流程

This commit is contained in:
锦恢 2025-04-27 22:28:29 +08:00
parent 31fa5ead4f
commit dd0d6016fa
15 changed files with 417 additions and 187 deletions

View File

@ -13,7 +13,8 @@ export enum MessageState {
Abort = 'abort',
ToolCall = 'tool call failed',
None = 'none',
Success = 'success'
Success = 'success',
ParseJsonError = 'parse json error'
}
export interface IExtraInfo {

View File

@ -23,7 +23,10 @@
<!-- 助手调用的工具部分 -->
<div class="message-content" v-else-if="message.role === 'assistant/tool_calls'">
<Message.Toolcall :message="message" :tab-id="props.tabId" />
<Message.Toolcall
:message="message" :tab-id="props.tabId"
@update:tool-result="(value, index) => (message.toolResult || [])[index] = value"
/>
</div>
</div>

View File

@ -0,0 +1,89 @@
<template>
<el-scrollbar width="100%">
<div v-if="props.item.type === 'text'" class="tool-text">
{{ props.item.text }}
</div>
<div v-else-if="props.item.type === 'image'" class="tool-image">
#{{ props.item.data }}
<span v-if="!finishProcess">
<el-progress
:percentage="progress"
:stroke-width="2"
:show-text="false"
>
<template #default="{ percentage }">
<span class="percentage-label">{{ progressText }}</span>
<span class="percentage-value">{{ percentage }}%</span>
</template>
</el-progress>
</span>
</div>
<div v-else class="tool-other">{{ JSON.stringify(props.item) }}</div>
</el-scrollbar>
</template>
<script setup lang="ts">
import { useMessageBridge } from '@/api/message-bridge';
import { ToolCallContent } from '@/hook/type';
import { defineComponent, PropType, defineProps, ref, defineEmits } from 'vue';
import { tabs } from '../../panel';
import { IRenderMessage } from '../chat';
defineComponent({ name: 'toolcall-result-item' });
const emit = defineEmits(['update:item']);
const props = defineProps({
item: {
type: Object as PropType<ToolCallContent>,
required: true
}
});
const metaInfo = props.item._meta || {};
const { ocr = false, workerId = '' } = metaInfo;
//
const progress = ref(0);
const progressText = ref('');
const finishProcess = ref(true);
if (ocr) {
finishProcess.value = false;
const bridge = useMessageBridge();
const cancel = bridge.addCommandListener('ocr/worker/log', data => {
finishProcess.value = false;
const { id, progress: p = 1.0, status = 'finish' } = data;
if (id === workerId) {
progressText.value = status;
progress.value = Math.min(Math.max(p * 100 ,0), 100);
}
}, { once: false });
bridge.addCommandListener('ocr/worker/done', () => {
progress.value = 1;
finishProcess.value = true;
if (props.item._meta) {
const { _meta, ...rest } = props.item;
emit('update:item', { ...rest });
}
cancel();
}, { once: true });
}
</script>
<style>
.tool-image {
position: relative;
}
.el-progress {
position: absolute;
bottom: 0;
left: 0;
right: 0;
}
</style>

View File

@ -11,7 +11,6 @@
<el-collapse v-model="activeNames" v-if="props.message.tool_calls">
<el-collapse-item name="tool">
<template #title>
<div class="tool-calls">
<div class="tool-call-header">
@ -52,27 +51,24 @@
:inactive-action-style="'backgroundColor: var(--sidebar)'" />
</span>
</div>
<div class="tool-result" v-if="isValid">
<!-- 展示 JSON -->
<div v-if="props.message.showJson!.value" class="tool-result-content">
<div class="inner">
<div v-html="toHtml(props.message.toolResult)"></div>
</div>
</div>
<!-- 展示富文本 -->
<span v-else>
<div v-for="(item, index) in props.message.toolResult" :key="index"
class="response-item"
>
<el-scrollbar width="100%">
<div v-if="item.type === 'text'" class="tool-text">
{{ item.text }}
</div>
<div v-else-if="item.type === 'image'" class="tool-image">
<img :src="`data:${item.mimeType};base64,${item.data}`" style="max-width: 70%;" />
</div>
<div v-else class="tool-other">{{ JSON.stringify(item) }}</div>
</el-scrollbar>
<ToolcallResultItem
:item="item"
@update:item="value => updateToolCallResultItem(value, index)"
/>
</div>
</span>
</div>
@ -96,7 +92,7 @@
</template>
<script setup lang="ts">
import { defineProps, ref, watch, PropType, computed } from 'vue';
import { defineProps, ref, watch, PropType, computed, defineEmits } from 'vue';
import MessageMeta from './message-meta.vue';
import { markdownToHtml } from '../markdown';
@ -104,6 +100,8 @@ import { createTest } from '@/views/setting/llm';
import { IRenderMessage, MessageState } from '../chat';
import { ToolCallContent } from '@/hook/type';
import ToolcallResultItem from './toolcall-result-item.vue';
const props = defineProps({
message: {
type: Object as PropType<IRenderMessage>,
@ -115,7 +113,6 @@ const props = defineProps({
}
});
const activeNames = ref<string[]>(props.message.toolResult ? [''] : ['tool']);
watch(
@ -129,6 +126,10 @@ watch(
}
);
/**
* @description 将工具调用结果转换成 html
* @param toolResult
*/
const toHtml = (toolResult: ToolCallContent[]) => {
const formattedJson = JSON.stringify(toolResult, null, 2);
const html = markdownToHtml('```json\n' + formattedJson + '\n```');
@ -187,6 +188,12 @@ const collectErrors = computed(() => {
}
});
const emit = defineEmits(['update:tool-result']);
function updateToolCallResultItem(value: any, index: number) {
emit('update:tool-result', value, index);
}
</script>
<style>

View File

@ -5,6 +5,7 @@ import { useMessageBridge } from "@/api/message-bridge";
import type { OpenAI } from 'openai';
import { callTool } from "../tool/tools";
import { llmManager, llms } from "@/views/setting/llm";
import { pinkLog, redLog } from "@/views/setting/util";
export type ChatCompletionChunk = OpenAI.Chat.Completions.ChatCompletionChunk;
export type ChatCompletionCreateParamsBase = OpenAI.Chat.Completions.ChatCompletionCreateParams & { id?: string };
@ -40,13 +41,38 @@ export class TaskLoop {
private async handleToolCalls(toolCalls: ToolCall[]) {
// TODO: 调用多个工具并返回调用结果?
const toolCall = toolCalls[0];
let toolName: string;
let toolArgs: Record<string, any>;
try {
toolName = toolCall.function.name;
toolArgs = JSON.parse(toolCall.function.arguments);
} catch (error) {
return {
content: [{
type: 'error',
text: this.parseErrorObject(error)
}],
state: MessageState.ParseJsonError
};
}
try {
const toolName = toolCall.function.name;
const toolArgs = JSON.parse(toolCall.function.arguments);
const toolResponse = await callTool(toolName, toolArgs);
if (!toolResponse.isError) {
// const content = JSON.stringify(toolResponse.content);
console.log(toolResponse);
if (typeof toolResponse === 'string') {
console.log(toolResponse);
return {
content: toolResponse,
state: MessageState.ToolCall
}
} else if (!toolResponse.isError) {
return {
content: toolResponse.content,
state: MessageState.Success
@ -190,14 +216,6 @@ export class TaskLoop {
const loadMessages = tabStorage.messages.slice(- tabStorage.settings.contextLength);
userMessages.push(...loadMessages);
// 过滤一下 userMessages现在的大部分模型只支持 text, image_url and video_url 这三种类型的数据
const postProcessMessages = [];
for (const msg of userMessages) {
if (msg.role === 'tool') {
}
}
// 增加一个id用于锁定状态
const id = crypto.randomUUID();
@ -287,12 +305,37 @@ export class TaskLoop {
}
});
pinkLog('调用工具数量:' + this.streamingToolCalls.value.length);
const toolCallResult = await this.handleToolCalls(this.streamingToolCalls.value);
console.log(toolCallResult);
if (toolCallResult.state === MessageState.ParseJsonError) {
// 如果是因为解析 JSON 错误,则重新开始
tabStorage.messages.pop();
redLog('解析 JSON 错误 ' + this.streamingToolCalls.value[0].function.arguments);
continue;
}
if (toolCallResult.content) {
if (toolCallResult.state === MessageState.Success) {
const toolCall = this.streamingToolCalls.value[0];
tabStorage.messages.push({
role: 'tool',
tool_call_id: toolCall.id || toolCall.function.name,
content: toolCallResult.content,
extraInfo: {
created: Date.now(),
state: toolCallResult.state,
serverName: llms[llmManager.currentModelIndex].id || 'unknown',
usage: this.completionUsage
}
});
}
if (toolCallResult.state === MessageState.ToolCall) {
const toolCall = this.streamingToolCalls.value[0];
tabStorage.messages.push({

View File

@ -123,6 +123,10 @@ class DiskStorage {
fs.writeFileSync(filePath, data, options);
}
public getStoragePath(filename: string): string {
return path.join(this.#storageHome, filename);
}
public deleteSync(filename: string): void {
const filePath = path.join(this.#storageHome, filename);
if (fs.existsSync(filePath)) {
@ -139,7 +143,7 @@ interface SettingItem extends Entity {
interface OcrItem extends Entity {
filename: string;
text?: string;
textCreateTime: number;
createTime: number;
}
export const diskStorage = new DiskStorage();

View File

@ -1,102 +0,0 @@
import * as os from 'os';
import * as path from 'path';
import * as fs from 'fs';
import { v4 as uuidv4 } from 'uuid';
import { diskStorage } from './db';
import { PostMessageble } from './adapter';
export function saveBase64ImageData(
base64String: string,
mimeType: string
): string {
// 从 base64 字符串中提取数据部分
const base64Data = base64String.replace(/^data:.+;base64,/, '');
// 生成唯一文件名
const fileName = `${uuidv4()}.${mimeType.split('/')[1]}`;
diskStorage.setSync(fileName, base64Data, { encoding: 'base64' });
return fileName;
}
export function loadBase64ImageData(fileName: string): string {
const homedir = os.homedir();
const imageStorageFolder = path.join(homedir, '.openmcp','storage');
const filePath = path.join(imageStorageFolder, fileName);
// 读取文件内容
if (!fs.existsSync(filePath)) {
return '';
}
const fileContent = fs.readFileSync(filePath, { encoding: 'base64' });
// 构建 base64 字符串
const base64String = `data:image/png;base64,${fileContent}`;
return base64String;
}
interface ToolCallContent {
type: string;
text?: string;
data?: any;
mimeType?: string;
_meta?: any;
[key: string]: any;
}
interface ToolCallResponse {
_meta?: any;
content?: ToolCallContent[];
isError?: boolean;
toolResult?: any;
}
async function handleImage(
content: ToolCallContent,
webview: PostMessageble
) {
if (content.data && content.mimeType) {
const fileName = saveBase64ImageData(content.data, content.mimeType);
content.data = fileName;
content._meta = {
ocr: true,
status: 'pending'
};
// 加入工作线程
}
}
/**
* @description mcp server
* 线 image url
* 0.x.x
* @param response
* @returns
*/
export function postProcessMcpToolcallResponse(
response: ToolCallResponse,
webview: PostMessageble
): ToolCallResponse {
if (response.isError) {
// 如果是错误响应,将其转换为错误信息
return response;
}
// 将 content 中的图像 base64 提取出来,并保存到本地
for (const content of response.content || []) {
switch (content.type) {
case 'image':
handleImage(content, webview);
break;
default:
break;
}
}
return response;
}

View File

@ -1,32 +0,0 @@
import Tesseract from 'tesseract.js';
export async function tesseractOCR(
imagePath: string,
logger: (message: Tesseract.LoggerMessage) => void,
lang: string = 'eng+chi_sim'
) {
try {
const { data: { text } } = await Tesseract.recognize(
imagePath,
lang,
{
logger
}
);
return text;
} catch (error) {
console.error('OCR error:', error);
}
return '无法识别图片';
}
export async function ocr(
filename: string,
logger: (message: Tesseract.LoggerMessage) => void,
lang: string = 'eng+chi_sim'
) {
}

View File

@ -2,6 +2,9 @@ import { PostMessageble } from "../hook/adapter";
import { OpenAI } from "openai";
import { MyMessageType, MyToolMessageType } from "./llm.dto";
import { RestfulResponse } from "../common/index.dto";
import { ocrDB } from "../hook/db";
import { ToolCallContent } from "../mcp/client.dto";
import { ocrWorkerStorage } from "../mcp/ocr.service";
export let currentStream: AsyncIterable<any> | null = null;
@ -20,7 +23,7 @@ export async function streamingChatCompletion(
tools = undefined;
}
postProcessMessages(messages);
await postProcessMessages(messages);
const stream = await client.chat.completions.create({
model,
@ -99,45 +102,66 @@ export function abortMessageService(data: any, webview: PostMessageble): Restful
}
}
function postProcessToolMessages(message: MyToolMessageType) {
async function postProcessToolMessages(message: MyToolMessageType) {
if (typeof message.content === 'string') {
return;
}
for (const content of message.content) {
const contentType = content.type as string;
const rawContent = content as any;
const rawContent = content as ToolCallContent;
if (contentType === 'image') {
delete rawContent._meta;
rawContent.type = 'text';
// 从缓存中提取图像数据
rawContent.text = '图片已被删除';
// 此时图片只会存在三个状态
// 1. 图片在 ocrDB 中
// 2. 图片的 OCR 仍然在进行中
// 3. 图片已被删除
// rawContent.data 就是 filename
const result = await ocrDB.findById(rawContent.data);
if (result) {
rawContent.text = result.text || '';
} else if (rawContent._meta) {
const workerId = rawContent._meta.workerId;
const worker = ocrWorkerStorage.get(workerId);
if (worker) {
const text = await worker.fut;
rawContent.text = text;
}
} else {
rawContent.text = '无效的图片';
}
delete rawContent._meta;
}
}
message.content = JSON.stringify(message.content);
}
export function postProcessMessages(messages: MyMessageType[]) {
export async function postProcessMessages(messages: MyMessageType[]) {
for (const message of messages) {
// 去除 extraInfo 属性
delete message.extraInfo;
switch (message.role) {
case 'user':
break;
break;
case 'assistant':
break;
break;
case 'system':
break;
case 'tool':
postProcessToolMessages(message);
await postProcessToolMessages(message);
break;
default:
break;

View File

@ -1,5 +1,6 @@
import { Controller, RequestClientType } from "../common";
import { PostMessageble } from "../hook/adapter";
import { postProcessMcpToolcallResponse } from "./client.service";
export class ClientController {
@ -130,6 +131,14 @@ export class ClientController {
name: option.toolName,
arguments: option.toolArgs
});
console.log(JSON.stringify(toolResult, null, 2));
postProcessMcpToolcallResponse(toolResult, webview);
console.log(JSON.stringify(toolResult, null, 2));
return {
code: 200,
msg: toolResult

View File

@ -35,3 +35,19 @@ export interface McpOptions {
clientName?: string;
clientVersion?: string;
}
export interface ToolCallContent {
type: string;
text?: string;
data?: any;
mimeType?: string;
_meta?: any;
[key: string]: any;
}
export interface ToolCallResponse {
_meta?: any;
content?: ToolCallContent[];
isError?: boolean;
toolResult?: any;
}

View File

@ -2,7 +2,9 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
import { McpOptions, McpTransport, IServerVersion } from './client.dto';
import { McpOptions, McpTransport, IServerVersion, ToolCallResponse, ToolCallContent } from './client.dto';
import { PostMessageble } from "../hook/adapter";
import { createOcrWorker, saveBase64ImageData } from "./ocr.service";
// 增强的客户端类
export class McpClient {
@ -11,8 +13,6 @@ export class McpClient {
private options: McpOptions;
private serverVersion: IServerVersion;
private transportStdErr: string = '';
constructor(options: McpOptions) {
this.options = options;
this.serverVersion = undefined;
@ -34,7 +34,6 @@ export class McpClient {
// 连接方法
public async connect(): Promise<void> {
this.transportStdErr = '';
// 根据连接类型创建传输层
switch (this.options.connectionType) {
@ -129,3 +128,53 @@ export async function connect(options: McpOptions): Promise<McpClient> {
await client.connect();
return client;
}
async function handleImage(
content: ToolCallContent,
webview: PostMessageble
) {
if (content.data && content.mimeType) {
const filename = saveBase64ImageData(content.data, content.mimeType);
content.data = filename;
// 加入工作线程
const worker = createOcrWorker(filename, webview);
content._meta = {
ocr: true,
workerId: worker.id
};
}
}
/**
* @description mcp server
* 线 image url
* 0.x.x
* @param response
* @returns
*/
export function postProcessMcpToolcallResponse(
response: ToolCallResponse,
webview: PostMessageble
): ToolCallResponse {
if (response.isError) {
// 如果是错误响应,将其转换为错误信息
return response;
}
// 将 content 中的图像 base64 提取出来,并保存到本地
for (const content of response.content || []) {
switch (content.type) {
case 'image':
handleImage(content, webview);
break;
default:
break;
}
}
return response;
}

View File

@ -0,0 +1,7 @@
export interface OcrWorker {
id: string;
name: string;
filename: string;
createTime: number;
fut: Promise<string>
}

View File

@ -0,0 +1,112 @@
import Tesseract from 'tesseract.js';
import { PostMessageble } from '../hook/adapter';
import { v4 as uuidv4 } from 'uuid';
import { OcrWorker } from './ocr.dto';
import { diskStorage, ocrDB } from '../hook/db';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
export const ocrWorkerStorage = new Map<string, OcrWorker>();
export function saveBase64ImageData(
base64String: string,
mimeType: string
): string {
// 从 base64 字符串中提取数据部分
const base64Data = base64String.replace(/^data:.+;base64,/, '');
// 生成唯一文件名
const fileName = `${uuidv4()}.${mimeType.split('/')[1]}`;
diskStorage.setSync(fileName, base64Data, { encoding: 'base64' });
return fileName;
}
export function loadBase64ImageData(fileName: string): string {
const homedir = os.homedir();
const imageStorageFolder = path.join(homedir, '.openmcp','storage');
const filePath = path.join(imageStorageFolder, fileName);
// 读取文件内容
if (!fs.existsSync(filePath)) {
return '';
}
const fileContent = fs.readFileSync(filePath, { encoding: 'base64' });
// 构建 base64 字符串
const base64String = `data:image/png;base64,${fileContent}`;
return base64String;
}
export async function tesseractOCR(
imagePath: string,
logger: (message: Tesseract.LoggerMessage) => void,
lang: string = 'eng+chi_sim'
) {
try {
const { data: { text } } = await Tesseract.recognize(
imagePath,
lang,
{
logger
}
);
return text;
} catch (error) {
console.error('OCR error:', error);
}
return '无法识别图片';
}
export function createOcrWorker(filename: string, webview: PostMessageble): OcrWorker {
const workerId = uuidv4();
const logger = (message: Tesseract.LoggerMessage) => {
webview.postMessage({
command: 'ocr/worker/log',
data: {
id: workerId,
...message
}
});
};
const imagePath = diskStorage.getStoragePath(filename);
const fut = tesseractOCR(imagePath, logger);
fut.then((text) => {
webview.postMessage({
command: 'ocr/worker/done',
data: {
id: workerId,
text
}
});
ocrDB.insert({
id: filename,
filename,
text,
createTime: Date.now()
});
ocrWorkerStorage.delete(workerId);
});
const worker = {
id: workerId,
name: 'ocr-' + filename,
filename,
createTime: Date.now(),
fut
};
ocrWorkerStorage.set(workerId, worker);
return worker;
}

View File

@ -97,7 +97,7 @@ export function revealOpenMcpWebviewPanel(
break;
default:
OpenMCPService.messageController(command, data, panel.webview);
OpenMCPService.routeMessage(command, data, panel.webview);
break;
}