{ "cells": [ { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from sklearn.manifold import TSNE\n", "import yaml\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import sys\n", "import os\n", "\n", "sys.path.append(os.path.abspath('..'))\n", "from prompt import PromptEngine" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data/zhelonghuang/miniconda3/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "/data/zhelonghuang/miniconda3/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] } ], "source": [ "# from BCEmbedding import EmbeddingModel\n", "from langchain_community.embeddings import HuggingFaceEmbeddings\n", "sentences = ['python 是什么', '请介绍一下 python']\n", "# model = EmbeddingModel(model_name_or_path=\"maidalun1020/bce-embedding-base_v1\")\n", "model = HuggingFaceEmbeddings(model_name='maidalun1020/bce-embedding-base_v1')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "embeddings = model.embed_documents(sentences)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(['请问 property.json 如何配置?',\n", " '我的自动补全无法使用,是不是有bug?',\n", " '帮我上传一下这份数据',\n", " 'surface了解一下?',\n", " '大佬们,为啥我的digital ide启动之后所有功能都没启动捏?我配置了property文件,然后插件的vivado路经和modelsim路经都加上了',\n", " '这群要被chisel夺舍了吗',\n", " 'Metals一开直接报错',\n", " '话说digital-ide打开大的verilog卡死了',\n", " '请问一下,第一次点击对文件仿真可以出波形文件,再次点击的时候就会提示unknown module type了。是哪个配置没配置好?',\n", " '怎么调整是哪个版本的vivado来构建工程呢',\n", " '咱们这个插件win7的vscode是不是只能用很早之前的版本',\n", " '帮我将这份数据保存到服务器上',\n", " '他这个意思是 单个功耗很低 但是功耗低那肯定性能就寄 频率肯定不高 靠人多'],\n", " [1, 2, 3, 4, 1, 4, 4, 2, 1, 1, 1, 3, 4])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "engine = PromptEngine('../config/story.yml')\n", "\n", "sentences = []\n", "labels = []\n", "for story in engine.stories:\n", " sentences.append(story.message)\n", " labels.append(engine.intent2id[story.intent])\n", "sentences, labels" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(13, 768)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embedding = model.embed_documents(sentences)\n", "embedding = np.array(embedding)\n", "embedding.shape" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "tsne = TSNE(n_components=2, perplexity=3)\n", "plots = tsne.fit_transform(embedding)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "labels = np.array(labels)\n", "for label in set(labels):\n", " mask = labels == label\n", " cor_plots = plots[mask]\n", " plt.scatter(cor_plots[:, 0], cor_plots[:, 1], s=50, alpha=0.9, label=label)\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LogisticRegression()" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "log_model = LogisticRegression()\n", "log_model.fit(embedding, labels)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([4])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_sentence = ['咖啡喝不了,喝了胃不舒服']\n", "test_embedding = model.embed_documents(test_sentence)\n", "log_model.predict(test_embedding)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['../embedding_mapping.sklearn']" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import joblib\n", "joblib.dump(log_model, '../embedding_mapping.sklearn')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "log_model = joblib.load('../model/embedding_mapping.sklearn')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([4])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_sentence = ['咖啡喝不了,喝了胃不舒服']\n", "test_embedding = model.embed_documents(test_sentence)\n", "log_model.predict(test_embedding)" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }