完成 EDL 的构建,完成 QA 的拒答流

This commit is contained in:
锦恢 2024-06-09 23:49:25 +08:00
parent 3423bddc7b
commit 0525539ad8
8 changed files with 346 additions and 67 deletions

View File

@ -0,0 +1,46 @@
import '../plugins/image';
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send } from 'lagrange.onebot'
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
import { handleGroupIntent } from './intent';
let lastCall = undefined;
export class Impl {
@mapper.onGroup(932987873, { at: false })
async handleDigitalGroup(c: LagrangeContext<GroupMessage>) {
const texts = [];
const message = c.message;
for (const msg of message.message) {
if (msg.type === 'text') {
texts.push(msg.data.text);
}
}
const reply: Send.Default[] = [];
const axiosRes = await apiGetIntentRecogition({ query: texts.join('\n') });
const res = axiosRes.data;
if (res.code == 200) {
const intentResult = res.data;
// 如果 不确定性 太高,就将意图修改为
if (intentResult.uncertainty >= 0.33) {
intentResult.name = 'others';
}
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
const intentDebug = `【意图: ${intentResult.name} 不确定度: ${uncertainty}`;
const anwser = await handleGroupIntent(c, intentResult);
if (anwser !== undefined) {
c.sendMessage(anwser + '\n' + intentDebug);
}
} else {
const now = Date.now();
if (lastCall === undefined || (now - lastCall) >= 60 * 10 * 1000) {
c.sendMessage('RAG 系统目前离线');
}
lastCall = Date.now();
}
}
}

View File

@ -23,13 +23,15 @@ async function useRagLLM(c: LagrangeContext<GroupMessage>, intentResult: IntentR
texts.push(msg.data.text);
}
}
const { data } = await apiQueryVecdb({ query: texts.join(' '), k: 3 });
const query = texts.join(' ').trim();
if (query.length === 0) {
return undefined;
}
const { data } = await apiQueryVecdb({ query, k: 3 });
if (data.code === 200) {
const messages: apiQueryVecdbDataItem[] = data.data.filter(m => m.score <= 0.8);
const messages: apiQueryVecdbDataItem[] = data.data.filter(m => m.score <= 0.7);
if (messages.length === 0) {
c.sendMessage('未在数据库中检索到相关内容。');
return '未在数据库中检索到相关内容。';
} else {
const query = makePrompt(messages);
const res = await llm.answer([
@ -40,16 +42,16 @@ async function useRagLLM(c: LagrangeContext<GroupMessage>, intentResult: IntentR
]);
if (typeof res === 'string') {
const links = messages.map(m => m.source);
const reference = ['参考链接:', ...links].join('\n');
const linkSet = new Set<string>(links);
const reference = ['参考链接:', ...linkSet].join('\n');
const anwser = res + '\n\n' + reference;
c.sendMessage(anwser);
return anwser;
}
}
} else {
logger.error('apiQueryVecdb 接口访问失败: ' + JSON.stringify(data));
return false;
return undefined;
}
}

View File

@ -1,6 +1,6 @@
import '../plugins/image';
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send } from 'lagrange.onebot'
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send, logger } from 'lagrange.onebot'
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
@ -41,9 +41,16 @@ export class Impl {
if (intentResult.uncertainty >= 0.33) {
intentResult.name = 'others';
}
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
c.sendMessage(`【意图: ${intentResult.name} 不确定度: ${uncertainty}`);
handleGroupIntent(c, intentResult);
const intentDebug = `【意图: ${intentResult.name} 不确定度: ${uncertainty}`;
const anwser = await handleGroupIntent(c, intentResult);
if (anwser === undefined) {
c.sendMessage('拒答' + '\n' + intentDebug);
} else {
c.sendMessage(anwser + '\n' + intentDebug);
}
} else {
c.sendMessage('RAG 系统目前离线');
}

1
config/push-to-server.sh Normal file
View File

@ -0,0 +1 @@
scp -r config ubuntu@101.43.239.71:/home/ubuntu/files/data/llm-rag

99
notebook/clear-logs.ipynb Normal file
View File

@ -0,0 +1,99 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import zipfile\n",
"from typing import *\n",
"import os\n",
"import json\n",
"import yaml\n",
"\n",
"log_home = '../logs'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def get_log_file_line(file: str, level='DEBUG') -> Generator:\n",
" if file.endswith('.log'):\n",
" for line in open(file, 'r', encoding='utf-8'):\n",
" if level in line:\n",
" yield line\n",
" elif file.endswith('.zip'):\n",
" with zipfile.ZipFile(file, 'r') as zip:\n",
" for zip_file in zip.namelist():\n",
" file_bytes = zip.read(zip_file)\n",
" for line in file_bytes.decode('utf-8').split('\\n'):\n",
" if level in line:\n",
" yield line\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"rag_log_files = [os.path.join(log_home, file) for file in os.listdir(log_home) if file.startswith('rag')]\n",
"\n",
"interesting_data = []\n",
"\n",
"for rag_file in rag_log_files:\n",
" for line in get_log_file_line(rag_file, level='DEBUG'):\n",
" try:\n",
" data = line.split('|')[-1].strip()\n",
" data = eval(data)\n",
" if data['intent']['name'] != 'others':\n",
" interesting_data.append(data)\n",
" except Exception:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"story = {'stories': []}\n",
"for d in interesting_data:\n",
" story['stories'].append({\n",
" 'message': d['query'],\n",
" 'intent': d['intent']['name']\n",
" })\n",
"\n",
"with open('../config/qq.story.yml', 'w', encoding='utf-8') as fp:\n",
" yaml.dump(story, fp, allow_unicode=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@ -8,6 +8,7 @@ import torch
from rag.db.embedding import embedding
from rag.api.constant import StatusCode, MsgCode
from rag.api.admin import app
from rag.api.admin import logger
from rag.api.config import necessary_files
from rag.model.enn import LinearEnn, train_enn
@ -60,30 +61,31 @@ def reload_embedding_mapping():
response.status_code = StatusCode.success.value
return response
# TODO: 删除该接口
@app.route('/intent/retrain-embedding-mapping', methods=['post'])
def retrain_embedding_mapping():
engine = PromptEngine(necessary_files['intent-story'])
engine.merge_stories_from_yml(necessary_files['issue-story'])
sentences = []
labels = []
for story in engine.stories:
sentences.append(story.message)
labels.append(engine.intent2id[story.intent])
try:
labels = np.array(labels)
embed = embedding.embed_documents(sentences)
enn_model = intent_recogition.classifier
train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
# engine = PromptEngine(necessary_files['intent-story'])
# engine.merge_stories_from_yml(necessary_files['issue-story'])
# sentences = []
# labels = []
# for story in engine.stories:
# sentences.append(story.message)
# labels.append(engine.intent2id[story.intent])
# try:
# labels = np.array(labels)
# embed = embedding.embed_documents(sentences)
# enn_model = intent_recogition.classifier
# train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
# torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
except Exception as e:
response = jsonify({
'code': StatusCode.process_error.value,
'data': str(e),
'msg': MsgCode.query_not_empty.value
})
response.status_code = StatusCode.success.value
return response
# except Exception as e:
# response = jsonify({
# 'code': StatusCode.process_error.value,
# 'data': str(e),
# 'msg': MsgCode.query_not_empty.value
# })
# response.status_code = StatusCode.success.value
# return response
response = jsonify({
'code': StatusCode.success.value,
@ -113,6 +115,12 @@ def get_intent_recogition():
result = intent_recogition.get_intent_recogition(query)
logger_chunk = json.dumps({
'query': query,
'intent': result
}, ensure_ascii=False)
logger.debug(logger_chunk)
response = jsonify({
'code': StatusCode.success.value,
'data': result,

View File

@ -42,6 +42,12 @@ suite('test intent recogition', () => {
{ input: '因为这是养蛊的虚拟机,放了些国产垃圾软件,得用国产流氓之王才能镇得住他们', expect: 'others' },
{ input: '你咋装了个360', expect: 'others' },
{ input: '', expect: 'expression,others' },
{ input: '有点怪,这里没有和竖线对齐', expect: 'others' },
{ input: '调试时候下载bit流和ltx固化需要在xdc约束里面添加spi或bpi相关约束后生成bit流。建议一开始xdc约束当中就添加相关约束后续生成mcs文件需要先添加flash型号根据相关型号生成对应的mcs文件后再固化到flash当中。', expect: 'advice,others' },
{ input: 'at 触发不够人性化,很多用户不知道 Tip 的存在就不会去使用它', expect: 'expression,others' },
{ input: '非问句会触发吗', expect: 'expression,others' },
{ input: 'zlib', expect: 'expression,others' },
{ input: 'https://github.com/Digital-EDA/Digital-IDE/discussions', expect: 'others' },
];
for (const s of intent_suites) {