From 1b85207b7f46acda6270d76195b46d77c6d51f7b Mon Sep 17 00:00:00 2001 From: Kirigaya <1193466151@qq.com> Date: Mon, 2 Jun 2025 18:03:19 +0800 Subject: [PATCH] support google gemini --- CHANGELOG.md | 4 + package-lock.json | 7 +- package.json | 2 +- service/package.json | 5 +- service/src/hook/axios-fetch.ts | 200 ++++++++++++++++++++++++++ service/src/hook/llm.ts | 2 + service/src/llm/llm.dto.ts | 17 +++ service/src/llm/llm.service.ts | 54 ++++--- service/src/main.ts | 1 + service/src/panel/panel.controller.ts | 1 - 10 files changed, 270 insertions(+), 23 deletions(-) create mode 100644 service/src/hook/axios-fetch.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 72fc119..50f569a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Change Log +## [main] 0.1.4 +- 支持 Google Gemini 模型。 +- 支持 Grok3 的 tool call 流式传输。 + ## [main] 0.1.3 - 解决 issue#21 点击按钮后的发送文本后不会清空当前的输入框。 - 修复暂停按键在多轮对话后消失的问题。 diff --git a/package-lock.json b/package-lock.json index dac9f15..0247b63 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "openmcp", - "version": "0.1.2", + "version": "0.1.3", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "openmcp", - "version": "0.1.2", + "version": "0.1.3", "workspaces": [ "service", "renderer", @@ -15,7 +15,7 @@ "dependencies": { "@modelcontextprotocol/sdk": "^1.12.1", "@seald-io/nedb": "^4.1.1", - "axios": "^1.7.7", + "axios": "^1.9.0", "bson": "^6.8.0", "openai": "^5.0.1", "pako": "^2.1.0", @@ -11830,6 +11830,7 @@ "dependencies": { "@modelcontextprotocol/sdk": "^1.12.1", "@seald-io/nedb": "^4.1.1", + "axios": "^1.9.0", "openai": "^5.0.1", "pako": "^2.1.0", "pino": "^9.6.0", diff --git a/package.json b/package.json index 2a31ec9..c1cabd4 100644 --- a/package.json +++ b/package.json @@ -235,7 +235,7 @@ "dependencies": { "@modelcontextprotocol/sdk": "^1.12.1", "@seald-io/nedb": "^4.1.1", - "axios": "^1.7.7", + "axios": "^1.9.0", "bson": "^6.8.0", "openai": "^5.0.1", "pako": "^2.1.0", diff --git a/service/package.json b/service/package.json index 933c099..0c3c624 100644 --- a/service/package.json +++ b/service/package.json @@ -4,13 +4,13 @@ "description": "", "main": "dist/index.js", "types": "dist/index.d.ts", - "type": "commonjs", + "type": "commonjs", "scripts": { "dev": "ts-node-dev --respawn --transpile-only src/main.ts", "serve": "ts-node-dev --respawn --transpile-only src/main.ts", "build": "tsc", "build:watch": "tsc --watch", - "postbuild": "node ./scripts/post-build.mjs", + "postbuild": "node ./scripts/post-build.mjs", "start": "node dist/main.js", "start:prod": "NODE_ENV=production node dist/main.js", "debug": "node --inspect -r ts-node/register src/main.ts", @@ -38,6 +38,7 @@ "dependencies": { "@modelcontextprotocol/sdk": "^1.12.1", "@seald-io/nedb": "^4.1.1", + "axios": "^1.9.0", "openai": "^5.0.1", "pako": "^2.1.0", "pino": "^9.6.0", diff --git a/service/src/hook/axios-fetch.ts b/service/src/hook/axios-fetch.ts new file mode 100644 index 0000000..d7ecb66 --- /dev/null +++ b/service/src/hook/axios-fetch.ts @@ -0,0 +1,200 @@ +import axios, { AxiosResponse } from "axios"; + +interface FetchOptions { + method?: string; + headers?: Record; + body?: string | Buffer | FormData | URLSearchParams | object; + [key: string]: any; +} + +interface FetchResponse { + ok: boolean; + status: number; + statusText: string; + headers: Headers; + url: string; + redirected: boolean; + type: string; + body: any; + + json(): Promise; + text(): Promise; + arrayBuffer(): Promise; + getReader(): ReadableStreamDefaultReader; +} + +interface ReadableStreamDefaultReader { + read(): Promise<{done: boolean, value?: any}>; + cancel(): Promise; + releaseLock(): void; + get closed(): boolean; +} + +/** + * 将 axios 配置转换为 fetch 风格的配置 + */ +function adaptRequestOptions(url: string, options: FetchOptions = {}): any { + const axiosConfig: any = { + url, + method: options.method || 'GET', + headers: options.headers, + responseType: 'stream' + }; + + // 处理 body/data 转换 + if (options.body) { + if (typeof options.body === 'string' || Buffer.isBuffer(options.body)) { + axiosConfig.data = options.body; + } else if (typeof options.body === 'object') { + // 如果是 FormData、URLSearchParams 等特殊类型需要特殊处理 + if (options.body instanceof FormData) { + axiosConfig.data = options.body; + axiosConfig.headers = { + ...axiosConfig.headers, + 'Content-Type': 'multipart/form-data' + }; + } else if (options.body instanceof URLSearchParams) { + axiosConfig.data = options.body.toString(); + axiosConfig.headers = { + ...axiosConfig.headers, + 'Content-Type': 'application/x-www-form-urlencoded' + }; + } else { + // 普通 JSON 对象 + axiosConfig.data = JSON.stringify(options.body); + axiosConfig.headers = { + ...axiosConfig.headers, + 'Content-Type': 'application/json' + }; + } + } + } + + return axiosConfig; +} + +/** + * 将 axios 响应转换为 fetch 风格的 Response 对象 + */ +function adaptResponse(axiosResponse: FetchOptions): FetchResponse { + // 创建 Headers 对象 + const headers = new Headers(); + Object.entries(axiosResponse.headers || {}).forEach(([key, value]) => { + headers.append(key, value); + }); + + // 创建符合 Fetch API 的 Response 对象 + const fetchResponse = { + ok: axiosResponse.status >= 200 && axiosResponse.status < 300, + status: axiosResponse.status, + statusText: axiosResponse.statusText, + headers: headers, + url: axiosResponse.config.url, + redirected: false, // axios 不直接提供此信息 + type: 'basic', // 简单类型 + body: null, + + // 标准方法 + json: async () => { + if (typeof axiosResponse.data === 'object') { + return axiosResponse.data; + } + throw new Error('Response is not JSON'); + }, + text: async () => { + if (typeof axiosResponse.data === 'string') { + return axiosResponse.data; + } + return JSON.stringify(axiosResponse.data); + }, + arrayBuffer: async () => { + throw new Error('arrayBuffer not implemented for streaming'); + }, + + // 流式支持 + getReader: () => { + if (!axiosResponse.data.on || typeof axiosResponse.data.on !== 'function') { + throw new Error('Not a stream response'); + } + + // 将 Node.js 流转换为 Web Streams 的 ReadableStream + const nodeStream = axiosResponse.data; + let isCancelled = false; + + return { + read: () => { + if (isCancelled) { + return Promise.resolve({ done: true }); + } + + return new Promise((resolve, reject) => { + const onData = (chunk: any) => { + cleanup(); + resolve({ done: false, value: chunk }); + }; + + const onEnd = () => { + cleanup(); + resolve({ done: true }); + }; + + const onError = (err: Error) => { + cleanup(); + reject(err); + }; + + const cleanup = () => { + nodeStream.off('data', onData); + nodeStream.off('end', onEnd); + nodeStream.off('error', onError); + }; + + nodeStream.once('data', onData); + nodeStream.once('end', onEnd); + nodeStream.once('error', onError); + }); + }, + + cancel: () => { + isCancelled = true; + nodeStream.destroy(); + return Promise.resolve(); + }, + + releaseLock: () => { + // TODO: 实现 releaseLock 方法 + }, + + get closed() { + return isCancelled; + } + }; + } + } as FetchResponse; + + // 设置 body 为可读流 + if (axiosResponse.data.on && typeof axiosResponse.data.on === 'function') { + fetchResponse.body = { + getReader: fetchResponse.getReader + }; + } + + return fetchResponse; +} + +/** + * @description 主函数 - 用 axios 实现 fetch + */ +export async function axiosFetch(url: any, options: any): Promise { + const axiosConfig = adaptRequestOptions(url, options); + + try { + const response = await axios(axiosConfig) as FetchOptions; + return adaptResponse(response); + } catch (error: any) { + if (error.response) { + return adaptResponse(error.response); + } + throw error; + } +} \ No newline at end of file diff --git a/service/src/hook/llm.ts b/service/src/hook/llm.ts index f31d1c9..70b2913 100644 --- a/service/src/hook/llm.ts +++ b/service/src/hook/llm.ts @@ -108,3 +108,5 @@ export const llms = [ userModel: 'moonshot-v1-8k' } ]; + + diff --git a/service/src/llm/llm.dto.ts b/service/src/llm/llm.dto.ts index 2514b92..9328eae 100644 --- a/service/src/llm/llm.dto.ts +++ b/service/src/llm/llm.dto.ts @@ -6,4 +6,21 @@ export type MyMessageType = OpenAI.Chat.ChatCompletionMessageParam & { export type MyToolMessageType = OpenAI.Chat.ChatCompletionToolMessageParam & { extraInfo?: any; +} + +export interface OpenMcpChatOption { + baseURL: string; + apiKey: string; + model: string; + messages: any[]; + temperature?: number; + tools?: any[]; + parallelToolCalls?: boolean; +} + +export interface MyStream extends AsyncIterable { + [Symbol.asyncIterator](): AsyncIterator; + controller: { + abort(): void; + }; } \ No newline at end of file diff --git a/service/src/llm/llm.service.ts b/service/src/llm/llm.service.ts index 26952fb..b9dbcaa 100644 --- a/service/src/llm/llm.service.ts +++ b/service/src/llm/llm.service.ts @@ -5,6 +5,7 @@ import { RestfulResponse } from "../common/index.dto"; import { ocrDB } from "../hook/db"; import type { ToolCallContent } from "../mcp/client.dto"; import { ocrWorkerStorage } from "../mcp/ocr.service"; +import { axiosFetch } from "../hook/axios-fetch"; export let currentStream: AsyncIterable | null = null; @@ -12,33 +13,54 @@ export async function streamingChatCompletion( data: any, webview: PostMessageble ) { - let { - baseURL, - apiKey, - model, - messages, - temperature, - tools = [], - parallelToolCalls = true - } = data; + const { + baseURL, + apiKey, + model, + messages, + temperature, + tools = [], + parallelToolCalls = true + } = data; const client = new OpenAI({ baseURL, - apiKey + apiKey, + fetch: async (input: string | URL | Request, init?: RequestInit) => { + + console.log('openai fetch begin'); + + if (model.startsWith('gemini')) { + // 该死的 google + if (init) { + init.headers = { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${apiKey}` + } + } + + console.log('input:', input); + console.log('init:', init); + + return await axiosFetch(input, init); + } else { + return await fetch(input, init); + } + } }); - if (tools.length === 0) { - tools = undefined; - } - + const seriableTools = (tools.length === 0) ? undefined: tools; + const seriableParallelToolCalls = (tools.length === 0)? + undefined: model.startsWith('gemini') ? undefined : parallelToolCalls; + await postProcessMessages(messages); const stream = await client.chat.completions.create({ model, messages, temperature, - tools, - parallel_tool_calls: parallelToolCalls, + tools: seriableTools, + parallel_tool_calls: seriableParallelToolCalls, stream: true }); diff --git a/service/src/main.ts b/service/src/main.ts index 0e6b778..54c9a09 100644 --- a/service/src/main.ts +++ b/service/src/main.ts @@ -6,6 +6,7 @@ import { VSCodeWebViewLike } from './hook/adapter'; import path from 'node:path'; import * as fs from 'node:fs'; import { setRunningCWD } from './hook/setting'; +import axios from 'axios'; export interface VSCodeMessage { command: string; diff --git a/service/src/panel/panel.controller.ts b/service/src/panel/panel.controller.ts index 92332ef..7b7321d 100644 --- a/service/src/panel/panel.controller.ts +++ b/service/src/panel/panel.controller.ts @@ -79,7 +79,6 @@ export class PanelController { @Controller('system-prompts/load') async loadSystemPrompts(data: RequestData, webview: PostMessageble) { - const client = getClient(data.clientId); const queryPrompts = await systemPromptDB.findAll(); const prompts = []; for (const prompt of queryPrompts) {