from __future__ import annotations from dataclasses import dataclass import warnings import random import math from abc import ABC, abstractmethod import yaml import json5 from loguru import logger logger.add( sink='./logs/prompt.log', level='DEBUG', rotation='00:00', retention='7 days', compression='zip', encoding='utf-8', enqueue=True, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" ) @dataclass(frozen=True) class IntentNode: name: str description: str | None children: list[IntentNode] parent: IntentNode | None stories: list[Story] @dataclass(frozen=True) class Story: message: str intent: str class PromptEngine: path: str schema: IntentNode | None stories: list[Story] rejects: list[str] intent2id: dict[str, int] id2intent: dict[int, str] name2node: dict[str, IntentNode] def __init__(self, path: str) -> None: self.path = path self.config = yaml.load(open(path, 'r', encoding='utf-8'), yaml.Loader) self.intent2id = {} self.id2intent = {} self.name2node = {} self.schema = self.handle_schema(self.config['schema']) self.stories = self.handle_stories(self.config['stories']) self.rejects = self.handle_rejects(self.config['rejects']) def handle_schema(self, raw_schema: dict) -> IntentNode: raw_root = raw_schema.get('root', None) if raw_root is None: warnings.warn('schema must have a root node as the beginning, otherwise intent recogition will not work') return None current_layers: list[tuple[dict, IntentNode | None]] = [(raw_root, None)] nodes: list[IntentNode] = [] # 层次遍历 while len(current_layers) > 0: new_current_layers: list[tuple[dict, IntentNode | None]] = [] for raw_node, intent_node in current_layers: name = raw_node.get('name', None) children = raw_node.get('children', None) description = raw_node.get('description', None) if name is None: raise NameError('you must specify a name in story item, current item : {}'.format(raw_node)) if children is None: children = [] if name not in self.intent2id: assign_id = len(self.intent2id) self.intent2id[name] = assign_id self.id2intent[assign_id] = name node = IntentNode(name, description, [], intent_node, []) self.name2node[name] = node nodes.append(node) if intent_node: intent_node.children.append(node) for raw_node in children: new_current_layers.append((raw_node, node)) current_layers.clear() current_layers.extend(new_current_layers) return nodes[0] def handle_stories(self, raw_stories: list[dict]) -> list[Story]: stories: list[Story] = [] for pair in raw_stories: message = pair.get('message', None) intent = pair.get('intent', None) if intent not in self.intent2id: warnings.warn('{} is not the intent you declare in schema, so this pair will be ignored'.format(intent)) continue if message and intent: story = Story(message, intent) node = self.name2node.get(intent) node.stories.append(story) stories.append(story) return stories def merge_stories_from_yml(self, path: str): config = yaml.load(open(path, 'r', encoding='utf-8'), Loader=yaml.Loader) stories = config.get('stories', []) self.merge_stories(stories) def merge_stories(self, raw_stories: list[dict]): stories = self.handle_stories(raw_stories) self.stories.extend(stories) def handle_rejects(self, raw_rejects: list[str]) -> list[str]: rejects = [] for reject in raw_rejects: rejects.append(reject) return rejects def generate_chunk(self, stories: list[Story]) -> tuple[str]: prompts = [] for story in stories: prompts.append('Message: ' + story.message.strip()) intent_id = self.intent2id.get(story.intent) prompts.append('Intent: { id: %s }' % (intent_id)) prompts.pop() user_content = '\n'.join(prompts) + '\n' + 'Intent: ' assistant_content = '{id : %s}' % (intent_id) return user_content, assistant_content def generate_llm_message(self, question: str, intent: IntentNode = None, chunk_size: int = 5, max_chunk_num: int = 10) -> list[dict]: if intent is None: intent = self.schema story_cache = [] for node in intent.children: story_cache.extend(node.stories) random.shuffle(story_cache) chunk_num = math.ceil(len(story_cache) / chunk_size) message = [] for chunk_id in range(chunk_num): start = chunk_id * chunk_size end = min(len(story_cache), start + chunk_size) chunk = story_cache[start: end] user_content, assistant_content = self.generate_chunk(chunk) message.append({ 'role': 'user', 'content': user_content }) message.append({ 'role': 'assistant', 'content': assistant_content }) if len(message) / 2 >= max_chunk_num: break message.append({ 'role': 'user', 'content': question + '\nIntent: ' }) # 创建开头的预设 preset = 'Label a users message from a conversation with an intent. Reply ONLY with the name of the intent.' intent_preset = ['The intent should be one of the following:'] for node in intent.children: intent_id = self.intent2id.get(node.name) intent_preset.append('- {}'.format(intent_id)) intent_preset = '\n'.join(intent_preset) message[0]['content'] = preset + '\n' + intent_preset + '\n' + message[0]['content'] return message class TreeIntent(ABC): path: str engine: PromptEngine def __init__(self, path: str) -> None: self.path = path self.engine = PromptEngine(path) @abstractmethod def call_llm(self, message: list[dict]) -> str: """ example of message: [ { "role": "user", "content": "Message: 大佬们,为啥我的digital ide启动之后所有功能都没启动捏?我配置了property文件,然后插件的vivado路经和modelsim路经都加上了\nIntent: " }, { "role": "assistant", "content": "{ id: 0 }" }, { "role": "user", "content": "话说digital-ide打开大的verilog卡死了\nIntent: " }, { "role": "assistant", "content": "{ id: 1 }" } ] """ pass def purify_json(self, json_string: str): stack = [] start_index = None for i, ch in enumerate(json_string): if ch == '{': if len(stack) == 0: start_index = i stack.append(ch) elif ch == '}': stack.pop() if len(stack) == 0: return json_string[start_index: i + 1] else: pass return json_string def try_generate_intent_id(self, question: str, intent: IntentNode = None, chunk_size: int = 5, max_chunk_num: int = 10, retry: int = 3) -> int | None: engine = self.engine for i in range(retry): try: message = engine.generate_llm_message(question, intent, chunk_size, max_chunk_num) result = self.call_llm(message) result = self.purify_json(result) result = json5.loads(result) intent_id = result['id'] return int(intent_id) except Exception as e: continue return None def inference(self, question: str, chunk_size: int = 5, max_chunk_num: int = 10) -> list[IntentNode] | None: root_node = self.engine.schema results: list[IntentNode] = [] engine = self.engine stack: list[IntentNode] = [root_node] while len(stack) > 0: node = stack.pop() intent_id = self.try_generate_intent_id(question, node) if intent_id is None: logger.warning('fail to generate intent id from message, check log file for details') logger.debug(json5.dumps({ 'question': question, 'node.name': node.name }, ensure_ascii=False)) return None if intent_id not in engine.id2intent: logger.warning('inferred intent id {} not in the list of engine.id2intent {}'.format(intent_id, list(engine.id2intent.keys()))) logger.debug(json5.dumps({ 'question': question, 'node.name': node.name, 'intent_id': intent_id }, ensure_ascii=False)) return None intent_name = engine.id2intent[intent_id] intent_node = engine.name2node[intent_name] results.append(intent_node) if len(intent_node.children) >= 2: stack.append(intent_node) return results if __name__ == '__main__': prompt_engine = PromptEngine('./config/story.yml') prompt_engine.merge_stories_from_yml('./config/github-issue.story.yml') print(len(prompt_engine.stories))