完成 EDL 的构建,完成 QA 的拒答流

This commit is contained in:
锦恢 2024-06-09 23:49:25 +08:00
parent 3423bddc7b
commit 0525539ad8
8 changed files with 346 additions and 67 deletions

View File

@ -0,0 +1,46 @@
import '../plugins/image';
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send } from 'lagrange.onebot'
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
import { handleGroupIntent } from './intent';
let lastCall = undefined;
export class Impl {
@mapper.onGroup(932987873, { at: false })
async handleDigitalGroup(c: LagrangeContext<GroupMessage>) {
const texts = [];
const message = c.message;
for (const msg of message.message) {
if (msg.type === 'text') {
texts.push(msg.data.text);
}
}
const reply: Send.Default[] = [];
const axiosRes = await apiGetIntentRecogition({ query: texts.join('\n') });
const res = axiosRes.data;
if (res.code == 200) {
const intentResult = res.data;
// 如果 不确定性 太高,就将意图修改为
if (intentResult.uncertainty >= 0.33) {
intentResult.name = 'others';
}
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
const intentDebug = `【意图: ${intentResult.name} 不确定度: ${uncertainty}`;
const anwser = await handleGroupIntent(c, intentResult);
if (anwser !== undefined) {
c.sendMessage(anwser + '\n' + intentDebug);
}
} else {
const now = Date.now();
if (lastCall === undefined || (now - lastCall) >= 60 * 10 * 1000) {
c.sendMessage('RAG 系统目前离线');
}
lastCall = Date.now();
}
}
}

View File

@ -23,13 +23,15 @@ async function useRagLLM(c: LagrangeContext<GroupMessage>, intentResult: IntentR
texts.push(msg.data.text); texts.push(msg.data.text);
} }
} }
const query = texts.join(' ').trim();
if (query.length === 0) {
const { data } = await apiQueryVecdb({ query: texts.join(' '), k: 3 }); return undefined;
}
const { data } = await apiQueryVecdb({ query, k: 3 });
if (data.code === 200) { if (data.code === 200) {
const messages: apiQueryVecdbDataItem[] = data.data.filter(m => m.score <= 0.8); const messages: apiQueryVecdbDataItem[] = data.data.filter(m => m.score <= 0.7);
if (messages.length === 0) { if (messages.length === 0) {
c.sendMessage('未在数据库中检索到相关内容。'); return '未在数据库中检索到相关内容。';
} else { } else {
const query = makePrompt(messages); const query = makePrompt(messages);
const res = await llm.answer([ const res = await llm.answer([
@ -40,16 +42,16 @@ async function useRagLLM(c: LagrangeContext<GroupMessage>, intentResult: IntentR
]); ]);
if (typeof res === 'string') { if (typeof res === 'string') {
const links = messages.map(m => m.source); const links = messages.map(m => m.source);
const linkSet = new Set<string>(links);
const reference = ['参考链接:', ...links].join('\n'); const reference = ['参考链接:', ...linkSet].join('\n');
const anwser = res + '\n\n' + reference; const anwser = res + '\n\n' + reference;
c.sendMessage(anwser); return anwser;
} }
} }
} else { } else {
logger.error('apiQueryVecdb 接口访问失败: ' + JSON.stringify(data)); logger.error('apiQueryVecdb 接口访问失败: ' + JSON.stringify(data));
return false; return undefined;
} }
} }

View File

@ -1,6 +1,6 @@
import '../plugins/image'; import '../plugins/image';
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send } from 'lagrange.onebot' import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send, logger } from 'lagrange.onebot'
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb'; import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
@ -41,9 +41,16 @@ export class Impl {
if (intentResult.uncertainty >= 0.33) { if (intentResult.uncertainty >= 0.33) {
intentResult.name = 'others'; intentResult.name = 'others';
} }
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000; const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
c.sendMessage(`【意图: ${intentResult.name} 不确定度: ${uncertainty}`); const intentDebug = `【意图: ${intentResult.name} 不确定度: ${uncertainty}`;
handleGroupIntent(c, intentResult); const anwser = await handleGroupIntent(c, intentResult);
if (anwser === undefined) {
c.sendMessage('拒答' + '\n' + intentDebug);
} else {
c.sendMessage(anwser + '\n' + intentDebug);
}
} else { } else {
c.sendMessage('RAG 系统目前离线'); c.sendMessage('RAG 系统目前离线');
} }

1
config/push-to-server.sh Normal file
View File

@ -0,0 +1 @@
scp -r config ubuntu@101.43.239.71:/home/ubuntu/files/data/llm-rag

99
notebook/clear-logs.ipynb Normal file
View File

@ -0,0 +1,99 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import zipfile\n",
"from typing import *\n",
"import os\n",
"import json\n",
"import yaml\n",
"\n",
"log_home = '../logs'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def get_log_file_line(file: str, level='DEBUG') -> Generator:\n",
" if file.endswith('.log'):\n",
" for line in open(file, 'r', encoding='utf-8'):\n",
" if level in line:\n",
" yield line\n",
" elif file.endswith('.zip'):\n",
" with zipfile.ZipFile(file, 'r') as zip:\n",
" for zip_file in zip.namelist():\n",
" file_bytes = zip.read(zip_file)\n",
" for line in file_bytes.decode('utf-8').split('\\n'):\n",
" if level in line:\n",
" yield line\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"rag_log_files = [os.path.join(log_home, file) for file in os.listdir(log_home) if file.startswith('rag')]\n",
"\n",
"interesting_data = []\n",
"\n",
"for rag_file in rag_log_files:\n",
" for line in get_log_file_line(rag_file, level='DEBUG'):\n",
" try:\n",
" data = line.split('|')[-1].strip()\n",
" data = eval(data)\n",
" if data['intent']['name'] != 'others':\n",
" interesting_data.append(data)\n",
" except Exception:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"story = {'stories': []}\n",
"for d in interesting_data:\n",
" story['stories'].append({\n",
" 'message': d['query'],\n",
" 'intent': d['intent']['name']\n",
" })\n",
"\n",
"with open('../config/qq.story.yml', 'w', encoding='utf-8') as fp:\n",
" yaml.dump(story, fp, allow_unicode=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@ -8,6 +8,7 @@ import torch
from rag.db.embedding import embedding from rag.db.embedding import embedding
from rag.api.constant import StatusCode, MsgCode from rag.api.constant import StatusCode, MsgCode
from rag.api.admin import app from rag.api.admin import app
from rag.api.admin import logger
from rag.api.config import necessary_files from rag.api.config import necessary_files
from rag.model.enn import LinearEnn, train_enn from rag.model.enn import LinearEnn, train_enn
@ -60,30 +61,31 @@ def reload_embedding_mapping():
response.status_code = StatusCode.success.value response.status_code = StatusCode.success.value
return response return response
# TODO: 删除该接口
@app.route('/intent/retrain-embedding-mapping', methods=['post']) @app.route('/intent/retrain-embedding-mapping', methods=['post'])
def retrain_embedding_mapping(): def retrain_embedding_mapping():
engine = PromptEngine(necessary_files['intent-story']) # engine = PromptEngine(necessary_files['intent-story'])
engine.merge_stories_from_yml(necessary_files['issue-story']) # engine.merge_stories_from_yml(necessary_files['issue-story'])
sentences = [] # sentences = []
labels = [] # labels = []
for story in engine.stories: # for story in engine.stories:
sentences.append(story.message) # sentences.append(story.message)
labels.append(engine.intent2id[story.intent]) # labels.append(engine.intent2id[story.intent])
try: # try:
labels = np.array(labels) # labels = np.array(labels)
embed = embedding.embed_documents(sentences) # embed = embedding.embed_documents(sentences)
enn_model = intent_recogition.classifier # enn_model = intent_recogition.classifier
train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100) # train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
torch.save(enn_model.state_dict(), necessary_files['intent-classifier']) # torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
except Exception as e: # except Exception as e:
response = jsonify({ # response = jsonify({
'code': StatusCode.process_error.value, # 'code': StatusCode.process_error.value,
'data': str(e), # 'data': str(e),
'msg': MsgCode.query_not_empty.value # 'msg': MsgCode.query_not_empty.value
}) # })
response.status_code = StatusCode.success.value # response.status_code = StatusCode.success.value
return response # return response
response = jsonify({ response = jsonify({
'code': StatusCode.success.value, 'code': StatusCode.success.value,
@ -113,6 +115,12 @@ def get_intent_recogition():
result = intent_recogition.get_intent_recogition(query) result = intent_recogition.get_intent_recogition(query)
logger_chunk = json.dumps({
'query': query,
'intent': result
}, ensure_ascii=False)
logger.debug(logger_chunk)
response = jsonify({ response = jsonify({
'code': StatusCode.success.value, 'code': StatusCode.success.value,
'data': result, 'data': result,

View File

@ -42,6 +42,12 @@ suite('test intent recogition', () => {
{ input: '因为这是养蛊的虚拟机,放了些国产垃圾软件,得用国产流氓之王才能镇得住他们', expect: 'others' }, { input: '因为这是养蛊的虚拟机,放了些国产垃圾软件,得用国产流氓之王才能镇得住他们', expect: 'others' },
{ input: '你咋装了个360', expect: 'others' }, { input: '你咋装了个360', expect: 'others' },
{ input: '', expect: 'expression,others' }, { input: '', expect: 'expression,others' },
{ input: '有点怪,这里没有和竖线对齐', expect: 'others' },
{ input: '调试时候下载bit流和ltx固化需要在xdc约束里面添加spi或bpi相关约束后生成bit流。建议一开始xdc约束当中就添加相关约束后续生成mcs文件需要先添加flash型号根据相关型号生成对应的mcs文件后再固化到flash当中。', expect: 'advice,others' },
{ input: 'at 触发不够人性化,很多用户不知道 Tip 的存在就不会去使用它', expect: 'expression,others' },
{ input: '非问句会触发吗', expect: 'expression,others' },
{ input: 'zlib', expect: 'expression,others' },
{ input: 'https://github.com/Digital-EDA/Digital-IDE/discussions', expect: 'others' },
]; ];
for (const s of intent_suites) { for (const s of intent_suites) {