123 lines
3.9 KiB
Python

from flask import Flask, request, jsonify
import numpy as np
import joblib
import json
import torch
from rag.db.embedding import embedding
from rag.api.constant import StatusCode, MsgCode
from rag.api.admin import app
from rag.api.config import necessary_files
from rag.model.enn import LinearEnn, train_enn
import sys
import os
sys.path.append(os.path.abspath('.'))
from prompt import PromptEngine
class IntentRecogition:
def __init__(self) -> None:
self.engine = PromptEngine(necessary_files['intent-story'])
self.classifier = LinearEnn(in_dim=768, out_dim=7, focal=0, alpha_kl=0)
self.classifier.load_state_dict(torch.load(necessary_files['intent-classifier']))
def get_intent_recogition(self, query: str) -> dict:
query_embed = embedding.embed_documents([query])
prob, u = self.classifier.predict(query_embed)
result_id = prob.argmax(dim=1)
u = u.item()
result_id = int(result_id.item())
return {
'id': int(result_id),
'name': self.engine.id2intent[result_id],
'uncertainty': float(u)
}
intent_recogition = IntentRecogition()
@app.route('/intent/reload-embedding-mapping', methods=['post'])
def reload_embedding_mapping():
try:
intent_recogition.classifier.load_state_dict(torch.load(necessary_files['intent-classifier']))
except Exception as e:
response = jsonify({
'code': StatusCode.process_error.value,
'data': str(e),
'msg': MsgCode.query_not_empty.value
})
response.status_code = StatusCode.success.value
return response
response = jsonify({
'code': StatusCode.success.value,
'data': 'load model from ' + necessary_files['intent-classifier'],
'msg': StatusCode.success.value
})
response.status_code = StatusCode.success.value
return response
@app.route('/intent/retrain-embedding-mapping', methods=['post'])
def retrain_embedding_mapping():
engine = PromptEngine(necessary_files['intent-story'])
engine.merge_stories_from_yml(necessary_files['issue-story'])
sentences = []
labels = []
for story in engine.stories:
sentences.append(story.message)
labels.append(engine.intent2id[story.intent])
try:
labels = np.array(labels)
embed = embedding.embed_documents(sentences)
enn_model = intent_recogition.classifier
train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100)
torch.save(enn_model.state_dict(), necessary_files['intent-classifier'])
except Exception as e:
response = jsonify({
'code': StatusCode.process_error.value,
'data': str(e),
'msg': MsgCode.query_not_empty.value
})
response.status_code = StatusCode.success.value
return response
response = jsonify({
'code': StatusCode.success.value,
'data': 'save data to ' + necessary_files['intent-classifier'],
'msg': StatusCode.success.value
})
response.status_code = StatusCode.success.value
return response
@app.route('/intent/get-intent-recogition', methods=['post'])
def get_intent_recogition():
params = request.data.decode('utf-8')
params: dict = json.loads(params)
result_data = {}
query = params.get('query', None)
if query is None:
response = jsonify({
'code': StatusCode.user_error.value,
'data': result_data,
'msg': MsgCode.query_not_empty.value
})
response.status_code = StatusCode.success.value
return response
result = intent_recogition.get_intent_recogition(query)
response = jsonify({
'code': StatusCode.success.value,
'data': result,
'msg': StatusCode.success.value
})
response.status_code = StatusCode.success.value
return response