98 lines
2.9 KiB
Python
98 lines
2.9 KiB
Python
from flask import Flask, request, jsonify
|
|
import numpy as np
|
|
import joblib
|
|
import json
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
|
from embedding import embedding
|
|
from constant import StatusCode, MsgCode
|
|
from admin import app
|
|
from configs import necessary_files
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.abspath('.'))
|
|
|
|
from prompt import PromptEngine
|
|
|
|
class IntentRecogition:
|
|
def __init__(self) -> None:
|
|
self.embed_intent_classificator = joblib.load(necessary_files['intent-classifier'])
|
|
self.engine = PromptEngine(necessary_files['intent-story'])
|
|
|
|
def get_intent_recogition(self, query: str) -> dict:
|
|
query_embed = embedding.embed_documents([query])
|
|
result_id = self.embed_intent_classificator.predict(query_embed)[0]
|
|
result_id = int(result_id)
|
|
return {
|
|
'id': result_id,
|
|
'name': self.engine.id2intent[result_id]
|
|
}
|
|
|
|
intent_recogition = IntentRecogition()
|
|
|
|
|
|
@app.route('/intent/retrain-embedding-mapping', methods=['post'])
|
|
def retrain_embedding_mapping():
|
|
engine = PromptEngine(necessary_files['intent-story'])
|
|
model = LogisticRegression()
|
|
sentences = []
|
|
labels = []
|
|
for story in engine.stories:
|
|
sentences.append(story.message)
|
|
labels.append(engine.intent2id[story.intent])
|
|
|
|
try:
|
|
labels = np.array(labels)
|
|
embed = embedding.embed_documents(sentences)
|
|
model.fit(embed, labels)
|
|
|
|
intent_recogition.engine = engine
|
|
intent_recogition.embed_intent_classificator = model
|
|
joblib.dump(model, necessary_files['intent-classifier'])
|
|
except Exception as e:
|
|
response = jsonify({
|
|
'code': StatusCode.process_error.value,
|
|
'data': str(e),
|
|
'msg': MsgCode.query_not_empty.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|
|
|
|
|
|
response = jsonify({
|
|
'code': StatusCode.success.value,
|
|
'data': 'save data to ' + necessary_files['intent-classifier'],
|
|
'msg': StatusCode.success.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|
|
|
|
|
|
|
|
@app.route('/intent/get-intent-recogition', methods=['post'])
|
|
def get_intent_recogition():
|
|
params = request.data.decode('utf-8')
|
|
params: dict = json.loads(params)
|
|
result_data = {}
|
|
|
|
query = params.get('query', None)
|
|
if query is None:
|
|
response = jsonify({
|
|
'code': StatusCode.user_error.value,
|
|
'data': result_data,
|
|
'msg': MsgCode.query_not_empty.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|
|
|
|
result = intent_recogition.get_intent_recogition(query)
|
|
|
|
response = jsonify({
|
|
'code': StatusCode.success.value,
|
|
'data': result,
|
|
'msg': StatusCode.success.value
|
|
})
|
|
response.status_code = StatusCode.success.value
|
|
return response
|