Lagrange.RagBot/notebook/experiment.ipynb

919 lines
115 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.manifold import TSNE\n",
"import yaml\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import sys\n",
"import os\n",
"\n",
"sys.path.append(os.path.abspath('..'))\n",
"from prompt import PromptEngine"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/data/zhelonghuang/miniconda3/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/data/zhelonghuang/miniconda3/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"# from BCEmbedding import EmbeddingModel\n",
"from langchain_community.embeddings import HuggingFaceEmbeddings\n",
"sentences = ['python 是什么', '请介绍一下 python']\n",
"# model = EmbeddingModel(model_name_or_path=\"maidalun1020/bce-embedding-base_v1\")\n",
"model = HuggingFaceEmbeddings(model_name='maidalun1020/bce-embedding-base_v1')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"embeddings = model.embed_documents(sentences)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(['请问 property.json 如何配置?',\n",
" '我的自动补全无法使用是不是有bug',\n",
" '帮我上传一下这份数据',\n",
" 'surface了解一下',\n",
" '大佬们为啥我的digital ide启动之后所有功能都没启动捏我配置了property文件然后插件的vivado路经和modelsim路经都加上了',\n",
" '这群要被chisel夺舍了吗',\n",
" 'Metals一开直接报错',\n",
" '话说digital-ide打开大的verilog卡死了',\n",
" '请问一下第一次点击对文件仿真可以出波形文件再次点击的时候就会提示unknown module type了。是哪个配置没配置好',\n",
" '怎么调整是哪个版本的vivado来构建工程呢',\n",
" '咱们这个插件win7的vscode是不是只能用很早之前的版本',\n",
" '帮我将这份数据保存到服务器上',\n",
" '他这个意思是 单个功耗很低 但是功耗低那肯定性能就寄 频率肯定不高 靠人多',\n",
" '我平时写代码就喜欢喝茶',\n",
" '感觉现在啥都在往AI靠',\n",
" '请问你们自动对齐插件用的啥?',\n",
" '不得不放一下我的咖啡笔记了',\n",
" 'stm32有什么好玩的应用不',\n",
" '别人设置的肯定有点不合适自己的',\n",
" 'http://hehezhou.cn/register2024/AArch64-regindex.html',\n",
" '因为他们py本领不是很强需要这些东西辅助',\n",
" '写C写多了顺手在pycharm写了个main.c',\n",
" '好流畅的にほんじんです',\n",
" '有没有接触过UI开发的想做一款寄存器管理的工具想把界面做的好看一点',\n",
" '现在嘉立创也在做FPGA了',\n",
" '大佬们更新0.3.3之后用iverilog仿真testbench中还是例化模块出错unknown module type这是什么原因啊',\n",
" '查了一下记录2017年买的静电容',\n",
" '我小时候电脑刚买回来一星期就被我玩坏了',\n",
" 'command not found: python',\n",
" 'path top.v is not a hdlFile 请问报这个错误大概是啥原因啊',\n",
" '咖啡喝不了,喝了胃不舒服',\n",
" '兄弟们有没有C语言绘图库推荐',\n",
" '在企业里面最大的问题是碰见傻逼怎么办?',\n",
" '如何使用 digital ide 这个插件?',\n",
" '我早上开着机去打论文 回来发现我电脑切换到Linux了',\n",
" '我在Windows下遇到的只要问题就是对于C程序包管理和编译管理器偶尔会不认识彼此但除此之外都很安稳win11除外',\n",
" '不能理解在生产环境用arch的人。。',\n",
" '请问一下xilinx fpga开发在win和linux平台哪个好',\n",
" '好羡慕你们可以开发自己喜欢的东西',\n",
" '???',\n",
" '',\n",
" '我人麻了',\n",
" '艹',\n",
" '我人傻了',\n",
" '我tm',\n",
" 'tnnd',\n",
" 'funny mud goup',\n",
" '衣服混起来洗有一件深色的掉色了现在我一盆白T恤全变成泥土色了',\n",
" '这可是陪伴了我六七年的衣服啊',\n",
" '唉神金',\n",
" '为啥要用手机号码啊',\n",
" '本人于今日12时点了4瓶500ml无糖可乐收到四瓶888ml',\n",
" '我是小趴菜',\n",
" '我也想组乐队 www',\n",
" '这个乐队谱,我看哭了',\n",
" '我的手机最大的用处就是当一个麦克风+相机',\n",
" '那你可能更适合iPhone不过稍微贵点入门级别5-6k好一点的得8k-1w',\n",
" 'F和弦弹不起来太真实了 hhh',\n",
" '草,这一套,比我买过的所有电子产品加起来还贵',\n",
" '奶茶点少糖游泳打乒乓球结果最近还胖了4斤',\n",
" '你感觉有必要搞一个全局的数据库之类的么,比如日程插件和聊天插件可能都会用到用户的一些信息,如果不共享就要写两套用户处理逻辑',\n",
" '确切说,前期他在外放声音看庆余年',\n",
" '因为这是养蛊的虚拟机,放了些国产垃圾软件,得用国产流氓之王才能镇得住他们',\n",
" '真的开线程是要tcl指令去改的',\n",
" '下个版本核心就是LSP设计了',\n",
" '后面研究下怎么玩这个插件',\n",
" '得上 Dirichlet 分布做一下不确定性拟合了',\n",
" '用群友的数据训了一下bias有点大',\n",
" '我觉得关键时刻可以防身',\n",
" '我记得windows默认max是2吧',\n",
" '再过几年毛利大叔就变成毛利老弟了',\n",
" '我也感觉自己老的好快',\n",
" '',\n",
" '我的反撤回还能用',\n",
" '关于波形显示的一些建议',\n",
" '采用iverilog生成的VCD貌似无法解析仿真数据',\n",
" '【v0.3.2】模块调用后netlist生成错误且仿真报错',\n",
" '网表优化与插入文档',\n",
" '插件文档导出问题',\n",
" '【v0.3.2】testbench修改之后再次仿真会报错',\n",
" '[0.3.2] [问题] 含参数的 Verilog 模块自动例化,代码格式不正确',\n",
" '【报错】RuntimeError: null function or function signature mismatch',\n",
" '报错verilog解析器无法解析以下代码',\n",
" '报错verilog解析器的bug',\n",
" '功能建议-能增加点类似Verilog-Mode的功能么',\n",
" '报错RuntimeError: null function or function signature mismatch、无法识别HDL文件',\n",
" '例化模块自动生成tb文件报错Unknown module type',\n",
" 'Errors happen when parsing d:/danpj/fpga/modelsim/mod1/user/src/count4.v. Error: \"RuntimeError: null function or function signature mismatch\". Just propose a valuable issue in our github repo ',\n",
" '自动例化报错',\n",
" '【问题】【0.3.2】重复提示 Error: \"RuntimeError: null function or function signature mismatch\"',\n",
" '【0.3.2】【问题】1无法解析localparam 2 带参数模块例化',\n",
" '基础教程太少',\n",
" '.v源文件未被正确识别',\n",
" '重复仿真时报错',\n",
" '插件不能使用',\n",
" '关于netlist的生成错误',\n",
" '[0.3.2] 支持对verilator 的dpi-c机制的支持 ',\n",
" '[0.3.2] 离线支持+SV支持',\n",
" 'Bad webstie connection on README',\n",
" '在声明数据位宽时使用宏定义会报错',\n",
" '0.3.2 无verilog语法检查且提示RuntimeError',\n",
" '[0.3.2] 模块定义跳转偶尔会出现问题',\n",
" '[0.3.2] 代码补全有多个内容完全相同的选项',\n",
" '[0.3.2] 例化模块的类型,模块名称的代码高亮不变色',\n",
" '文档中的params和ports数反了',\n",
" '[建议]优化Formatter与文档生成',\n",
" '[0.3.2] Linter(vivado) 启用无效 (还是说我用的vivado2023太新了?)',\n",
" '[0.3.2]module的#后的parameter能悬停显示数值, 但内部parameter的不能',\n",
" '[0.3.2]param语法错误会弹右下角报错弹窗, TreeView刷新按钮无效',\n",
" '0.3.0版本后存在bug,构建项目后仿真无法运行',\n",
" '建议:模块例化可以基于文件夹来检索,当文件比较多时更整洁一点。',\n",
" 'Add Questa-Sim into the linter option',\n",
" '悬停提示对 /**/ 型的注释有误',\n",
" '仿真时因为文件夹名字存在空格产生错误',\n",
" 'filetype from json to jsonc (support comments)',\n",
" '语法识别错误',\n",
" 'code to doc',\n",
" 'WSL环境点击“显示当前文件的FSM图”时会发生扩展远程主机终止的错误',\n",
" '在架构里不能解析include后的模块',\n",
" 'treeview在文件移动时遇到问题',\n",
" 'treeview在文件变动时产生错误',\n",
" '[0.3.0 beta] iverilog指令错误',\n",
" 'Will verilator be supported?',\n",
" '[0.3.0 beta] 高亮颜色错误',\n",
" '状态机显示有问题',\n",
" 'xdc文件无法高亮显示',\n",
" 'Ubuntu环境下digital ide对配置环境有问题。',\n",
" '[0.3.0 beta] \"定义跳转\"定义位置出错',\n",
" '[0.3.0 beta] \"library导入文件\"模块解析报错',\n",
" '[0.3.0 beta] \"对当前文件进行仿真\"功能报错',\n",
" '[0.3.0 beta] 由于插件不能分析出使用了\"include\"语法, 从而导致Sim失败',\n",
" '[0.3.0 beta] 文件跳转功能失效',\n",
" '[0.3.0 beta] 例化中的[xx:xx]连线会导致后续颜色错误',\n",
" '[0.3.0 beta]\"显示当前文件的netlist\"无法正确显示出网表',\n",
" '[0.3.0 beta] Bitwidth of 1-bit signal is incorrectly recognized as \"Unknown\" in auto instantiation and auto document',\n",
" '[0.3.3 beta] 含参数模型的例化存在问题以及拓展快捷键失效的问题',\n",
" \"[0.3.2]数值悬停提示不支持'_'语法\",\n",
" '请问有没有verilog的学习PDF啊',\n",
" 'zlib',\n",
" '我只有标准文件',\n",
" '[不会吧]',\n",
" '哟西',\n",
" 'vivado中program device没有把程序固化进板子吧',\n",
" '调试时候下载bit流和ltx固化需要在xdc约束里面添加spi或bpi相关约束后生成bit流。建议一开始xdc约束当中就添加相关约束后续生成mcs文件需要先添加flash型号根据相关型号生成对应的mcs文件后再固化到flash当中。',\n",
" '符号能自己空格',\n",
" '换行能缩进4个字符',\n",
" '现在我们插件的 format 不足是什么',\n",
" 'github 的 discussion',\n",
" 'others',\n",
" '这种也就是在 scope 中进行缩进',\n",
" '不会呐',\n",
" '有点怪,这里没有和竖线对齐',\n",
" '我手动跳过',\n",
" '好了,',\n",
" '自动格式化出来就没对齐',\n",
" '手动能对齐',\n",
" ' 好了',\n",
" '哪个呀?',\n",
" 'https://github.com/Digital-EDA/Digital-IDE/discussions',\n",
" ' \\n 这个格式看着舒服,怎么设置的呀',\n",
" ' \\n 手动',\n",
" '我把wire和reg位置固定了',\n",
" '采用iverilog生成的VCD貌似无法解析仿真数据这是为什么',\n",
" ' 我们目前上线了一个 QA 机器人,正在迭代中,它会根据问题自动进行回答,并给出对应的链接。我们设置了拒答流,并不是所有问题它都会进行回答,当然,意图检查模块仍然在迭代中,我们需要更多的数据。',\n",
" '请问什么时候会支持verilator',\n",
" 'at 触发不够人性化,很多用户不知道 Tip 的存在就不会去使用它',\n",
" '支持verilator',\n",
" '应该不会',\n",
" '非问句会触发吗',\n",
" 'tql',\n",
" 'qwq',\n",
" '目前digital ide的netlist是用什么综合出的',\n",
" '目前Digital-IDE插件的netlist是用什么综合出的',\n",
" '微调不了,数据太少了',\n",
" '大模型只是其中的一个组件',\n",
" '摸了',\n",
" '这个可以伪造转发消息吗?',\n",
" '这个机器人有回答修正功能吗',\n",
" '我回去把格式发你',\n",
" '这个垃圾插件不好用',\n",
" '我靠,牛',\n",
" '但是可以帮我敲顺序的状态机定义和寄存器地址',\n",
" '或者自己写脚本解决',\n",
" '后台就要我有时间才能统一上传',\n",
" '网络问题应该是'],\n",
" [0,\n",
" 1,\n",
" 2,\n",
" 6,\n",
" 0,\n",
" 6,\n",
" 6,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 2,\n",
" 6,\n",
" 0,\n",
" 6,\n",
" 0,\n",
" 6,\n",
" 0,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 0,\n",
" 6,\n",
" 0,\n",
" 6,\n",
" 0,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 0,\n",
" 6,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 6,\n",
" 5,\n",
" 5,\n",
" 4,\n",
" 4,\n",
" 6,\n",
" 6,\n",
" 4,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 5,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 4,\n",
" 6,\n",
" 3,\n",
" 1,\n",
" 1,\n",
" 3,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 3,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 3,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 3,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 0,\n",
" 6,\n",
" 6,\n",
" 4,\n",
" 5,\n",
" 0,\n",
" 3,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 4,\n",
" 6,\n",
" 4,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 1,\n",
" 1,\n",
" 4,\n",
" 4,\n",
" 6,\n",
" 6,\n",
" 4,\n",
" 6,\n",
" 1,\n",
" 6,\n",
" 1,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 4,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 6,\n",
" 6,\n",
" 4,\n",
" 1,\n",
" 1,\n",
" 6,\n",
" 1,\n",
" 4,\n",
" 6,\n",
" 6,\n",
" 6,\n",
" 6])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"engine = PromptEngine('../config/story.yml')\n",
"engine.merge_stories_from_yml('../config/github-issue.story.yml')\n",
"engine.merge_stories_from_yml('../config/qq.story.yml')\n",
"\n",
"sentences = []\n",
"labels = []\n",
"for story in engine.stories:\n",
" sentences.append(story.message)\n",
" labels.append(engine.intent2id[story.intent])\n",
"sentences, labels"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(185, 768)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embedding = model.embed_documents(sentences)\n",
"embedding = np.array(embedding)\n",
"embedding.shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"tsne = TSNE(n_components=2)\n",
"plots = tsne.fit_transform(embedding)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f76e6a3d8d0>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"labels = np.array(labels)\n",
"for label in set(labels):\n",
" mask = labels == label\n",
" cor_plots = plots[mask]\n",
" plt.scatter(cor_plots[:, 0], cor_plots[:, 1], s=50, alpha=0.9, label=engine.id2intent[label])\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression()</pre></div></div></div></div></div>"
],
"text/plain": [
"LogisticRegression()"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"log_model = LogisticRegression()\n",
"log_model.fit(embedding, labels)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'others'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_sentence = ['咖啡喝不了,喝了胃不舒服']\n",
"test_embedding = model.embed_documents(test_sentence)\n",
"res = log_model.predict(test_embedding)[0]\n",
"engine.id2intent[res]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['../model/embedding_mapping.sklearn']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import joblib\n",
"joblib.dump(log_model, '../model/embedding_mapping.sklearn')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"log_model = joblib.load('../model/embedding_mapping.sklearn')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'others'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_sentence = ['咖啡喝不了,喝了胃不舒服']\n",
"test_embedding = model.embed_documents(test_sentence)\n",
"res = log_model.predict(test_embedding)[0]\n",
"engine.id2intent[res]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 尝试使用 EDL\n",
"\n",
"使用证据网络增加不确定性计算,详细可看: [EDLEvidential Deep Learning 原理与代码实现](https://kirigaya.cn/blog/article?seq=154)\n",
"\n",
"损失函数\n",
"\n",
"$$\n",
"\\mathcal L(\\theta) = \\sum_{i=1}^N \\mathcal L_i(\\theta) +\\lambda_t \\sum_{i=1}^N \\mathrm{KL}\\left(D(p_i|\\tilde{\\alpha}_i) || D(p_i | \\bold 1)\\right)\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import tqdm\n",
"\n",
"class SimpleERNN(nn.Module):\n",
" in_dim: int\n",
" out_dim: int\n",
" alpha_kl: float\n",
" def __init__(self, in_dim: int, out_dim: int, focal: int, alpha_kl: float):\n",
" super().__init__()\n",
" self.in_dim = in_dim\n",
" self.out_dim = out_dim\n",
" self.alpha_kl = alpha_kl\n",
" self.focal = focal\n",
" self.classifier = nn.Sequential(\n",
" nn.Linear(in_dim, out_dim),\n",
" nn.ELU(),\n",
" )\n",
" \n",
" def forward(self, inputs: torch.FloatTensor) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n",
" logits = self.classifier(inputs)\n",
" evidence = torch.exp(logits)\n",
" prob = F.normalize(evidence + 1, p=1, dim=1)\n",
" return evidence, prob\n",
"\n",
" def criterion(self, evidence: torch.FloatTensor, label: torch.LongTensor) -> torch.FloatTensor:\n",
" if len(label.shape) == 1:\n",
" label = F.one_hot(label, self.out_dim)\n",
" alpha = evidence + 1\n",
" alpha_0 = alpha.sum(1).unsqueeze(-1).repeat(1, self.out_dim)\n",
" loss_ece = torch.sum(label * (torch.digamma(alpha_0) - torch.digamma(alpha)), dim=1)\n",
" loss_ece = torch.mean(loss_ece)\n",
" if self.alpha_kl > 0:\n",
" tilde_alpha = label + (1 - label) * alpha\n",
" uncertainty_alpha = torch.ones_like(tilde_alpha).cuda()\n",
" estimate_dirichlet = torch.distributions.Dirichlet(tilde_alpha)\n",
" uncertainty_dirichlet = torch.distributions.Dirichlet(uncertainty_alpha)\n",
" kl = torch.distributions.kl_divergence(estimate_dirichlet, uncertainty_dirichlet)\n",
" loss_kl = torch.mean(kl)\n",
" else:\n",
" loss_kl = 0\n",
" return loss_ece + self.alpha_kl * loss_kl "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 200/200 [00:00<00:00, 230.43it/s]\n"
]
}
],
"source": [
"in_dim = embedding.shape[1]\n",
"out_dim = max(labels) + 1\n",
"enn_model = SimpleERNN(in_dim, out_dim, 0, 0)\n",
"optimizer = torch.optim.AdamW(enn_model.parameters(), lr=1e-3)\n",
"\n",
"bs = 64\n",
"sample_num = len(embedding)\n",
"sample_indice = np.arange(sample_num)\n",
"bs_num = int(np.ceil(sample_num / bs))\n",
"\n",
"training_losses = []\n",
"\n",
"for i in tqdm.trange(200):\n",
" alpha_kl = min(0.9, i / 10)\n",
" np.random.shuffle(sample_indice) \n",
" train_loss = 0\n",
" for bs_i in range(bs_num):\n",
" start = bs_i * bs\n",
" end = min(sample_num, start + bs)\n",
" data_indice = sample_indice[start: end]\n",
" data = torch.FloatTensor(embedding[data_indice])\n",
" label = torch.LongTensor(labels[data_indice])\n",
" evidence, prob = enn_model(data)\n",
" loss = enn_model.criterion(evidence, label)\n",
" train_loss += loss.item()\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" training_losses.append(train_loss / bs_num)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f76e602f190>]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(training_losses)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"torch.save(enn_model.state_dict(), '../model/intent.enn.pth')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"in_dim = embedding.shape[1]\n",
"out_dim = max(labels) + 1\n",
"enn_model = SimpleERNN(in_dim, out_dim, 0, 0)\n",
"state_dict = torch.load('../model/intent.enn.pth')\n",
"enn_model.load_state_dict(state_dict)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0.0972]),\n",
" tensor([[0.2197, 0.1341, 0.0202, 0.0307, 0.0252, 0.0267, 0.5433]]),\n",
" tensor([[0.2059, 0.1202, 0.0064, 0.0168, 0.0113, 0.0129, 0.5295]]))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embd = model.embed_documents(['其实是计划给linux适配的linux和mac都是posix接口那不是自然就适配mac了吗'])\n",
"embd = torch.FloatTensor(embd)\n",
"with torch.no_grad():\n",
" evidence, prob = enn_model(embd)\n",
"\n",
"e = evidence\n",
"alpha = e + 1\n",
"S = alpha.sum(1)\n",
"b = e / S\n",
"u = out_dim / S\n",
"u, prob, b"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"usage usage √ tensor([0.0501])\n",
"bug usage,bug √ tensor([0.0773])\n",
"bug usage,bug √ tensor([0.0758])\n",
"others others √ tensor([0.1678])\n",
"others others √ tensor([0.0887])\n",
"bug usage,bug √ tensor([0.0902])\n",
"others others √ tensor([0.0453])\n",
"others others √ tensor([0.0424])\n",
"others others √ tensor([0.1416])\n",
"others others √ tensor([0.1441])\n",
"bug usage,bug,others √ tensor([0.1615])\n",
"usage usage √ tensor([0.0562])\n",
"others usage,bug,others √ tensor([0.0820])\n",
"others usage,others √ tensor([0.0798])\n",
"others others √ tensor([0.1282])\n",
"others others √ tensor([0.1034])\n",
"others others √ tensor([0.0967])\n",
"expression expression √ tensor([0.0802])\n"
]
}
],
"source": [
"test_suite = [\n",
" { 'input': '如何使用 digital ide 这个插件?', 'expect': 'usage' },\n",
" { 'input': '我今天打开 vscode发现 自动补全失效了,我是哪里没有配置好吗?', 'expect': 'usage,bug' },\n",
" { 'input': 'path top.v is not a hdlFile 请问报这个错误大概是啥原因啊', 'expect': 'usage,bug' },\n",
" { 'input': '我同学在学习强国看到小麦收割了,然后就买相应的股就赚了', 'expect': 'others' },\n",
" { 'input': '我平时写代码就喜欢喝茶', 'expect': 'others' },\n",
" { 'input': '请问报这个错误大概是啥原因啊', 'expect': 'usage,bug' },\n",
" { 'input': '感觉现在啥都在往AI靠', 'expect': 'others' },\n",
" { 'input': '别人设置的肯定有点不合适自己的', 'expect': 'others' },\n",
" { 'input': '在企业里面最大的问题是碰见傻逼怎么办?', 'expect': 'others' },\n",
" { 'input': '几乎完全不喝牛奶2333', 'expect': 'others' },\n",
" { 'input': 'command not found: python', 'expect': 'usage,bug,others' },\n",
" { 'input': '兄弟们有没有C语言绘图库推荐', 'expect': 'usage' },\n",
" { 'input': '我早上开着机去打论文 回来发现我电脑切换到Linux了', 'expect': 'usage,bug,others' },\n",
" { 'input': '我在Windows下遇到的只要问题就是对于C程序包管理和编译管理器偶尔会不认识彼此但除此之外都很安稳win11除外', 'expect': 'usage,others' },\n",
" { 'input': '我的反撤回还能用', 'expect': 'others' },\n",
" { 'input': '因为这是养蛊的虚拟机,放了些国产垃圾软件,得用国产流氓之王才能镇得住他们', 'expect': 'others' },\n",
" { 'input': '你咋装了个360', 'expect': 'others' },\n",
" { 'input': '', 'expect': 'expression' },\n",
"]\n",
"for test in test_suite:\n",
" embd = model.embed_documents([test['input']])\n",
" embd = torch.FloatTensor(embd)\n",
" with torch.no_grad():\n",
" evidence, prob = enn_model(embd)\n",
"\n",
" e = evidence\n",
" alpha = e + 1\n",
" S = alpha.sum(1)\n",
" b = e / S\n",
" u = out_dim / S\n",
" pre_label = prob.argmax(1)\n",
" name = engine.id2intent[pre_label[0].item()]\n",
" ok = '√' if name in test['expect'] else '×'\n",
" print(name, test['expect'], ok, u)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}