from flask import Flask, request, jsonify import numpy as np import joblib import json import torch from rag.db.embedding import embedding from rag.api.constant import StatusCode, MsgCode from rag.api.admin import app from rag.api.config import necessary_files from rag.model.enn import LinearEnn, train_enn import sys import os sys.path.append(os.path.abspath('.')) from prompt import PromptEngine class IntentRecogition: def __init__(self) -> None: self.engine = PromptEngine(necessary_files['intent-story']) self.classifier = LinearEnn(in_dim=768, out_dim=7, focal=0, alpha_kl=0) self.classifier.load_state_dict(torch.load(necessary_files['intent-classifier'])) def get_intent_recogition(self, query: str) -> dict: query_embed = embedding.embed_documents([query]) prob, u = self.classifier.predict(query_embed) result_id = prob.argmax(dim=1) u = u.item() result_id = int(result_id.item()) return { 'id': int(result_id), 'name': self.engine.id2intent[result_id], 'uncertainty': float(u) } intent_recogition = IntentRecogition() @app.route('/intent/reload-embedding-mapping', methods=['post']) def reload_embedding_mapping(): try: intent_recogition.classifier.load_state_dict(torch.load(necessary_files['intent-classifier'])) except Exception as e: response = jsonify({ 'code': StatusCode.process_error.value, 'data': str(e), 'msg': MsgCode.query_not_empty.value }) response.status_code = StatusCode.success.value return response response = jsonify({ 'code': StatusCode.success.value, 'data': 'load model from ' + necessary_files['intent-classifier'], 'msg': StatusCode.success.value }) response.status_code = StatusCode.success.value return response @app.route('/intent/retrain-embedding-mapping', methods=['post']) def retrain_embedding_mapping(): engine = PromptEngine(necessary_files['intent-story']) engine.merge_stories_from_yml(necessary_files['issue-story']) sentences = [] labels = [] for story in engine.stories: sentences.append(story.message) labels.append(engine.intent2id[story.intent]) try: labels = np.array(labels) embed = embedding.embed_documents(sentences) enn_model = intent_recogition.classifier train_enn(enn_model, embed, labels, bs=64, lr=1e-3, epoch=100) torch.save(enn_model.state_dict(), necessary_files['intent-classifier']) except Exception as e: response = jsonify({ 'code': StatusCode.process_error.value, 'data': str(e), 'msg': MsgCode.query_not_empty.value }) response.status_code = StatusCode.success.value return response response = jsonify({ 'code': StatusCode.success.value, 'data': 'save data to ' + necessary_files['intent-classifier'], 'msg': StatusCode.success.value }) response.status_code = StatusCode.success.value return response @app.route('/intent/get-intent-recogition', methods=['post']) def get_intent_recogition(): params = request.data.decode('utf-8') params: dict = json.loads(params) result_data = {} query = params.get('query', None) if query is None: response = jsonify({ 'code': StatusCode.user_error.value, 'data': result_data, 'msg': MsgCode.query_not_empty.value }) response.status_code = StatusCode.success.value return response result = intent_recogition.get_intent_recogition(query) response = jsonify({ 'code': StatusCode.success.value, 'data': result, 'msg': StatusCode.success.value }) response.status_code = StatusCode.success.value return response