完成 EDL 的构建,完成 QA 的拒答流
This commit is contained in:
parent
3423bddc7b
commit
0525539ad8
@ -0,0 +1,46 @@
|
|||||||
|
import '../plugins/image';
|
||||||
|
|
||||||
|
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send } from 'lagrange.onebot'
|
||||||
|
|
||||||
|
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
|
||||||
|
|
||||||
|
import { handleGroupIntent } from './intent';
|
||||||
|
|
||||||
|
let lastCall = undefined;
|
||||||
|
|
||||||
|
export class Impl {
|
||||||
|
@mapper.onGroup(932987873, { at: false })
|
||||||
|
async handleDigitalGroup(c: LagrangeContext<GroupMessage>) {
|
||||||
|
const texts = [];
|
||||||
|
const message = c.message;
|
||||||
|
for (const msg of message.message) {
|
||||||
|
if (msg.type === 'text') {
|
||||||
|
texts.push(msg.data.text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const reply: Send.Default[] = [];
|
||||||
|
const axiosRes = await apiGetIntentRecogition({ query: texts.join('\n') });
|
||||||
|
const res = axiosRes.data;
|
||||||
|
|
||||||
|
if (res.code == 200) {
|
||||||
|
const intentResult = res.data;
|
||||||
|
// 如果 不确定性 太高,就将意图修改为
|
||||||
|
if (intentResult.uncertainty >= 0.33) {
|
||||||
|
intentResult.name = 'others';
|
||||||
|
}
|
||||||
|
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
|
||||||
|
const intentDebug = `【意图: ${intentResult.name} 不确定度: ${uncertainty}】`;
|
||||||
|
const anwser = await handleGroupIntent(c, intentResult);
|
||||||
|
if (anwser !== undefined) {
|
||||||
|
c.sendMessage(anwser + '\n' + intentDebug);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
const now = Date.now();
|
||||||
|
if (lastCall === undefined || (now - lastCall) >= 60 * 10 * 1000) {
|
||||||
|
c.sendMessage('RAG 系统目前离线');
|
||||||
|
}
|
||||||
|
lastCall = Date.now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -23,13 +23,15 @@ async function useRagLLM(c: LagrangeContext<GroupMessage>, intentResult: IntentR
|
|||||||
texts.push(msg.data.text);
|
texts.push(msg.data.text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const query = texts.join(' ').trim();
|
||||||
|
if (query.length === 0) {
|
||||||
const { data } = await apiQueryVecdb({ query: texts.join(' '), k: 3 });
|
return undefined;
|
||||||
|
}
|
||||||
|
const { data } = await apiQueryVecdb({ query, k: 3 });
|
||||||
if (data.code === 200) {
|
if (data.code === 200) {
|
||||||
const messages: apiQueryVecdbDataItem[] = data.data.filter(m => m.score <= 0.8);
|
const messages: apiQueryVecdbDataItem[] = data.data.filter(m => m.score <= 0.7);
|
||||||
if (messages.length === 0) {
|
if (messages.length === 0) {
|
||||||
c.sendMessage('未在数据库中检索到相关内容。');
|
return '未在数据库中检索到相关内容。';
|
||||||
} else {
|
} else {
|
||||||
const query = makePrompt(messages);
|
const query = makePrompt(messages);
|
||||||
const res = await llm.answer([
|
const res = await llm.answer([
|
||||||
@ -40,16 +42,16 @@ async function useRagLLM(c: LagrangeContext<GroupMessage>, intentResult: IntentR
|
|||||||
]);
|
]);
|
||||||
if (typeof res === 'string') {
|
if (typeof res === 'string') {
|
||||||
const links = messages.map(m => m.source);
|
const links = messages.map(m => m.source);
|
||||||
|
const linkSet = new Set<string>(links);
|
||||||
const reference = ['参考链接:', ...links].join('\n');
|
const reference = ['参考链接:', ...linkSet].join('\n');
|
||||||
const anwser = res + '\n\n' + reference;
|
const anwser = res + '\n\n' + reference;
|
||||||
c.sendMessage(anwser);
|
return anwser;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
logger.error('apiQueryVecdb 接口访问失败: ' + JSON.stringify(data));
|
logger.error('apiQueryVecdb 接口访问失败: ' + JSON.stringify(data));
|
||||||
return false;
|
return undefined;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import '../plugins/image';
|
import '../plugins/image';
|
||||||
|
|
||||||
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send } from 'lagrange.onebot'
|
import { mapper, plugins, LagrangeContext, PrivateMessage, GroupMessage, Send, logger } from 'lagrange.onebot'
|
||||||
|
|
||||||
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
|
import { apiGetIntentRecogition, apiQueryVecdb } from '../api/vecdb';
|
||||||
|
|
||||||
@ -41,9 +41,16 @@ export class Impl {
|
|||||||
if (intentResult.uncertainty >= 0.33) {
|
if (intentResult.uncertainty >= 0.33) {
|
||||||
intentResult.name = 'others';
|
intentResult.name = 'others';
|
||||||
}
|
}
|
||||||
|
|
||||||
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
|
const uncertainty = Math.round(intentResult.uncertainty * 1000) / 1000;
|
||||||
c.sendMessage(`【意图: ${intentResult.name} 不确定度: ${uncertainty}】`);
|
const intentDebug = `【意图: ${intentResult.name} 不确定度: ${uncertainty}】`;
|
||||||
handleGroupIntent(c, intentResult);
|
const anwser = await handleGroupIntent(c, intentResult);
|
||||||
|
if (anwser === undefined) {
|
||||||
|
c.sendMessage('拒答' + '\n' + intentDebug);
|
||||||
|
} else {
|
||||||
|
c.sendMessage(anwser + '\n' + intentDebug);
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
c.sendMessage('RAG 系统目前离线');
|
c.sendMessage('RAG 系统目前离线');
|
||||||
}
|
}
|
||||||
|
1
config/push-to-server.sh
Normal file
1
config/push-to-server.sh
Normal file
@ -0,0 +1 @@
|
|||||||
|
scp -r config ubuntu@101.43.239.71:/home/ubuntu/files/data/llm-rag
|
99
notebook/clear-logs.ipynb
Normal file
99
notebook/clear-logs.ipynb
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import zipfile\n",
|
||||||
|
"from typing import *\n",
|
||||||
|
"import os\n",
|
||||||
|
"import json\n",
|
||||||
|
"import yaml\n",
|
||||||
|
"\n",
|
||||||
|
"log_home = '../logs'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def get_log_file_line(file: str, level='DEBUG') -> Generator:\n",
|
||||||
|
" if file.endswith('.log'):\n",
|
||||||
|
" for line in open(file, 'r', encoding='utf-8'):\n",
|
||||||
|
" if level in line:\n",
|
||||||
|
" yield line\n",
|
||||||
|
" elif file.endswith('.zip'):\n",
|
||||||
|
" with zipfile.ZipFile(file, 'r') as zip:\n",
|
||||||
|
" for zip_file in zip.namelist():\n",
|
||||||
|
" file_bytes = zip.read(zip_file)\n",
|
||||||
|
" for line in file_bytes.decode('utf-8').split('\\n'):\n",
|
||||||
|
" if level in line:\n",
|
||||||
|
" yield line\n",
|
||||||
|
" "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"rag_log_files = [os.path.join(log_home, file) for file in os.listdir(log_home) if file.startswith('rag')]\n",
|
||||||
|
"\n",
|
||||||
|
"interesting_data = []\n",
|
||||||
|
"\n",
|
||||||
|
"for rag_file in rag_log_files:\n",
|
||||||
|
" for line in get_log_file_line(rag_file, level='DEBUG'):\n",
|
||||||
|
" try:\n",
|
||||||
|
" data = line.split('|')[-1].strip()\n",
|
||||||
|
" data = eval(data)\n",
|
||||||
|
" if data['intent']['name'] != 'others':\n",
|
||||||
|
" interesting_data.append(data)\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" pass"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"story = {'stories': []}\n",
|
||||||
|
"for d in interesting_data:\n",
|
||||||
|
" story['stories'].append({\n",
|
||||||
|
" 'message': d['query'],\n",
|
||||||
|
" 'intent': d['intent']['name']\n",
|
||||||
|
" })\n",
|
||||||
|
"\n",
|
||||||
|
"with open('../config/qq.story.yml', 'w', encoding='utf-8') as fp:\n",
|
||||||
|
" yaml.dump(story, fp, allow_unicode=True)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "base",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
@ -8,6 +8,7 @@ import torch
|
|||||||
from rag.db.embedding import embedding
|
from rag.db.embedding import embedding
|
||||||
from rag.api.constant import StatusCode, MsgCode
|
from rag.api.constant import StatusCode, MsgCode
|
||||||
from rag.api.admin import app
|
from rag.api.admin import app
|
||||||
|
from rag.api.admin import logger
|
||||||
from rag.api.config import necessary_files
|
from rag.api.config import necessary_files
|
||||||
from rag.model.enn import LinearEnn, train_enn
|
from rag.model.enn import LinearEnn, train_enn
|
||||||
|
|
||||||
@ -60,30 +61,31 @@ def reload_embedding_mapping():
|
|||||||
response.status_code = StatusCode.success.value
|
response.status_code = StatusCode.success.value
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
# TODO: 删除该接口
|
||||||
@app.route('/intent/retrain-embedding-mapping', methods=['post'])
|
@app.route('/intent/retrain-embedding-mapping', methods=['post'])
|
||||||
def retrain_embedding_mapping():
|
def retrain_embedding_mapping():
|
||||||
engine = PromptEngine(necessary_files['intent-story'])
|
# engine = PromptEngine(necessary_files['intent-story'])
|
||||||
engine.merge_stories_from_yml(necessary_files['issue-story'])
|
# engine.merge_stories_from_yml(necessary_files['issue-story'])
|
||||||
sentences = []
|
# sentences = []
|
||||||
labels = []
|
# labels = []
|
||||||
for story in engine.stories:
|
# for story in engine.stories:
|
||||||
sentences.append(story.message)
|
# sentences.append(story.message)
|
||||||
labels.append(engine.intent2id[story.intent])
|
# labels.append(engine.intent2id[story.intent])
|
||||||
try:
|
# try:
|
||||||
labels = np.array(labels)
|
# labels = np.array(labels)
|
||||||
embed = embedding.embed_documents(sentences)
|
# embed = embedding.embed_documents(sentences)
|
||||||
enn_model = intent_recogition.classifier
|
# enn_model = intent_recogition.classifier
|
||||||
train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
|
# train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
|
||||||
torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
|
# torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
response = jsonify({
|
# response = jsonify({
|
||||||
'code': StatusCode.process_error.value,
|
# 'code': StatusCode.process_error.value,
|
||||||
'data': str(e),
|
# 'data': str(e),
|
||||||
'msg': MsgCode.query_not_empty.value
|
# 'msg': MsgCode.query_not_empty.value
|
||||||
})
|
# })
|
||||||
response.status_code = StatusCode.success.value
|
# response.status_code = StatusCode.success.value
|
||||||
return response
|
# return response
|
||||||
|
|
||||||
response = jsonify({
|
response = jsonify({
|
||||||
'code': StatusCode.success.value,
|
'code': StatusCode.success.value,
|
||||||
@ -113,6 +115,12 @@ def get_intent_recogition():
|
|||||||
|
|
||||||
result = intent_recogition.get_intent_recogition(query)
|
result = intent_recogition.get_intent_recogition(query)
|
||||||
|
|
||||||
|
logger_chunk = json.dumps({
|
||||||
|
'query': query,
|
||||||
|
'intent': result
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
logger.debug(logger_chunk)
|
||||||
|
|
||||||
response = jsonify({
|
response = jsonify({
|
||||||
'code': StatusCode.success.value,
|
'code': StatusCode.success.value,
|
||||||
'data': result,
|
'data': result,
|
||||||
|
@ -42,6 +42,12 @@ suite('test intent recogition', () => {
|
|||||||
{ input: '因为这是养蛊的虚拟机,放了些国产垃圾软件,得用国产流氓之王才能镇得住他们', expect: 'others' },
|
{ input: '因为这是养蛊的虚拟机,放了些国产垃圾软件,得用国产流氓之王才能镇得住他们', expect: 'others' },
|
||||||
{ input: '你咋装了个360', expect: 'others' },
|
{ input: '你咋装了个360', expect: 'others' },
|
||||||
{ input: '???', expect: 'expression,others' },
|
{ input: '???', expect: 'expression,others' },
|
||||||
|
{ input: '有点怪,这里没有和竖线对齐', expect: 'others' },
|
||||||
|
{ input: '调试时候下载bit流和ltx,固化需要在xdc约束里面添加spi或bpi相关约束后生成bit流。建议一开始xdc约束当中就添加相关约束,后续生成mcs文件需要先添加flash型号,根据相关型号生成对应的mcs文件后再固化到flash当中。', expect: 'advice,others' },
|
||||||
|
{ input: 'at 触发不够人性化,很多用户不知道 Tip 的存在就不会去使用它', expect: 'expression,others' },
|
||||||
|
{ input: '非问句会触发吗', expect: 'expression,others' },
|
||||||
|
{ input: 'zlib', expect: 'expression,others' },
|
||||||
|
{ input: 'https://github.com/Digital-EDA/Digital-IDE/discussions', expect: 'others' },
|
||||||
];
|
];
|
||||||
|
|
||||||
for (const s of intent_suites) {
|
for (const s of intent_suites) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user