176 lines
6.4 KiB
Python
176 lines
6.4 KiB
Python
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
import warnings
|
|
import random
|
|
import math
|
|
|
|
import yaml
|
|
|
|
@dataclass(frozen=True)
|
|
class IntentNode:
|
|
name: str
|
|
description: str | None
|
|
children: list[IntentNode]
|
|
parent: IntentNode | None
|
|
stories: list[Story]
|
|
|
|
@dataclass(frozen=True)
|
|
class Story:
|
|
message: str
|
|
intent: str
|
|
|
|
class PromptEngine:
|
|
path: str
|
|
schema: IntentNode | None
|
|
stories: list[Story]
|
|
rejects: list[str]
|
|
intent2id: dict[str, int]
|
|
id2intent: dict[int, str]
|
|
name2node: dict[str, IntentNode]
|
|
|
|
def __init__(self, path: str) -> None:
|
|
self.path = path
|
|
self.config = yaml.load(open(path, 'r', encoding='utf-8'), yaml.Loader)
|
|
self.intent2id = {}
|
|
self.id2intent = {}
|
|
self.name2node = {}
|
|
self.schema = self.handle_schema(self.config['schema'])
|
|
self.stories = self.handle_stories(self.config['stories'])
|
|
self.rejects = self.handle_rejects(self.config['rejects'])
|
|
|
|
def handle_schema(self, raw_schema: dict) -> IntentNode:
|
|
raw_root = raw_schema.get('root', None)
|
|
if raw_root is None:
|
|
warnings.warn('schema must have a root node as the beginning, otherwise intent recogition will not work')
|
|
return None
|
|
|
|
current_layers: list[tuple[dict, IntentNode | None]] = [(raw_root, None)]
|
|
nodes: list[IntentNode] = []
|
|
|
|
# 层次遍历
|
|
while len(current_layers) > 0:
|
|
new_current_layers: list[tuple[dict, IntentNode | None]] = []
|
|
for raw_node, intent_node in current_layers:
|
|
name = raw_node.get('name', None)
|
|
children = raw_node.get('children', None)
|
|
description = raw_node.get('description', None)
|
|
if name is None:
|
|
raise NameError('you must specify a name in story item, current item : {}'.format(raw_node))
|
|
if children is None:
|
|
children = []
|
|
|
|
if name not in self.intent2id:
|
|
assign_id = len(self.intent2id)
|
|
self.intent2id[name] = assign_id
|
|
self.id2intent[assign_id] = name
|
|
|
|
node = IntentNode(name, description, [], intent_node, [])
|
|
self.name2node[name] = node
|
|
|
|
nodes.append(node)
|
|
if intent_node:
|
|
intent_node.children.append(node)
|
|
for raw_node in children:
|
|
new_current_layers.append((raw_node, node))
|
|
current_layers.clear()
|
|
current_layers.extend(new_current_layers)
|
|
return nodes[0]
|
|
|
|
def handle_stories(self, raw_stories: list[dict]) -> list[Story]:
|
|
stories: list[Story] = []
|
|
for pair in raw_stories:
|
|
message = pair.get('message', None)
|
|
intent = pair.get('intent', None)
|
|
if intent not in self.intent2id:
|
|
warnings.warn('{} is not the intent you declare in schema, so this pair will be ignored'.format(intent))
|
|
continue
|
|
if message and intent:
|
|
story = Story(message, intent)
|
|
node = self.name2node.get(intent)
|
|
node.stories.append(story)
|
|
stories.append(story)
|
|
return stories
|
|
|
|
def handle_rejects(self, raw_rejects: list[str]) -> list[str]:
|
|
rejects = []
|
|
for reject in raw_rejects:
|
|
rejects.append(reject)
|
|
return rejects
|
|
|
|
def generate_chunk(self, stories: list[Story]) -> tuple[str]:
|
|
prompts = []
|
|
for story in stories:
|
|
prompts.append('Message: ' + story.message.strip())
|
|
intent_id = self.intent2id.get(story.intent)
|
|
prompts.append('Intent: { id: %s }' % (intent_id))
|
|
|
|
prompts.pop()
|
|
|
|
user_content = '\n'.join(prompts) + '\n' + 'Intent: '
|
|
assistant_content = '{id : %s}' % (intent_id)
|
|
return user_content, assistant_content
|
|
|
|
def generate_llm_message(self, question: str, intent: IntentNode = None, chunk_size: int = 5, max_chunk_num: int = 10):
|
|
if intent is None:
|
|
intent = self.schema
|
|
|
|
story_cache = []
|
|
for node in intent.children:
|
|
story_cache.extend(node.stories)
|
|
|
|
random.shuffle(story_cache)
|
|
chunk_num = math.ceil(len(story_cache) / chunk_size)
|
|
message = []
|
|
for chunk_id in range(chunk_num):
|
|
start = chunk_id * chunk_size
|
|
end = min(len(story_cache), start + chunk_size)
|
|
chunk = story_cache[start: end]
|
|
user_content, assistant_content = self.generate_chunk(chunk)
|
|
message.append({
|
|
'role': 'user',
|
|
'content': user_content
|
|
})
|
|
message.append({
|
|
'role': 'assistant',
|
|
'content': assistant_content
|
|
})
|
|
|
|
if len(message) / 2 >= max_chunk_num:
|
|
break
|
|
|
|
message.append({
|
|
'role': 'user',
|
|
'content': question + '\nIntent: '
|
|
})
|
|
|
|
# 创建开头的预设
|
|
preset = 'Label a users message from a conversation with an intent. Reply ONLY with the name of the intent.'
|
|
intent_preset = ['The intent should be one of the following:']
|
|
for node in intent.children:
|
|
intent_id = self.intent2id.get(node.name)
|
|
intent_preset.append('- {}'.format(intent_id))
|
|
intent_preset = '\n'.join(intent_preset)
|
|
message[0]['content'] = preset + '\n' + intent_preset + '\n' + message[0]['content']
|
|
return message
|
|
|
|
|
|
class KIntent:
|
|
path: str
|
|
engine: PromptEngine
|
|
def __init__(self, path: str) -> None:
|
|
self.path = path
|
|
self.engine = PromptEngine(path)
|
|
|
|
def inference(self, question: str, chunk_size: int = 5, max_chunk_num: int = 10) -> list[IntentNode]:
|
|
root_node = self.engine.schema
|
|
results: list[IntentNode] = []
|
|
stack = [root_node]
|
|
while len(stack) > 0:
|
|
node = stack.pop()
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
prompt_engine = PromptEngine('./story.yml')
|
|
msg = prompt_engine.generate_llm_message('如何解决 digital ide 无法载入配置文件的问题?')
|
|
print(msg) |