83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import os
|
|
import json
|
|
|
|
import requests as r
|
|
|
|
from core import TreeIntent, logger
|
|
|
|
|
|
class ErineIntent(TreeIntent):
|
|
api_key: str
|
|
secret_key: str
|
|
access_token: str
|
|
def __init__(self, path: str, api_key: str = None, secret_key: str = None) -> None:
|
|
super().__init__(path)
|
|
self.api_key = api_key or os.environ['BAIDU_API_KEY']
|
|
self.secret_key = secret_key or os.environ['BAIDU_SECRET_KEY']
|
|
|
|
try:
|
|
self.access_token = self.get_access_token()
|
|
except Exception as e:
|
|
raise ValueError('fail to get access token in initialization')
|
|
|
|
|
|
def get_access_token(self):
|
|
headers = {
|
|
'Content-Type': 'application/json',
|
|
'Accept': 'application/json'
|
|
}
|
|
api_key = self.api_key
|
|
secret_key = self.secret_key
|
|
url = f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}'
|
|
payload = json.dumps("")
|
|
|
|
res = r.post(
|
|
url=url,
|
|
data=payload,
|
|
headers=headers
|
|
)
|
|
|
|
resJson = res.json()
|
|
access_token = resJson.get('access_token')
|
|
assert isinstance(access_token, str), 'access_token 获取失败,详细信息' + str(resJson)
|
|
return access_token
|
|
|
|
def post_message(self, message: list[dict]):
|
|
headers = {
|
|
'Content-Type': 'application/json'
|
|
}
|
|
payload = json.dumps({
|
|
'messages': message,
|
|
'penalty_score': 2.0
|
|
})
|
|
url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k?access_token=' + self.access_token
|
|
return r.post(url, headers=headers, data=payload)
|
|
|
|
def call_llm(self, message: list[dict]) -> str:
|
|
try:
|
|
res = self.post_message(message)
|
|
except Exception:
|
|
self.access_token = self.get_access_token()
|
|
res = self.post_message(message)
|
|
try:
|
|
return res.json()['result']
|
|
except Exception as e:
|
|
logger.error('get error when parse response of wenxinyiyan: ' + str(e))
|
|
logger.debug(res.json())
|
|
return None
|
|
|
|
|
|
if __name__ == '__main__':
|
|
erine = ErineIntent('./config/story.yml')
|
|
result = []
|
|
for i in range(20):
|
|
nodes = erine.inference('那不就是rv芯片往上堆扩展吗')
|
|
if nodes is None:
|
|
print('none -> ohters')
|
|
else:
|
|
node = nodes[0]
|
|
result.append(node.name)
|
|
print(node.name)
|
|
|
|
from collections import Counter
|
|
print(Counter(result)) |