131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
from flask import Flask, request, jsonify
|
|
import numpy as np
|
|
import joblib
|
|
import json
|
|
import torch
|
|
|
|
|
|
from rag.db.embedding import embedding
|
|
from rag.api.constant import StatusCode, MsgCode
|
|
from rag.api.admin import app
|
|
from rag.api.admin import logger
|
|
from rag.api.config import necessary_files
|
|
from rag.model.enn import LinearEnn, train_enn
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.abspath('.'))
|
|
|
|
from prompt import PromptEngine
|
|
|
|
class IntentRecogition:
|
|
def __init__(self) -> None:
|
|
self.engine = PromptEngine(necessary_files['intent-story'])
|
|
self.classifier = LinearEnn(in_dim=768, out_dim=7, focal=0, alpha_kl=0)
|
|
self.classifier.load_state_dict(torch.load(necessary_files['intent-classifier']))
|
|
|
|
def get_intent_recogition(self, query: str) -> dict:
|
|
query_embed = embedding.embed_documents([query])
|
|
prob, u = self.classifier.predict(query_embed)
|
|
|
|
result_id = prob.argmax(dim=1)
|
|
u = u.item()
|
|
|
|
result_id = int(result_id.item())
|
|
return {
|
|
'id': int(result_id),
|
|
'name': self.engine.id2intent[result_id],
|
|
'uncertainty': float(u)
|
|
}
|
|
|
|
intent_recogition = IntentRecogition()
|
|
|
|
@app.route('/intent/reload-embedding-mapping', methods=['post'])
|
|
def reload_embedding_mapping():
|
|
try:
|
|
intent_recogition.classifier.load_state_dict(torch.load(necessary_files['intent-classifier']))
|
|
except Exception as e:
|
|
response = jsonify({
|
|
'code': StatusCode.process_error.value,
|
|
'data': str(e),
|
|
'msg': MsgCode.query_not_empty.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|
|
|
|
response = jsonify({
|
|
'code': StatusCode.success.value,
|
|
'data': 'load model from ' + necessary_files['intent-classifier'],
|
|
'msg': StatusCode.success.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|
|
|
|
# TODO: 删除该接口
|
|
@app.route('/intent/retrain-embedding-mapping', methods=['post'])
|
|
def retrain_embedding_mapping():
|
|
# engine = PromptEngine(necessary_files['intent-story'])
|
|
# engine.merge_stories_from_yml(necessary_files['issue-story'])
|
|
# sentences = []
|
|
# labels = []
|
|
# for story in engine.stories:
|
|
# sentences.append(story.message)
|
|
# labels.append(engine.intent2id[story.intent])
|
|
# try:
|
|
# labels = np.array(labels)
|
|
# embed = embedding.embed_documents(sentences)
|
|
# enn_model = intent_recogition.classifier
|
|
# train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
|
|
# torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
|
|
|
|
# except Exception as e:
|
|
# response = jsonify({
|
|
# 'code': StatusCode.process_error.value,
|
|
# 'data': str(e),
|
|
# 'msg': MsgCode.query_not_empty.value
|
|
# })
|
|
# response.status_code = StatusCode.success.value
|
|
# return response
|
|
|
|
response = jsonify({
|
|
'code': StatusCode.success.value,
|
|
'data': 'save data to ' + necessary_files['intent-classifier'],
|
|
'msg': StatusCode.success.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|
|
|
|
|
|
|
|
@app.route('/intent/get-intent-recogition', methods=['post'])
|
|
def get_intent_recogition():
|
|
params = request.data.decode('utf-8')
|
|
params: dict = json.loads(params)
|
|
result_data = {}
|
|
|
|
query = params.get('query', None)
|
|
if query is None:
|
|
response = jsonify({
|
|
'code': StatusCode.user_error.value,
|
|
'data': result_data,
|
|
'msg': MsgCode.query_not_empty.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|
|
|
|
result = intent_recogition.get_intent_recogition(query)
|
|
|
|
logger_chunk = json.dumps({
|
|
'query': query,
|
|
'intent': result
|
|
}, ensure_ascii=False)
|
|
logger.debug(logger_chunk)
|
|
|
|
response = jsonify({
|
|
'code': StatusCode.success.value,
|
|
'data': result,
|
|
'msg': StatusCode.success.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|