83 lines
2.6 KiB
Python

import os
import json
import requests as r
from core import TreeIntent, logger
class ErineIntent(TreeIntent):
api_key: str
secret_key: str
access_token: str
def __init__(self, path: str, api_key: str = None, secret_key: str = None) -> None:
super().__init__(path)
self.api_key = api_key or os.environ['BAIDU_API_KEY']
self.secret_key = secret_key or os.environ['BAIDU_SECRET_KEY']
try:
self.access_token = self.get_access_token()
except Exception as e:
raise ValueError('fail to get access token in initialization')
def get_access_token(self):
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
api_key = self.api_key
secret_key = self.secret_key
url = f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}'
payload = json.dumps("")
res = r.post(
url=url,
data=payload,
headers=headers
)
resJson = res.json()
access_token = resJson.get('access_token')
assert isinstance(access_token, str), 'access_token 获取失败,详细信息' + str(resJson)
return access_token
def post_message(self, message: list[dict]):
headers = {
'Content-Type': 'application/json'
}
payload = json.dumps({
'messages': message,
'penalty_score': 2.0
})
url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k?access_token=' + self.access_token
return r.post(url, headers=headers, data=payload)
def call_llm(self, message: list[dict]) -> str:
try:
res = self.post_message(message)
except Exception:
self.access_token = self.get_access_token()
res = self.post_message(message)
try:
return res.json()['result']
except Exception as e:
logger.error('get error when parse response of wenxinyiyan: ' + str(e))
logger.debug(res.json())
return None
if __name__ == '__main__':
erine = ErineIntent('./config/story.yml')
result = []
for i in range(20):
nodes = erine.inference('那不就是rv芯片往上堆扩展吗')
if nodes is None:
print('none -> ohters')
else:
node = nodes[0]
result.append(node.name)
print(node.name)
from collections import Counter
print(Counter(result))