完成 EDL 的构建,完成 QA 的拒答流
This commit is contained in:
parent
3423bddc7b
commit
0525539ad8
@ -0,0 +1,46 @@
|
||||
import '../plugins/image';
|
||||
|
||||
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send } from 'lagrange.onebot'
|
||||
|
||||
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
|
||||
|
||||
import { handleGroupIntent } from './intent';
|
||||
|
||||
let lastCall = undefined;
|
||||
|
||||
export class Impl {
|
||||
@mapper.onGroup(932987873, { at: false })
|
||||
async handleDigitalGroup(c: LagrangeContext<GroupMessage>) {
|
||||
const texts = [];
|
||||
const message = c.message;
|
||||
for (const msg of message.message) {
|
||||
if (msg.type === 'text') {
|
||||
texts.push(msg.data.text);
|
||||
}
|
||||
}
|
||||
const reply: Send.Default[] = [];
|
||||
const axiosRes = await apiGetIntentRecogition({ query: texts.join('\n') });
|
||||
const res = axiosRes.data;
|
||||
|
||||
if (res.code == 200) {
|
||||
const intentResult = res.data;
|
||||
// 如果 不确定性 太高,就将意图修改为
|
||||
if (intentResult.uncertainty >= 0.33) {
|
||||
intentResult.name = 'others';
|
||||
}
|
||||
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
|
||||
const intentDebug = `【意图: ${intentResult.name} 不确定度: ${uncertainty}】`;
|
||||
const anwser = await handleGroupIntent(c, intentResult);
|
||||
if (anwser !== undefined) {
|
||||
c.sendMessage(anwser + '\n' + intentDebug);
|
||||
}
|
||||
|
||||
} else {
|
||||
const now = Date.now();
|
||||
if (lastCall === undefined || (now - lastCall) >= 60 * 10 * 1000) {
|
||||
c.sendMessage('RAG 系统目前离线');
|
||||
}
|
||||
lastCall = Date.now();
|
||||
}
|
||||
}
|
||||
}
|
@ -23,13 +23,15 @@ async function useRagLLM(c: LagrangeContext<GroupMessage>, intentResult: IntentR
|
||||
texts.push(msg.data.text);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const { data } = await apiQueryVecdb({ query: texts.join(' '), k: 3 });
|
||||
const query = texts.join(' ').trim();
|
||||
if (query.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
const { data } = await apiQueryVecdb({ query, k: 3 });
|
||||
if (data.code === 200) {
|
||||
const messages: apiQueryVecdbDataItem[] = data.data.filter(m => m.score <= 0.8);
|
||||
const messages: apiQueryVecdbDataItem[] = data.data.filter(m => m.score <= 0.7);
|
||||
if (messages.length === 0) {
|
||||
c.sendMessage('未在数据库中检索到相关内容。');
|
||||
return '未在数据库中检索到相关内容。';
|
||||
} else {
|
||||
const query = makePrompt(messages);
|
||||
const res = await llm.answer([
|
||||
@ -40,16 +42,16 @@ async function useRagLLM(c: LagrangeContext<GroupMessage>, intentResult: IntentR
|
||||
]);
|
||||
if (typeof res === 'string') {
|
||||
const links = messages.map(m => m.source);
|
||||
|
||||
const reference = ['参考链接:', ...links].join('\n');
|
||||
const linkSet = new Set<string>(links);
|
||||
const reference = ['参考链接:', ...linkSet].join('\n');
|
||||
const anwser = res + '\n\n' + reference;
|
||||
c.sendMessage(anwser);
|
||||
return anwser;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
logger.error('apiQueryVecdb 接口访问失败: ' + JSON.stringify(data));
|
||||
return false;
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import '../plugins/image';
|
||||
|
||||
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send } from 'lagrange.onebot'
|
||||
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send, logger } from 'lagrange.onebot'
|
||||
|
||||
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
|
||||
|
||||
@ -41,9 +41,16 @@ export class Impl {
|
||||
if (intentResult.uncertainty >= 0.33) {
|
||||
intentResult.name = 'others';
|
||||
}
|
||||
|
||||
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
|
||||
c.sendMessage(`【意图: ${intentResult.name} 不确定度: ${uncertainty}】`);
|
||||
handleGroupIntent(c, intentResult);
|
||||
const intentDebug = `【意图: ${intentResult.name} 不确定度: ${uncertainty}】`;
|
||||
const anwser = await handleGroupIntent(c, intentResult);
|
||||
if (anwser === undefined) {
|
||||
c.sendMessage('拒答' + '\n' + intentDebug);
|
||||
} else {
|
||||
c.sendMessage(anwser + '\n' + intentDebug);
|
||||
}
|
||||
|
||||
} else {
|
||||
c.sendMessage('RAG 系统目前离线');
|
||||
}
|
||||
|
1
config/push-to-server.sh
Normal file
1
config/push-to-server.sh
Normal file
@ -0,0 +1 @@
|
||||
scp -r config ubuntu@101.43.239.71:/home/ubuntu/files/data/llm-rag
|
99
notebook/clear-logs.ipynb
Normal file
99
notebook/clear-logs.ipynb
Normal file
@ -0,0 +1,99 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import zipfile\n",
|
||||
"from typing import *\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"log_home = '../logs'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_log_file_line(file: str, level='DEBUG') -> Generator:\n",
|
||||
" if file.endswith('.log'):\n",
|
||||
" for line in open(file, 'r', encoding='utf-8'):\n",
|
||||
" if level in line:\n",
|
||||
" yield line\n",
|
||||
" elif file.endswith('.zip'):\n",
|
||||
" with zipfile.ZipFile(file, 'r') as zip:\n",
|
||||
" for zip_file in zip.namelist():\n",
|
||||
" file_bytes = zip.read(zip_file)\n",
|
||||
" for line in file_bytes.decode('utf-8').split('\\n'):\n",
|
||||
" if level in line:\n",
|
||||
" yield line\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rag_log_files = [os.path.join(log_home, file) for file in os.listdir(log_home) if file.startswith('rag')]\n",
|
||||
"\n",
|
||||
"interesting_data = []\n",
|
||||
"\n",
|
||||
"for rag_file in rag_log_files:\n",
|
||||
" for line in get_log_file_line(rag_file, level='DEBUG'):\n",
|
||||
" try:\n",
|
||||
" data = line.split('|')[-1].strip()\n",
|
||||
" data = eval(data)\n",
|
||||
" if data['intent']['name'] != 'others':\n",
|
||||
" interesting_data.append(data)\n",
|
||||
" except Exception:\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"story = {'stories': []}\n",
|
||||
"for d in interesting_data:\n",
|
||||
" story['stories'].append({\n",
|
||||
" 'message': d['query'],\n",
|
||||
" 'intent': d['intent']['name']\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"with open('../config/qq.story.yml', 'w', encoding='utf-8') as fp:\n",
|
||||
" yaml.dump(story, fp, allow_unicode=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
File diff suppressed because one or more lines are too long
@ -8,6 +8,7 @@ import torch
|
||||
from rag.db.embedding import embedding
|
||||
from rag.api.constant import StatusCode, MsgCode
|
||||
from rag.api.admin import app
|
||||
from rag.api.admin import logger
|
||||
from rag.api.config import necessary_files
|
||||
from rag.model.enn import LinearEnn, train_enn
|
||||
|
||||
@ -60,30 +61,31 @@ def reload_embedding_mapping():
|
||||
response.status_code = StatusCode.success.value
|
||||
return response
|
||||
|
||||
# TODO: 删除该接口
|
||||
@app.route('/intent/retrain-embedding-mapping', methods=['post'])
|
||||
def retrain_embedding_mapping():
|
||||
engine = PromptEngine(necessary_files['intent-story'])
|
||||
engine.merge_stories_from_yml(necessary_files['issue-story'])
|
||||
sentences = []
|
||||
labels = []
|
||||
for story in engine.stories:
|
||||
sentences.append(story.message)
|
||||
labels.append(engine.intent2id[story.intent])
|
||||
try:
|
||||
labels = np.array(labels)
|
||||
embed = embedding.embed_documents(sentences)
|
||||
enn_model = intent_recogition.classifier
|
||||
train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
|
||||
torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
|
||||
# engine = PromptEngine(necessary_files['intent-story'])
|
||||
# engine.merge_stories_from_yml(necessary_files['issue-story'])
|
||||
# sentences = []
|
||||
# labels = []
|
||||
# for story in engine.stories:
|
||||
# sentences.append(story.message)
|
||||
# labels.append(engine.intent2id[story.intent])
|
||||
# try:
|
||||
# labels = np.array(labels)
|
||||
# embed = embedding.embed_documents(sentences)
|
||||
# enn_model = intent_recogition.classifier
|
||||
# train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
|
||||
# torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
|
||||
|
||||
except Exception as e:
|
||||
response = jsonify({
|
||||
'code': StatusCode.process_error.value,
|
||||
'data': str(e),
|
||||
'msg': MsgCode.query_not_empty.value
|
||||
})
|
||||
response.status_code = StatusCode.success.value
|
||||
return response
|
||||
# except Exception as e:
|
||||
# response = jsonify({
|
||||
# 'code': StatusCode.process_error.value,
|
||||
# 'data': str(e),
|
||||
# 'msg': MsgCode.query_not_empty.value
|
||||
# })
|
||||
# response.status_code = StatusCode.success.value
|
||||
# return response
|
||||
|
||||
response = jsonify({
|
||||
'code': StatusCode.success.value,
|
||||
@ -113,6 +115,12 @@ def get_intent_recogition():
|
||||
|
||||
result = intent_recogition.get_intent_recogition(query)
|
||||
|
||||
logger_chunk = json.dumps({
|
||||
'query': query,
|
||||
'intent': result
|
||||
}, ensure_ascii=False)
|
||||
logger.debug(logger_chunk)
|
||||
|
||||
response = jsonify({
|
||||
'code': StatusCode.success.value,
|
||||
'data': result,
|
||||
|
@ -42,6 +42,12 @@ suite('test intent recogition', () => {
|
||||
{ input: '因为这是养蛊的虚拟机,放了些国产垃圾软件,得用国产流氓之王才能镇得住他们', expect: 'others' },
|
||||
{ input: '你咋装了个360', expect: 'others' },
|
||||
{ input: '???', expect: 'expression,others' },
|
||||
{ input: '有点怪,这里没有和竖线对齐', expect: 'others' },
|
||||
{ input: '调试时候下载bit流和ltx,固化需要在xdc约束里面添加spi或bpi相关约束后生成bit流。建议一开始xdc约束当中就添加相关约束,后续生成mcs文件需要先添加flash型号,根据相关型号生成对应的mcs文件后再固化到flash当中。', expect: 'advice,others' },
|
||||
{ input: 'at 触发不够人性化,很多用户不知道 Tip 的存在就不会去使用它', expect: 'expression,others' },
|
||||
{ input: '非问句会触发吗', expect: 'expression,others' },
|
||||
{ input: 'zlib', expect: 'expression,others' },
|
||||
{ input: 'https://github.com/Digital-EDA/Digital-IDE/discussions', expect: 'others' },
|
||||
];
|
||||
|
||||
for (const s of intent_suites) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user