176 lines
6.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import warnings
import random
import math
import yaml
@dataclass(frozen=True)
class IntentNode:
name: str
description: str | None
children: list[IntentNode]
parent: IntentNode | None
stories: list[Story]
@dataclass(frozen=True)
class Story:
message: str
intent: str
class PromptEngine:
path: str
schema: IntentNode | None
stories: list[Story]
rejects: list[str]
intent2id: dict[str, int]
id2intent: dict[int, str]
name2node: dict[str, IntentNode]
def __init__(self, path: str) -> None:
self.path = path
self.config = yaml.load(open(path, 'r', encoding='utf-8'), yaml.Loader)
self.intent2id = {}
self.id2intent = {}
self.name2node = {}
self.schema = self.handle_schema(self.config['schema'])
self.stories = self.handle_stories(self.config['stories'])
self.rejects = self.handle_rejects(self.config['rejects'])
def handle_schema(self, raw_schema: dict) -> IntentNode:
raw_root = raw_schema.get('root', None)
if raw_root is None:
warnings.warn('schema must have a root node as the beginning, otherwise intent recogition will not work')
return None
current_layers: list[tuple[dict, IntentNode | None]] = [(raw_root, None)]
nodes: list[IntentNode] = []
# 层次遍历
while len(current_layers) > 0:
new_current_layers: list[tuple[dict, IntentNode | None]] = []
for raw_node, intent_node in current_layers:
name = raw_node.get('name', None)
children = raw_node.get('children', None)
description = raw_node.get('description', None)
if name is None:
raise NameError('you must specify a name in story item, current item : {}'.format(raw_node))
if children is None:
children = []
if name not in self.intent2id:
assign_id = len(self.intent2id)
self.intent2id[name] = assign_id
self.id2intent[assign_id] = name
node = IntentNode(name, description, [], intent_node, [])
self.name2node[name] = node
nodes.append(node)
if intent_node:
intent_node.children.append(node)
for raw_node in children:
new_current_layers.append((raw_node, node))
current_layers.clear()
current_layers.extend(new_current_layers)
return nodes[0]
def handle_stories(self, raw_stories: list[dict]) -> list[Story]:
stories: list[Story] = []
for pair in raw_stories:
message = pair.get('message', None)
intent = pair.get('intent', None)
if intent not in self.intent2id:
warnings.warn('{} is not the intent you declare in schema, so this pair will be ignored'.format(intent))
continue
if message and intent:
story = Story(message, intent)
node = self.name2node.get(intent)
node.stories.append(story)
stories.append(story)
return stories
def handle_rejects(self, raw_rejects: list[str]) -> list[str]:
rejects = []
for reject in raw_rejects:
rejects.append(reject)
return rejects
def generate_chunk(self, stories: list[Story]) -> tuple[str]:
prompts = []
for story in stories:
prompts.append('Message: ' + story.message.strip())
intent_id = self.intent2id.get(story.intent)
prompts.append('Intent: { id: %s }' % (intent_id))
prompts.pop()
user_content = '\n'.join(prompts) + '\n' + 'Intent: '
assistant_content = '{id : %s}' % (intent_id)
return user_content, assistant_content
def generate_llm_message(self, question: str, intent: IntentNode = None, chunk_size: int = 5, max_chunk_num: int = 10):
if intent is None:
intent = self.schema
story_cache = []
for node in intent.children:
story_cache.extend(node.stories)
random.shuffle(story_cache)
chunk_num = math.ceil(len(story_cache) / chunk_size)
message = []
for chunk_id in range(chunk_num):
start = chunk_id * chunk_size
end = min(len(story_cache), start + chunk_size)
chunk = story_cache[start: end]
user_content, assistant_content = self.generate_chunk(chunk)
message.append({
'role': 'user',
'content': user_content
})
message.append({
'role': 'assistant',
'content': assistant_content
})
if len(message) / 2 >= max_chunk_num:
break
message.append({
'role': 'user',
'content': question + '\nIntent: '
})
# 创建开头的预设
preset = 'Label a users message from a conversation with an intent. Reply ONLY with the name of the intent.'
intent_preset = ['The intent should be one of the following:']
for node in intent.children:
intent_id = self.intent2id.get(node.name)
intent_preset.append('- {}'.format(intent_id))
intent_preset = '\n'.join(intent_preset)
message[0]['content'] = preset + '\n' + intent_preset + '\n' + message[0]['content']
return message
class KIntent:
path: str
engine: PromptEngine
def __init__(self, path: str) -> None:
self.path = path
self.engine = PromptEngine(path)
def inference(self, question: str, chunk_size: int = 5, max_chunk_num: int = 10) -> list[IntentNode]:
root_node = self.engine.schema
results: list[IntentNode] = []
stack = [root_node]
while len(stack) > 0:
node = stack.pop()
if __name__ == '__main__':
prompt_engine = PromptEngine('./story.yml')
msg = prompt_engine.generate_llm_message('如何解决 digital ide 无法载入配置文件的问题?')
print(msg)