refactor: 重写 LLM OpenAI 模块
This commit is contained in:
@@ -84,7 +84,7 @@ def init(cfg):
|
|||||||
global _global_object
|
global _global_object
|
||||||
|
|
||||||
# 迁移旧配置
|
# 迁移旧配置
|
||||||
gu.try_migrate_config(cfg)
|
gu.try_migrate_config()
|
||||||
# 使用新配置
|
# 使用新配置
|
||||||
cfg = cc.get_all()
|
cfg = cc.get_all()
|
||||||
|
|
||||||
@@ -105,6 +105,15 @@ def init(cfg):
|
|||||||
cc.put("reply_prefix", "")
|
cc.put("reply_prefix", "")
|
||||||
else:
|
else:
|
||||||
_global_object.reply_prefix = cfg['reply_prefix']
|
_global_object.reply_prefix = cfg['reply_prefix']
|
||||||
|
|
||||||
|
default_personality_str = cc.get("default_personality_str", "")
|
||||||
|
if default_personality_str == "":
|
||||||
|
_global_object.default_personality = None
|
||||||
|
else:
|
||||||
|
_global_object.default_personality = {
|
||||||
|
"name": "default",
|
||||||
|
"prompt": default_personality_str,
|
||||||
|
}
|
||||||
|
|
||||||
# 语言模型提供商
|
# 语言模型提供商
|
||||||
logger.info("正在载入语言模型...")
|
logger.info("正在载入语言模型...")
|
||||||
@@ -122,6 +131,10 @@ def init(cfg):
|
|||||||
llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal"))
|
llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal"))
|
||||||
chosen_provider = OPENAI_OFFICIAL
|
chosen_provider = OPENAI_OFFICIAL
|
||||||
|
|
||||||
|
instance = llm_instance[OPENAI_OFFICIAL]
|
||||||
|
assert isinstance(instance, ProviderOpenAIOfficial)
|
||||||
|
instance.personality_set(_global_object.default_personality, session_id=None)
|
||||||
|
|
||||||
# 检查provider设置偏好
|
# 检查provider设置偏好
|
||||||
p = cc.get("chosen_provider", None)
|
p = cc.get("chosen_provider", None)
|
||||||
if p is not None and p in llm_instance:
|
if p is not None and p in llm_instance:
|
||||||
@@ -197,14 +210,6 @@ def init(cfg):
|
|||||||
cfg, _global_object), daemon=True).start()
|
cfg, _global_object), daemon=True).start()
|
||||||
platform_str += "QQ_OFFICIAL,"
|
platform_str += "QQ_OFFICIAL,"
|
||||||
|
|
||||||
default_personality_str = cc.get("default_personality_str", "")
|
|
||||||
if default_personality_str == "":
|
|
||||||
_global_object.default_personality = None
|
|
||||||
else:
|
|
||||||
_global_object.default_personality = {
|
|
||||||
"name": "default",
|
|
||||||
"prompt": default_personality_str,
|
|
||||||
}
|
|
||||||
# 初始化dashboard
|
# 初始化dashboard
|
||||||
_global_object.dashboard_data = DashBoardData(
|
_global_object.dashboard_data = DashBoardData(
|
||||||
stats={},
|
stats={},
|
||||||
@@ -430,12 +435,12 @@ async def oper_msg(message: AstrBotMessage,
|
|||||||
official_fc = chosen_provider == OPENAI_OFFICIAL
|
official_fc = chosen_provider == OPENAI_OFFICIAL
|
||||||
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
|
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
|
||||||
else:
|
else:
|
||||||
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality=_global_object.default_personality)
|
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url)
|
||||||
|
|
||||||
llm_result_str = _global_object.reply_prefix + llm_result_str
|
llm_result_str = _global_object.reply_prefix + llm_result_str
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.info(f"调用异常:{traceback.format_exc()}")
|
logger.error(f"调用异常:{traceback.format_exc()}")
|
||||||
return MessageResult(f"调用语言模型例程时出现异常。原因: {str(e)}")
|
return MessageResult(f"调用异常。详细原因:{str(e)}")
|
||||||
|
|
||||||
# 切换回原来的语言模型
|
# 切换回原来的语言模型
|
||||||
if temp_switch != "":
|
if temp_switch != "":
|
||||||
@@ -458,14 +463,10 @@ async def oper_msg(message: AstrBotMessage,
|
|||||||
return MessageResult(f"指令调用错误: \n{str(command_result[1])}")
|
return MessageResult(f"指令调用错误: \n{str(command_result[1])}")
|
||||||
|
|
||||||
# 画图指令
|
# 画图指令
|
||||||
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw':
|
if command == 'draw':
|
||||||
for i in command_result[1]:
|
# 保存到本地
|
||||||
# 保存到本地
|
path = await gu.download_image_by_url(command_result[1])
|
||||||
async with aiohttp.ClientSession() as session:
|
return MessageResult([Image.fromFileSystem(path)])
|
||||||
async with session.get(i) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
image = PILImage.open(io.BytesIO(await resp.read()))
|
|
||||||
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
|
|
||||||
# 其他指令
|
# 其他指令
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -223,8 +223,6 @@ class Command:
|
|||||||
"nick": "设置机器人昵称",
|
"nick": "设置机器人昵称",
|
||||||
"plugin": "插件安装、卸载和重载",
|
"plugin": "插件安装、卸载和重载",
|
||||||
"web on/off": "LLM 网页搜索能力",
|
"web on/off": "LLM 网页搜索能力",
|
||||||
"reset": "重置 LLM 对话",
|
|
||||||
"/gpt": "切换到 OpenAI 官方接口"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None):
|
async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None):
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ class CommandOpenAIOfficial(Command):
|
|||||||
self.commands = [
|
self.commands = [
|
||||||
CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"),
|
CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"),
|
||||||
CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"),
|
CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"),
|
||||||
CommandItem("status", self.gpt, "查看 GPT 配置信息和用量状态。", "内置"),
|
CommandItem("status", self.status, "查看 GPT 配置信息和用量状态。", "内置"),
|
||||||
|
|
||||||
]
|
]
|
||||||
super().__init__(provider, global_object)
|
super().__init__(provider, global_object)
|
||||||
|
|
||||||
@@ -59,8 +58,6 @@ class CommandOpenAIOfficial(Command):
|
|||||||
return True, self.update(message, role)
|
return True, self.update(message, role)
|
||||||
elif self.command_start_with(message, "画", "draw"):
|
elif self.command_start_with(message, "画", "draw"):
|
||||||
return True, await self.draw(message)
|
return True, await self.draw(message)
|
||||||
elif self.command_start_with(message, "key"):
|
|
||||||
return True, self.key(message)
|
|
||||||
elif self.command_start_with(message, "switch"):
|
elif self.command_start_with(message, "switch"):
|
||||||
return True, await self.switch(message)
|
return True, await self.switch(message)
|
||||||
elif self.command_start_with(message, "models"):
|
elif self.command_start_with(message, "models"):
|
||||||
@@ -87,12 +84,13 @@ class CommandOpenAIOfficial(Command):
|
|||||||
|
|
||||||
async def help(self):
|
async def help(self):
|
||||||
commands = super().general_commands()
|
commands = super().general_commands()
|
||||||
commands['画'] = '画画'
|
commands['画'] = '调用 OpenAI DallE 模型生成图片'
|
||||||
commands['key'] = '添加OpenAI key'
|
|
||||||
commands['set'] = '人格设置面板'
|
commands['set'] = '人格设置面板'
|
||||||
commands['gpt'] = '查看gpt配置信息'
|
commands['status'] = '查看 Api Key 状态和配置信息'
|
||||||
commands['status'] = '查看key使用状态'
|
commands['token'] = '查看本轮会话 token'
|
||||||
commands['token'] = '查看本轮会话token'
|
commands['reset'] = '重置当前与 LLM 的会话'
|
||||||
|
commands['reset p'] = '重置当前与 LLM 的会话,但保留人格(system prompt)'
|
||||||
|
|
||||||
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
||||||
|
|
||||||
async def reset(self, session_id: str, message: str = "reset"):
|
async def reset(self, session_id: str, message: str = "reset"):
|
||||||
@@ -103,66 +101,34 @@ class CommandOpenAIOfficial(Command):
|
|||||||
await self.provider.forget(session_id)
|
await self.provider.forget(session_id)
|
||||||
return True, "重置成功", "reset"
|
return True, "重置成功", "reset"
|
||||||
if len(l) == 2 and l[1] == "p":
|
if len(l) == 2 and l[1] == "p":
|
||||||
self.provider.forget(session_id)
|
await self.provider.forget(session_id, keep_system_prompt=True)
|
||||||
if self.personality_str != "":
|
|
||||||
self.set(self.personality_str, session_id) # 重新设置人格
|
|
||||||
return True, "重置成功", "reset"
|
|
||||||
|
|
||||||
def his(self, message: str, session_id: str):
|
def his(self, message: str, session_id: str):
|
||||||
if self.provider is None:
|
if self.provider is None:
|
||||||
return False, "未启用 OpenAI 官方 API", "his"
|
return False, "未启用 OpenAI 官方 API", "his"
|
||||||
# 分页,每页5条
|
|
||||||
msg = ''
|
|
||||||
size_per_page = 3
|
size_per_page = 3
|
||||||
page = 1
|
page = 1
|
||||||
if message[4:]:
|
l = message.split(" ")
|
||||||
page = int(message[4:])
|
if len(l) == 2:
|
||||||
# 检查是否有过历史记录
|
try:
|
||||||
if session_id not in self.provider.session_dict:
|
page = int(l[1])
|
||||||
msg = f"历史记录为空"
|
except BaseException as e:
|
||||||
return True, msg, "his"
|
return True, "页码不合法", "his"
|
||||||
l = self.provider.session_dict[session_id]
|
contexts, total_num = self.provider.dump_contexts_page(size_per_page, page=page)
|
||||||
max_page = len(l)//size_per_page + \
|
t_pages = total_num // size_per_page + 1
|
||||||
1 if len(l) % size_per_page != 0 else len(l)//size_per_page
|
return True, f"历史记录如下:\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n*输入 /his 2 跳转到第 2 页", "his"
|
||||||
p = self.provider.get_prompts_by_cache_list(
|
|
||||||
self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
|
|
||||||
return True, f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页", "his"
|
|
||||||
|
|
||||||
def status(self):
|
def status(self):
|
||||||
if self.provider is None:
|
if self.provider is None:
|
||||||
return False, "未启用 OpenAI 官方 API", "status"
|
return False, "未启用 OpenAI 官方 API", "status"
|
||||||
chatgpt_cfg_str = ""
|
keys_data = self.provider.get_keys_data()
|
||||||
key_stat = self.provider.get_key_stat()
|
ret = "OpenAI Key"
|
||||||
index = 1
|
for k in keys_data:
|
||||||
max = 9000000
|
status = "🟢" if keys_data[k]['status'] == 0 else "🔴"
|
||||||
gg_count = 0
|
ret += "\n|- " + k[:8] + " " + status
|
||||||
total = 0
|
|
||||||
tag = ''
|
|
||||||
for key in key_stat.keys():
|
|
||||||
sponsor = ''
|
|
||||||
total += key_stat[key]['used']
|
|
||||||
if key_stat[key]['exceed']:
|
|
||||||
gg_count += 1
|
|
||||||
continue
|
|
||||||
if 'sponsor' in key_stat[key]:
|
|
||||||
sponsor = key_stat[key]['sponsor']
|
|
||||||
chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n"
|
|
||||||
index += 1
|
|
||||||
return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status"
|
|
||||||
|
|
||||||
def key(self, message: str):
|
conf = self.provider.get_configs()
|
||||||
if self.provider is None:
|
ret += "\n当前模型:" + conf['model']
|
||||||
return False, "未启用 OpenAI 官方 API", "reset"
|
|
||||||
l = message.split(" ")
|
|
||||||
if len(l) == 1:
|
|
||||||
msg = "感谢您赞助key,key为官方API使用,请以以下格式赞助:\n/key xxxxx"
|
|
||||||
return True, msg, "key"
|
|
||||||
key = l[1]
|
|
||||||
if self.provider.check_key(key):
|
|
||||||
self.provider.append_key(key)
|
|
||||||
return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~"
|
|
||||||
else:
|
|
||||||
return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key"
|
|
||||||
|
|
||||||
async def switch(self, message: str):
|
async def switch(self, message: str):
|
||||||
'''
|
'''
|
||||||
@@ -179,14 +145,13 @@ class CommandOpenAIOfficial(Command):
|
|||||||
return True, ret, "switch"
|
return True, ret, "switch"
|
||||||
elif len(l) == 2:
|
elif len(l) == 2:
|
||||||
try:
|
try:
|
||||||
key_stat = self.provider.get_key_stat()
|
key_stat = self.provider.get_keys_data()
|
||||||
index = int(l[1])
|
index = int(l[1])
|
||||||
if index > len(key_stat) or index < 1:
|
if index > len(key_stat) or index < 1:
|
||||||
return True, "账号序号不合法。", "switch"
|
return True, "账号序号不合法。", "switch"
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
new_key = list(key_stat.keys())[index-1]
|
new_key = list(key_stat.keys())[index-1]
|
||||||
ret = await self.provider.check_key(new_key)
|
|
||||||
self.provider.set_key(new_key)
|
self.provider.set_key(new_key)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
return True, "账号切换失败,原因: " + str(e), "switch"
|
return True, "账号切换失败,原因: " + str(e), "switch"
|
||||||
@@ -235,58 +200,22 @@ class CommandOpenAIOfficial(Command):
|
|||||||
'name': ps,
|
'name': ps,
|
||||||
'prompt': personalities[ps]
|
'prompt': personalities[ps]
|
||||||
}
|
}
|
||||||
self.provider.session_dict[session_id] = []
|
self.provider.personality_set(ps, session_id)
|
||||||
new_record = {
|
|
||||||
"user": {
|
|
||||||
"role": "user",
|
|
||||||
"content": personalities[ps],
|
|
||||||
},
|
|
||||||
"AI": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "好的,接下来我会扮演这个角色。"
|
|
||||||
},
|
|
||||||
'type': "personality",
|
|
||||||
'usage_tokens': 0,
|
|
||||||
'single-tokens': 0
|
|
||||||
}
|
|
||||||
self.provider.session_dict[session_id].append(new_record)
|
|
||||||
self.personality_str = message
|
|
||||||
return True, f"人格{ps}已设置。", "set"
|
return True, f"人格{ps}已设置。", "set"
|
||||||
else:
|
else:
|
||||||
self.provider.curr_personality = {
|
self.provider.curr_personality = {
|
||||||
'name': '自定义人格',
|
'name': '自定义人格',
|
||||||
'prompt': ps
|
'prompt': ps
|
||||||
}
|
}
|
||||||
new_record = {
|
self.provider.personality_set(ps, session_id)
|
||||||
"user": {
|
|
||||||
"role": "user",
|
|
||||||
"content": ps,
|
|
||||||
},
|
|
||||||
"AI": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "好的,接下来我会扮演这个角色。"
|
|
||||||
},
|
|
||||||
'type': "personality",
|
|
||||||
'usage_tokens': 0,
|
|
||||||
'single-tokens': 0
|
|
||||||
}
|
|
||||||
self.provider.session_dict[session_id] = []
|
|
||||||
self.provider.session_dict[session_id].append(new_record)
|
|
||||||
self.personality_str = message
|
|
||||||
return True, f"自定义人格已设置。 \n人格信息: {ps}", "set"
|
return True, f"自定义人格已设置。 \n人格信息: {ps}", "set"
|
||||||
|
|
||||||
async def draw(self, message):
|
async def draw(self, message: str):
|
||||||
if self.provider is None:
|
if self.provider is None:
|
||||||
return False, "未启用 OpenAI 官方 API", "draw"
|
return False, "未启用 OpenAI 官方 API", "draw"
|
||||||
if message.startswith("/画"):
|
if message.startswith("/画"):
|
||||||
message = message[2:]
|
message = message[2:]
|
||||||
elif message.startswith("画"):
|
elif message.startswith("画"):
|
||||||
message = message[1:]
|
message = message[1:]
|
||||||
try:
|
img_url = await self.provider.image_generate(message)
|
||||||
# 画图模式传回3个参数
|
return True, img_url, "draw"
|
||||||
img_url = await self.provider.image_chat(message)
|
|
||||||
return True, img_url, "draw"
|
|
||||||
except Exception as e:
|
|
||||||
if 'exceeded' in str(e):
|
|
||||||
return f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)"
|
|
||||||
return False, f"图片生成失败: {e}", "draw"
|
|
||||||
@@ -5,10 +5,12 @@ import time
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
|
import base64
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.images_response import ImagesResponse
|
from openai.types.images_response import ImagesResponse
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
from openai._exceptions import *
|
||||||
|
|
||||||
from cores.database.conn import dbConn
|
from cores.database.conn import dbConn
|
||||||
from model.provider.provider import Provider
|
from model.provider.provider import Provider
|
||||||
@@ -16,84 +18,94 @@ from util import general_utils as gu
|
|||||||
from util.cmd_config import CmdConfig
|
from util.cmd_config import CmdConfig
|
||||||
from SparkleLogging.utils.core import LogManager
|
from SparkleLogging.utils.core import LogManager
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
"gpt-4o": 128000,
|
||||||
|
"gpt-4o-2024-05-13": 128000,
|
||||||
|
"gpt-4-turbo": 128000,
|
||||||
|
"gpt-4-turbo-2024-04-09": 128000,
|
||||||
|
"gpt-4-turbo-preview": 128000,
|
||||||
|
"gpt-4-0125-preview": 128000,
|
||||||
|
"gpt-4-1106-preview": 128000,
|
||||||
|
"gpt-4-vision-preview": 128000,
|
||||||
|
"gpt-4-1106-vision-preview": 128000,
|
||||||
|
"gpt-4": 8192,
|
||||||
|
"gpt-4-0613": 8192,
|
||||||
|
"gpt-4-32k": 32768,
|
||||||
|
"gpt-4-32k-0613": 32768,
|
||||||
|
"gpt-3.5-turbo-0125": 16385,
|
||||||
|
"gpt-3.5-turbo": 16385,
|
||||||
|
"gpt-3.5-turbo-1106": 16385,
|
||||||
|
"gpt-3.5-turbo-instruct": 4096,
|
||||||
|
"gpt-3.5-turbo-16k": 16385,
|
||||||
|
"gpt-3.5-turbo-0613": 16385,
|
||||||
|
"gpt-3.5-turbo-16k-0613": 16385,
|
||||||
|
}
|
||||||
|
|
||||||
class ProviderOpenAIOfficial(Provider):
|
class ProviderOpenAIOfficial(Provider):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg) -> None:
|
||||||
self.cc = CmdConfig()
|
super().__init__()
|
||||||
|
|
||||||
self.key_list = []
|
os.makedirs("data/openai", exist_ok=True)
|
||||||
# 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错
|
|
||||||
for key in cfg['key']:
|
|
||||||
if len(key) == 1:
|
|
||||||
raise BaseException(
|
|
||||||
"检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。")
|
|
||||||
if cfg['key'] != '' and cfg['key'] != None:
|
|
||||||
self.key_list = cfg['key']
|
|
||||||
if len(self.key_list) == 0:
|
|
||||||
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
|
|
||||||
|
|
||||||
self.key_stat = {}
|
self.cc = CmdConfig
|
||||||
for k in self.key_list:
|
self.key_data_path = "data/openai/keys.json"
|
||||||
self.key_stat[k] = {'exceed': False, 'used': 0}
|
self.api_keys = []
|
||||||
|
self.chosen_api_key = None
|
||||||
|
self.base_url = None
|
||||||
|
self.keys_data = {} # 记录超额
|
||||||
|
|
||||||
self.api_base = None
|
if cfg['key']: self.api_keys = cfg['key']
|
||||||
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
if cfg['api_base']: self.base_url = cfg['api_base']
|
||||||
self.api_base = cfg['api_base']
|
if not self.api_keys:
|
||||||
logger.info(f"设置 api_base 为: {self.api_base}")
|
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
|
||||||
|
else:
|
||||||
|
self.chosen_api_key = self.api_keys[0]
|
||||||
|
|
||||||
|
for key in self.api_keys:
|
||||||
|
self.keys_data[key] = True
|
||||||
|
|
||||||
# 创建 OpenAI Client
|
|
||||||
self.client = AsyncOpenAI(
|
self.client = AsyncOpenAI(
|
||||||
api_key=self.key_list[0],
|
api_key=self.chosen_api_key,
|
||||||
base_url=self.api_base
|
base_url=self.base_url
|
||||||
)
|
)
|
||||||
|
self.model_configs: Dict = cfg['chatGPTConfigs']
|
||||||
self.openai_model_configs: dict = cfg['chatGPTConfigs']
|
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
|
||||||
self.openai_configs = cfg
|
self.session_memory: Dict[str, List] = {} # 会话记忆
|
||||||
# 会话缓存
|
self.session_memory_lock = threading.Lock()
|
||||||
self.session_dict = {}
|
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
|
||||||
# 最大缓存token
|
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
|
||||||
self.max_tokens = cfg['total_tokens_limit']
|
self.curr_personality = {
|
||||||
# 历史记录持久化间隔时间
|
"name": "default",
|
||||||
self.history_dump_interval = 20
|
"prompt": "你是一个很有帮助的 AI 助手。"
|
||||||
|
}
|
||||||
self.enc = tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
# 从 SQLite DB 读取历史记录
|
# 从 SQLite DB 读取历史记录
|
||||||
try:
|
try:
|
||||||
db1 = dbConn()
|
db1 = dbConn()
|
||||||
for session in db1.get_all_session():
|
for session in db1.get_all_session():
|
||||||
self.session_dict[session[0]] = json.loads(session[1])['data']
|
self.session_memory_lock.acquire()
|
||||||
logger.info("读取历史记录成功。")
|
self.session_memory[session[0]] = json.loads(session[1])['data']
|
||||||
|
self.session_memory_lock.release()
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.info("读取历史记录失败,但不影响使用。")
|
logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||||
|
|
||||||
# 创建转储定时器线程
|
# 定时保存历史记录
|
||||||
threading.Thread(target=self.dump_history, daemon=True).start()
|
threading.Thread(target=self.dump_history, daemon=True).start()
|
||||||
|
|
||||||
# 人格
|
|
||||||
self.curr_personality = {}
|
|
||||||
|
|
||||||
def make_tmp_client(self, api_key: str, base_url: str):
|
|
||||||
return AsyncOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url
|
|
||||||
)
|
|
||||||
|
|
||||||
# 转储历史记录
|
|
||||||
def dump_history(self):
|
def dump_history(self):
|
||||||
|
'''
|
||||||
|
转储历史记录
|
||||||
|
'''
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
db = dbConn()
|
db = dbConn()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# print("转储历史记录...")
|
for key in self.session_memory:
|
||||||
for key in self.session_dict:
|
data = self.session_memory[key]
|
||||||
data = self.session_dict[key]
|
|
||||||
data_json = {
|
data_json = {
|
||||||
'data': data
|
'data': data
|
||||||
}
|
}
|
||||||
@@ -101,309 +113,338 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
db.update_session(key, json.dumps(data_json))
|
db.update_session(key, json.dumps(data_json))
|
||||||
else:
|
else:
|
||||||
db.insert_session(key, json.dumps(data_json))
|
db.insert_session(key, json.dumps(data_json))
|
||||||
# print("转储历史记录完毕")
|
logger.debug("已保存 OpenAI 会话历史记录")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
print(e)
|
print(e)
|
||||||
# 每隔10分钟转储一次
|
finally:
|
||||||
time.sleep(10*self.history_dump_interval)
|
time.sleep(10*60)
|
||||||
|
|
||||||
def personality_set(self, default_personality: dict, session_id: str):
|
def personality_set(self, default_personality: dict, session_id: str):
|
||||||
|
if not default_personality: return
|
||||||
|
if session_id not in self.session_memory: self.session_memory[session_id] = []
|
||||||
self.curr_personality = default_personality
|
self.curr_personality = default_personality
|
||||||
|
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
|
||||||
|
tokens_num = len(encoded_prompt)
|
||||||
|
model = self.model_configs['model']
|
||||||
|
if model in MODELS and tokens_num > MODELS[model] - 800:
|
||||||
|
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 800])
|
||||||
|
|
||||||
new_record = {
|
new_record = {
|
||||||
"user": {
|
"user": {
|
||||||
"role": "user",
|
"role": "system",
|
||||||
"content": default_personality['prompt'],
|
"content": default_personality['prompt'],
|
||||||
},
|
},
|
||||||
"AI": {
|
'usage_tokens': 0, # 到该条目的总 token 数
|
||||||
"role": "assistant",
|
'single-tokens': 0 # 该条目的 token 数
|
||||||
"content": "好的,接下来我会扮演这个角色。"
|
|
||||||
},
|
|
||||||
'type': "personality",
|
|
||||||
'usage_tokens': 0,
|
|
||||||
'single-tokens': 0
|
|
||||||
}
|
}
|
||||||
self.session_dict[session_id].append(new_record)
|
|
||||||
|
|
||||||
async def text_chat(self, prompt,
|
self.session_memory[session_id].append(new_record)
|
||||||
session_id=None,
|
|
||||||
image_url=None,
|
async def encode_image_bs64(self, image_url: str) -> str:
|
||||||
function_call=None,
|
'''
|
||||||
extra_conf: dict = None,
|
将图片转换为 base64
|
||||||
default_personality: dict = None):
|
'''
|
||||||
if session_id is None:
|
if image_url.startswith("http"):
|
||||||
|
image_url = await gu.download_image_by_url(image_url)
|
||||||
|
|
||||||
|
with open(image_url, "rb") as f:
|
||||||
|
image_bs64 = base64.b64encode(f.read()).decode()
|
||||||
|
return "data:image/jpeg;base64," + image_bs64
|
||||||
|
|
||||||
|
async def retrieve_context(self, session_id: str):
|
||||||
|
'''
|
||||||
|
根据 session_id 获取保存的 OpenAI 格式的上下文
|
||||||
|
'''
|
||||||
|
if session_id not in self.session_memory:
|
||||||
|
raise Exception("会话 ID 不存在")
|
||||||
|
|
||||||
|
# 转换为 openai 要求的格式
|
||||||
|
context = []
|
||||||
|
for record in self.session_memory[session_id]:
|
||||||
|
if "user" in record and record['user']:
|
||||||
|
context.append(record['user'])
|
||||||
|
if "AI" in record and record['AI']:
|
||||||
|
context.append(record['AI'])
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
async def get_models(self):
|
||||||
|
'''
|
||||||
|
获取所有模型
|
||||||
|
'''
|
||||||
|
models = await self.client.models.list()
|
||||||
|
logger.info(f"OpenAI 模型列表:{models}")
|
||||||
|
return models
|
||||||
|
|
||||||
|
async def assemble_context(self, session_id: str, prompt: str, image_url: str = None):
|
||||||
|
'''
|
||||||
|
组装上下文,并且根据当前上下文窗口大小截断
|
||||||
|
'''
|
||||||
|
if session_id not in self.session_memory:
|
||||||
|
raise Exception("会话 ID 不存在")
|
||||||
|
|
||||||
|
tokens_num = len(self.tokenizer.encode(prompt))
|
||||||
|
previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens']
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"usage_tokens": previous_total_tokens_num + tokens_num,
|
||||||
|
"single_tokens": tokens_num,
|
||||||
|
"AI": None
|
||||||
|
}
|
||||||
|
if image_url:
|
||||||
|
user_content = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"url": await self.encode_image_bs64(image_url)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
user_content = {
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
message["user"] = user_content
|
||||||
|
self.session_memory[session_id].append(message)
|
||||||
|
|
||||||
|
# 根据 模型的上下文窗口 淘汰掉多余的记录
|
||||||
|
curr_model = self.model_configs['model']
|
||||||
|
if curr_model in MODELS:
|
||||||
|
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
|
||||||
|
# if message['usage_tokens'] > maxium_tokens_num:
|
||||||
|
# 淘汰多余的记录,使得最终的 usage_tokens 不超过 maxium_tokens_num - 300
|
||||||
|
# contexts = self.session_memory[session_id]
|
||||||
|
# need_to_remove_idx = 0
|
||||||
|
# freed_tokens_num = contexts[0]['single-tokens']
|
||||||
|
# while freed_tokens_num < message['usage_tokens'] - maxium_tokens_num:
|
||||||
|
# need_to_remove_idx += 1
|
||||||
|
# freed_tokens_num += contexts[need_to_remove_idx]['single-tokens']
|
||||||
|
# # 更新之后的所有记录的 usage_tokens
|
||||||
|
# for i in range(len(contexts)):
|
||||||
|
# if i > need_to_remove_idx:
|
||||||
|
# contexts[i]['usage_tokens'] -= freed_tokens_num
|
||||||
|
# logger.debug(f"淘汰上下文记录 {need_to_remove_idx+1} 条,释放 {freed_tokens_num} 个 token。当前上下文总 token 为 {contexts[-1]['usage_tokens']}。")
|
||||||
|
# self.session_memory[session_id] = contexts[need_to_remove_idx+1:]
|
||||||
|
while len(self.session_memory[session_id]) and self.session_memory[session_id][-1]['usage_tokens'] > maxium_tokens_num:
|
||||||
|
self.pop_record(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||||
|
'''
|
||||||
|
弹出第一条记录
|
||||||
|
'''
|
||||||
|
if session_id not in self.session_memory:
|
||||||
|
raise Exception("会话 ID 不存在")
|
||||||
|
|
||||||
|
if len(self.session_memory[session_id]) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for i in range(len(self.session_memory[session_id])):
|
||||||
|
# 检查是否是 system prompt
|
||||||
|
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||||
|
continue
|
||||||
|
record = self.session_memory[session_id].pop(i)
|
||||||
|
break
|
||||||
|
|
||||||
|
# 更新之后所有记录的 usage_tokens
|
||||||
|
for i in range(len(self.session_memory[session_id])):
|
||||||
|
self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens']
|
||||||
|
logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}。")
|
||||||
|
return record
|
||||||
|
|
||||||
|
async def text_chat(self,
|
||||||
|
prompt: str,
|
||||||
|
session_id: str,
|
||||||
|
image_url: None,
|
||||||
|
tools: None=None,
|
||||||
|
extra_conf: Dict = None,
|
||||||
|
**kwargs
|
||||||
|
) -> str:
|
||||||
|
if not session_id:
|
||||||
session_id = "unknown"
|
session_id = "unknown"
|
||||||
if "unknown" in self.session_dict:
|
if "unknown" in self.session_memory:
|
||||||
del self.session_dict["unknown"]
|
del self.session_memory["unknown"]
|
||||||
# 会话机制
|
|
||||||
if session_id not in self.session_dict:
|
|
||||||
self.session_dict[session_id] = []
|
|
||||||
|
|
||||||
if len(self.session_dict[session_id]) == 0:
|
if session_id not in self.session_memory:
|
||||||
# 设置默认人格
|
self.session_memory[session_id] = []
|
||||||
if default_personality is not None:
|
self.personality_set(self.curr_personality, session_id)
|
||||||
self.personality_set(default_personality, session_id)
|
|
||||||
|
# 如果 prompt 超过了最大窗口,截断。
|
||||||
|
# 1. 可以保证之后 pop 的时候不会出现问题
|
||||||
|
# 2. 可以保证不会超过最大 token 数
|
||||||
|
_encoded_prompt = self.tokenizer.encode(prompt)
|
||||||
|
curr_model = self.model_configs['model']
|
||||||
|
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
|
||||||
|
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
|
||||||
|
prompt = self.tokenizer.decode(_encoded_prompt)
|
||||||
|
|
||||||
|
# 组装上下文,并且根据当前上下文窗口大小截断
|
||||||
|
await self.assemble_context(session_id, prompt, image_url)
|
||||||
|
|
||||||
# 使用 tictoken 截断消息
|
# 获取上下文,openai 格式
|
||||||
_encoded_prompt = self.enc.encode(prompt)
|
contexts = await self.retrieve_context(session_id)
|
||||||
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
|
|
||||||
prompt = self.enc.decode(_encoded_prompt[:int(
|
|
||||||
self.openai_model_configs['max_tokens']*0.80)])
|
|
||||||
logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。")
|
|
||||||
|
|
||||||
cache_data_list, new_record, req = self.wrap(
|
conf = self.model_configs
|
||||||
prompt, session_id, image_url)
|
if extra_conf: conf.update(extra_conf)
|
||||||
logger.debug(f"cache: {str(cache_data_list)}")
|
|
||||||
logger.debug(f"request: {str(req)}")
|
# start request
|
||||||
retry = 0
|
retry = 0
|
||||||
response = None
|
rate_limit_retry = 0
|
||||||
err = ''
|
while retry < 3 or rate_limit_retry < 5:
|
||||||
|
logger.debug(conf)
|
||||||
# 截断倍率
|
logger.debug(contexts)
|
||||||
truncate_rate = 0.75
|
if tools:
|
||||||
|
completion_coro = self.client.chat.completions.create(
|
||||||
conf = self.openai_model_configs
|
messages=contexts,
|
||||||
if extra_conf is not None:
|
tools=tools,
|
||||||
conf.update(extra_conf)
|
**conf
|
||||||
|
)
|
||||||
while retry < 10:
|
else:
|
||||||
|
completion_coro = self.client.chat.completions.create(
|
||||||
|
messages=contexts,
|
||||||
|
**conf
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if function_call is None:
|
completion = await completion_coro
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
messages=req,
|
|
||||||
**conf
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
messages=req,
|
|
||||||
tools=function_call,
|
|
||||||
**conf
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
except AuthenticationError as e:
|
||||||
|
api_key = self.chosen_api_key[10:] + "..."
|
||||||
|
logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key(如果有的话)")
|
||||||
|
self.keys_data[self.chosen_api_key] = False
|
||||||
|
ok = await self.switch_to_next_key()
|
||||||
|
if ok: continue
|
||||||
|
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
|
||||||
|
|
||||||
|
except RateLimitError as e:
|
||||||
|
if "You exceeded your current quota" in e:
|
||||||
|
self.keys_data[self.chosen_api_key] = False
|
||||||
|
ok = await self.switch_to_next_key()
|
||||||
|
if ok: continue
|
||||||
|
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
|
||||||
|
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
|
||||||
|
await self.switch_to_next_key()
|
||||||
|
rate_limit_retry += 1
|
||||||
|
time.sleep(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
|
||||||
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
|
|
||||||
raise e
|
|
||||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
|
||||||
logger.info("当前 Key 已超额或异常, 正在切换",
|
|
||||||
)
|
|
||||||
self.key_stat[self.client.api_key]['exceed'] = True
|
|
||||||
is_switched = self.handle_switch_key()
|
|
||||||
if not is_switched:
|
|
||||||
raise e
|
|
||||||
retry -= 1
|
|
||||||
elif 'maximum context length' in str(e):
|
|
||||||
logger.info("token 超限, 清空对应缓存,并进行消息截断")
|
|
||||||
self.session_dict[session_id] = []
|
|
||||||
prompt = prompt[:int(len(prompt)*truncate_rate)]
|
|
||||||
truncate_rate -= 0.05
|
|
||||||
cache_data_list, new_record, req = self.wrap(
|
|
||||||
prompt, session_id)
|
|
||||||
|
|
||||||
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
|
|
||||||
time.sleep(30)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.error(str(e))
|
|
||||||
time.sleep(2)
|
|
||||||
err = str(e)
|
|
||||||
retry += 1
|
retry += 1
|
||||||
if retry >= 10:
|
if retry >= 3:
|
||||||
logger.warning(
|
logger.error(traceback.format_exc())
|
||||||
r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki")
|
raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。")
|
||||||
raise BaseException("连接出错: "+str(err))
|
if "maximum context length" in str(e):
|
||||||
assert isinstance(response, ChatCompletion)
|
logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||||
logger.debug(
|
self.pop_record(session_id)
|
||||||
f"OPENAI RESPONSE: {response.usage}")
|
|
||||||
|
|
||||||
# 结果分类
|
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
|
||||||
choice = response.choices[0]
|
time.sleep(1)
|
||||||
if choice.message.content != None:
|
|
||||||
# 文本形式
|
assert isinstance(completion, ChatCompletion)
|
||||||
chatgpt_res = str(choice.message.content).strip()
|
logger.debug(f"openai completion: {completion.usage}")
|
||||||
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
|
|
||||||
|
choice = completion.choices[0]
|
||||||
|
|
||||||
|
usage_tokens = completion.usage.total_tokens
|
||||||
|
completion_tokens = completion.usage.completion_tokens
|
||||||
|
self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens
|
||||||
|
self.session_memory[session_id][-1]['single_tokens'] += completion_tokens
|
||||||
|
|
||||||
|
if choice.message.content:
|
||||||
|
# 返回文本
|
||||||
|
completion_text = str(choice.message.content).strip()
|
||||||
|
elif choice.message.tool_calls and choice.message.tool_calls:
|
||||||
# tools call (function calling)
|
# tools call (function calling)
|
||||||
return choice.message.tool_calls[0].function
|
return choice.message.tool_calls[0].function
|
||||||
|
|
||||||
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
|
self.session_memory[session_id][-1]['AI'] = {
|
||||||
current_usage_tokens = response.usage.total_tokens
|
"role": "assistant",
|
||||||
|
"content": completion_text
|
||||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
|
||||||
if current_usage_tokens > self.max_tokens:
|
|
||||||
t = current_usage_tokens
|
|
||||||
index = 0
|
|
||||||
while t > self.max_tokens:
|
|
||||||
if index >= len(cache_data_list):
|
|
||||||
break
|
|
||||||
# 保留人格信息
|
|
||||||
if cache_data_list[index]['type'] != 'personality':
|
|
||||||
t -= int(cache_data_list[index]['single_tokens'])
|
|
||||||
del cache_data_list[index]
|
|
||||||
else:
|
|
||||||
index += 1
|
|
||||||
# 删除完后更新相关字段
|
|
||||||
self.session_dict[session_id] = cache_data_list
|
|
||||||
|
|
||||||
# 添加新条目进入缓存的prompt
|
|
||||||
new_record['AI'] = {
|
|
||||||
'role': 'assistant',
|
|
||||||
'content': chatgpt_res,
|
|
||||||
}
|
}
|
||||||
new_record['usage_tokens'] = current_usage_tokens
|
|
||||||
if len(cache_data_list) > 0:
|
|
||||||
new_record['single_tokens'] = current_usage_tokens - \
|
|
||||||
int(cache_data_list[-1]['usage_tokens'])
|
|
||||||
else:
|
|
||||||
new_record['single_tokens'] = current_usage_tokens
|
|
||||||
|
|
||||||
cache_data_list.append(new_record)
|
return completion_text
|
||||||
|
|
||||||
self.session_dict[session_id] = cache_data_list
|
async def switch_to_next_key(self):
|
||||||
|
'''
|
||||||
return chatgpt_res
|
切换到下一个 API Key
|
||||||
|
'''
|
||||||
async def image_chat(self, prompt, img_num=1, img_size="1024x1024"):
|
if not self.api_keys:
|
||||||
retry = 0
|
logger.error("OpenAI API Key 不存在。")
|
||||||
image_url = ''
|
|
||||||
|
|
||||||
image_generate_configs = self.cc.get("openai_image_generate", None)
|
|
||||||
|
|
||||||
while retry < 5:
|
|
||||||
try:
|
|
||||||
response: ImagesResponse = await self.client.images.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
**image_generate_configs
|
|
||||||
)
|
|
||||||
image_url = []
|
|
||||||
for i in range(img_num):
|
|
||||||
image_url.append(response.data[i].url)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(str(e))
|
|
||||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
|
|
||||||
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
|
||||||
logger.warning("当前 Key 已超额或者不正常, 正在切换")
|
|
||||||
self.key_stat[self.client.api_key]['exceed'] = True
|
|
||||||
is_switched = self.handle_switch_key()
|
|
||||||
if not is_switched:
|
|
||||||
raise e
|
|
||||||
elif 'Your request was rejected as a result of our safety system.' in str(e):
|
|
||||||
logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试")
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
retry += 1
|
|
||||||
if retry >= 5:
|
|
||||||
raise BaseException("连接超时")
|
|
||||||
|
|
||||||
return image_url
|
|
||||||
|
|
||||||
async def forget(self, session_id=None) -> bool:
|
|
||||||
if session_id is None:
|
|
||||||
return False
|
return False
|
||||||
self.session_dict[session_id] = []
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
|
for key in self.keys_data:
|
||||||
|
if self.keys_data[key]:
|
||||||
|
# 没超额
|
||||||
|
self.chosen_api_key = key
|
||||||
|
self.client.api_key = key
|
||||||
|
logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def image_generate(self, prompt, session_id, **kwargs) -> str:
|
||||||
|
'''
|
||||||
|
生成图片
|
||||||
|
'''
|
||||||
|
retry = 0
|
||||||
|
conf = self.image_generator_model_configs
|
||||||
|
if not conf:
|
||||||
|
logger.error("OpenAI 图片生成模型配置不存在。")
|
||||||
|
raise Exception("OpenAI 图片生成模型配置不存在。")
|
||||||
|
|
||||||
|
while retry < 3:
|
||||||
|
try:
|
||||||
|
images_response = await self.client.images.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
**conf
|
||||||
|
)
|
||||||
|
image_url = images_response.data[0].url
|
||||||
|
return image_url
|
||||||
|
except Exception as e:
|
||||||
|
retry += 1
|
||||||
|
if retry >= 3:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
|
||||||
|
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
|
||||||
|
if session_id is None: return False
|
||||||
|
self.session_memory[session_id] = []
|
||||||
|
if keep_system_prompt:
|
||||||
|
self.personality_set(self.curr_personality, session_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def dump_contexts_page(self, size=5, page=1):
|
||||||
'''
|
'''
|
||||||
获取缓存的会话
|
获取缓存的会话
|
||||||
'''
|
'''
|
||||||
prompts = ""
|
contexts_str = ""
|
||||||
if paging:
|
for i, key in enumerate(self.session_memory):
|
||||||
page_begin = (page-1)*size
|
if i < (page-1)*size or i >= page*size:
|
||||||
page_end = page*size
|
|
||||||
if page_begin < 0:
|
|
||||||
page_begin = 0
|
|
||||||
if page_end > len(cache_data_list):
|
|
||||||
page_end = len(cache_data_list)
|
|
||||||
cache_data_list = cache_data_list[page_begin:page_end]
|
|
||||||
for item in cache_data_list:
|
|
||||||
prompts += str(item['user']['role']) + ":\n" + \
|
|
||||||
str(item['user']['content']) + "\n"
|
|
||||||
prompts += str(item['AI']['role']) + ":\n" + \
|
|
||||||
str(item['AI']['content']) + "\n"
|
|
||||||
|
|
||||||
if divide:
|
|
||||||
prompts += "----------\n"
|
|
||||||
return prompts
|
|
||||||
|
|
||||||
def wrap(self, prompt, session_id, image_url=None):
|
|
||||||
if image_url is not None:
|
|
||||||
prompt = [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": image_url
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
# 获得缓存信息
|
|
||||||
context = self.session_dict[session_id]
|
|
||||||
new_record = {
|
|
||||||
"user": {
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
},
|
|
||||||
"AI": {},
|
|
||||||
'type': "common",
|
|
||||||
'usage_tokens': 0,
|
|
||||||
}
|
|
||||||
req_list = []
|
|
||||||
for i in context:
|
|
||||||
if 'user' in i:
|
|
||||||
req_list.append(i['user'])
|
|
||||||
if 'AI' in i:
|
|
||||||
req_list.append(i['AI'])
|
|
||||||
req_list.append(new_record['user'])
|
|
||||||
return context, new_record, req_list
|
|
||||||
|
|
||||||
def handle_switch_key(self):
|
|
||||||
is_all_exceed = True
|
|
||||||
for key in self.key_stat:
|
|
||||||
if key == None or self.key_stat[key]['exceed']:
|
|
||||||
continue
|
continue
|
||||||
is_all_exceed = False
|
contexts_str += f"Session ID: {key}\n"
|
||||||
self.client.api_key = key
|
for record in self.session_memory[key]:
|
||||||
logger.warning(
|
if "user" in record:
|
||||||
f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})")
|
contexts_str += f"User: {record['user']['content']}\n"
|
||||||
break
|
if "AI" in record:
|
||||||
if is_all_exceed:
|
contexts_str += f"AI: {record['AI']['content']}\n"
|
||||||
logger.warning(
|
contexts_str += "---\n"
|
||||||
"所有 Key 已超额")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
return contexts_str, len(self.session_memory)
|
||||||
|
|
||||||
def get_configs(self):
|
def get_configs(self):
|
||||||
return self.openai_configs
|
return self.model_configs
|
||||||
|
|
||||||
def get_key_stat(self):
|
def get_keys_data(self):
|
||||||
return self.key_stat
|
return self.keys_data
|
||||||
|
|
||||||
def get_key_list(self):
|
|
||||||
return self.key_list
|
|
||||||
|
|
||||||
def get_curr_key(self):
|
def get_curr_key(self):
|
||||||
return self.client.api_key
|
return self.chosen_api_key
|
||||||
|
|
||||||
def set_key(self, key):
|
def set_key(self, key):
|
||||||
self.client.api_key = key
|
self.client.api_key = key
|
||||||
|
|
||||||
# 添加key
|
|
||||||
def append_key(self, key, sponsor):
|
|
||||||
self.key_list.append(key)
|
|
||||||
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
|
|
||||||
|
|
||||||
# 检查key是否可用
|
|
||||||
async def check_key(self, key):
|
|
||||||
client_ = AsyncOpenAI(
|
|
||||||
api_key=key,
|
|
||||||
base_url=self.api_base
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "please just echo `test`"}]
|
|
||||||
await client_.chat.completions.create(
|
|
||||||
messages=messages,
|
|
||||||
**self.openai_model_configs
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import tiktoken
|
|
||||||
import threading
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
from openai.types.images_response import ImagesResponse
|
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
|
||||||
|
|
||||||
from cores.database.conn import dbConn
|
|
||||||
from model.provider.provider import Provider
|
|
||||||
from util import general_utils as gu
|
|
||||||
from util.cmd_config import CmdConfig
|
|
||||||
from SparkleLogging.utils.core import LogManager
|
|
||||||
from logging import Logger
|
|
||||||
|
|
||||||
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
|
||||||
|
|
||||||
class ProviderOpenAIOfficial(Provider):
|
|
||||||
def __init__(self, cfg) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
os.makedirs("data/openai", exist_ok=True)
|
|
||||||
|
|
||||||
self.cc = CmdConfig
|
|
||||||
self.key_data_path = "data/openai/keys.json"
|
|
||||||
self.api_keys = []
|
|
||||||
self.chosen_api_key = None
|
|
||||||
self.base_url = None
|
|
||||||
self.keys_data = {
|
|
||||||
"keys": []
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg['key']: self.api_keys = cfg['key']
|
|
||||||
if cfg['api_base']: self.base_url = cfg['api_base']
|
|
||||||
if not self.api_keys:
|
|
||||||
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
|
|
||||||
else:
|
|
||||||
self.chosen_api_key = self.api_keys[0]
|
|
||||||
|
|
||||||
self.client = AsyncOpenAI(
|
|
||||||
api_key=self.chosen_api_key,
|
|
||||||
base_url=self.base_url
|
|
||||||
)
|
|
||||||
self.model_configs: dict = cfg['chatGPTConfigs']
|
|
||||||
self.session_memory = {} # 会话记忆
|
|
||||||
self.session_memory_lock = threading.Lock()
|
|
||||||
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
|
|
||||||
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
|
|
||||||
|
|
||||||
# 从 SQLite DB 读取历史记录
|
|
||||||
try:
|
|
||||||
db1 = dbConn()
|
|
||||||
for session in db1.get_all_session():
|
|
||||||
self.session_memory_lock.acquire()
|
|
||||||
self.session_memory[session[0]] = json.loads(session[1])['data']
|
|
||||||
self.session_memory_lock.release()
|
|
||||||
except BaseException as e:
|
|
||||||
logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
|
||||||
|
|
||||||
# 定时保存历史记录
|
|
||||||
threading.Thread(target=self.dump_history, daemon=True).start()
|
|
||||||
|
|
||||||
409
model/provider/openai_official_old.py
Normal file
409
model/provider/openai_official_old.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import tiktoken
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from openai.types.images_response import ImagesResponse
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
|
from cores.database.conn import dbConn
|
||||||
|
from model.provider.provider import Provider
|
||||||
|
from util import general_utils as gu
|
||||||
|
from util.cmd_config import CmdConfig
|
||||||
|
from SparkleLogging.utils.core import LogManager
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||||
|
|
||||||
|
|
||||||
|
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderOpenAIOfficial(Provider):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cc = CmdConfig()
|
||||||
|
|
||||||
|
self.key_list = []
|
||||||
|
# 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错
|
||||||
|
for key in cfg['key']:
|
||||||
|
if len(key) == 1:
|
||||||
|
raise BaseException(
|
||||||
|
"检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。")
|
||||||
|
if cfg['key'] != '' and cfg['key'] != None:
|
||||||
|
self.key_list = cfg['key']
|
||||||
|
if len(self.key_list) == 0:
|
||||||
|
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
|
||||||
|
|
||||||
|
self.key_stat = {}
|
||||||
|
for k in self.key_list:
|
||||||
|
self.key_stat[k] = {'exceed': False, 'used': 0}
|
||||||
|
|
||||||
|
self.api_base = None
|
||||||
|
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
||||||
|
self.api_base = cfg['api_base']
|
||||||
|
logger.info(f"设置 api_base 为: {self.api_base}")
|
||||||
|
|
||||||
|
# 创建 OpenAI Client
|
||||||
|
self.client = AsyncOpenAI(
|
||||||
|
api_key=self.key_list[0],
|
||||||
|
base_url=self.api_base
|
||||||
|
)
|
||||||
|
|
||||||
|
self.openai_model_configs: dict = cfg['chatGPTConfigs']
|
||||||
|
self.openai_configs = cfg
|
||||||
|
# 会话缓存
|
||||||
|
self.session_dict = {}
|
||||||
|
# 最大缓存token
|
||||||
|
self.max_tokens = cfg['total_tokens_limit']
|
||||||
|
# 历史记录持久化间隔时间
|
||||||
|
self.history_dump_interval = 20
|
||||||
|
|
||||||
|
self.enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
# 从 SQLite DB 读取历史记录
|
||||||
|
try:
|
||||||
|
db1 = dbConn()
|
||||||
|
for session in db1.get_all_session():
|
||||||
|
self.session_dict[session[0]] = json.loads(session[1])['data']
|
||||||
|
logger.info("读取历史记录成功。")
|
||||||
|
except BaseException as e:
|
||||||
|
logger.info("读取历史记录失败,但不影响使用。")
|
||||||
|
|
||||||
|
# 创建转储定时器线程
|
||||||
|
threading.Thread(target=self.dump_history, daemon=True).start()
|
||||||
|
|
||||||
|
# 人格
|
||||||
|
self.curr_personality = {}
|
||||||
|
|
||||||
|
def make_tmp_client(self, api_key: str, base_url: str):
|
||||||
|
return AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转储历史记录
|
||||||
|
def dump_history(self):
|
||||||
|
time.sleep(10)
|
||||||
|
db = dbConn()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# print("转储历史记录...")
|
||||||
|
for key in self.session_dict:
|
||||||
|
data = self.session_dict[key]
|
||||||
|
data_json = {
|
||||||
|
'data': data
|
||||||
|
}
|
||||||
|
if db.check_session(key):
|
||||||
|
db.update_session(key, json.dumps(data_json))
|
||||||
|
else:
|
||||||
|
db.insert_session(key, json.dumps(data_json))
|
||||||
|
# print("转储历史记录完毕")
|
||||||
|
except BaseException as e:
|
||||||
|
print(e)
|
||||||
|
# 每隔10分钟转储一次
|
||||||
|
time.sleep(10*self.history_dump_interval)
|
||||||
|
|
||||||
|
def personality_set(self, default_personality: dict, session_id: str):
|
||||||
|
self.curr_personality = default_personality
|
||||||
|
new_record = {
|
||||||
|
"user": {
|
||||||
|
"role": "user",
|
||||||
|
"content": default_personality['prompt'],
|
||||||
|
},
|
||||||
|
"AI": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "好的,接下来我会扮演这个角色。"
|
||||||
|
},
|
||||||
|
'type': "personality",
|
||||||
|
'usage_tokens': 0,
|
||||||
|
'single-tokens': 0
|
||||||
|
}
|
||||||
|
self.session_dict[session_id].append(new_record)
|
||||||
|
|
||||||
|
async def text_chat(self, prompt,
|
||||||
|
session_id=None,
|
||||||
|
image_url=None,
|
||||||
|
function_call=None,
|
||||||
|
extra_conf: dict = None,
|
||||||
|
default_personality: dict = None):
|
||||||
|
if session_id is None:
|
||||||
|
session_id = "unknown"
|
||||||
|
if "unknown" in self.session_dict:
|
||||||
|
del self.session_dict["unknown"]
|
||||||
|
# 会话机制
|
||||||
|
if session_id not in self.session_dict:
|
||||||
|
self.session_dict[session_id] = []
|
||||||
|
|
||||||
|
if len(self.session_dict[session_id]) == 0:
|
||||||
|
# 设置默认人格
|
||||||
|
if default_personality is not None:
|
||||||
|
self.personality_set(default_personality, session_id)
|
||||||
|
|
||||||
|
# 使用 tictoken 截断消息
|
||||||
|
_encoded_prompt = self.enc.encode(prompt)
|
||||||
|
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
|
||||||
|
prompt = self.enc.decode(_encoded_prompt[:int(
|
||||||
|
self.openai_model_configs['max_tokens']*0.80)])
|
||||||
|
logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。")
|
||||||
|
|
||||||
|
cache_data_list, new_record, req = self.wrap(
|
||||||
|
prompt, session_id, image_url)
|
||||||
|
logger.debug(f"cache: {str(cache_data_list)}")
|
||||||
|
logger.debug(f"request: {str(req)}")
|
||||||
|
retry = 0
|
||||||
|
response = None
|
||||||
|
err = ''
|
||||||
|
|
||||||
|
# 截断倍率
|
||||||
|
truncate_rate = 0.75
|
||||||
|
|
||||||
|
conf = self.openai_model_configs
|
||||||
|
if extra_conf is not None:
|
||||||
|
conf.update(extra_conf)
|
||||||
|
|
||||||
|
while retry < 10:
|
||||||
|
try:
|
||||||
|
if function_call is None:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
messages=req,
|
||||||
|
**conf
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
messages=req,
|
||||||
|
tools=function_call,
|
||||||
|
**conf
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
|
||||||
|
raise e
|
||||||
|
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||||
|
logger.info("当前 Key 已超额或异常, 正在切换",
|
||||||
|
)
|
||||||
|
self.key_stat[self.client.api_key]['exceed'] = True
|
||||||
|
is_switched = self.handle_switch_key()
|
||||||
|
if not is_switched:
|
||||||
|
raise e
|
||||||
|
retry -= 1
|
||||||
|
elif 'maximum context length' in str(e):
|
||||||
|
logger.info("token 超限, 清空对应缓存,并进行消息截断")
|
||||||
|
self.session_dict[session_id] = []
|
||||||
|
prompt = prompt[:int(len(prompt)*truncate_rate)]
|
||||||
|
truncate_rate -= 0.05
|
||||||
|
cache_data_list, new_record, req = self.wrap(
|
||||||
|
prompt, session_id)
|
||||||
|
|
||||||
|
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
|
||||||
|
time.sleep(30)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(str(e))
|
||||||
|
time.sleep(2)
|
||||||
|
err = str(e)
|
||||||
|
retry += 1
|
||||||
|
if retry >= 10:
|
||||||
|
logger.warning(
|
||||||
|
r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki")
|
||||||
|
raise BaseException("连接出错: "+str(err))
|
||||||
|
assert isinstance(response, ChatCompletion)
|
||||||
|
logger.debug(
|
||||||
|
f"OPENAI RESPONSE: {response.usage}")
|
||||||
|
|
||||||
|
# 结果分类
|
||||||
|
choice = response.choices[0]
|
||||||
|
if choice.message.content != None:
|
||||||
|
# 文本形式
|
||||||
|
chatgpt_res = str(choice.message.content).strip()
|
||||||
|
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
|
||||||
|
# tools call (function calling)
|
||||||
|
return choice.message.tool_calls[0].function
|
||||||
|
|
||||||
|
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
|
||||||
|
current_usage_tokens = response.usage.total_tokens
|
||||||
|
|
||||||
|
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||||
|
if current_usage_tokens > self.max_tokens:
|
||||||
|
t = current_usage_tokens
|
||||||
|
index = 0
|
||||||
|
while t > self.max_tokens:
|
||||||
|
if index >= len(cache_data_list):
|
||||||
|
break
|
||||||
|
# 保留人格信息
|
||||||
|
if cache_data_list[index]['type'] != 'personality':
|
||||||
|
t -= int(cache_data_list[index]['single_tokens'])
|
||||||
|
del cache_data_list[index]
|
||||||
|
else:
|
||||||
|
index += 1
|
||||||
|
# 删除完后更新相关字段
|
||||||
|
self.session_dict[session_id] = cache_data_list
|
||||||
|
|
||||||
|
# 添加新条目进入缓存的prompt
|
||||||
|
new_record['AI'] = {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': chatgpt_res,
|
||||||
|
}
|
||||||
|
new_record['usage_tokens'] = current_usage_tokens
|
||||||
|
if len(cache_data_list) > 0:
|
||||||
|
new_record['single_tokens'] = current_usage_tokens - \
|
||||||
|
int(cache_data_list[-1]['usage_tokens'])
|
||||||
|
else:
|
||||||
|
new_record['single_tokens'] = current_usage_tokens
|
||||||
|
|
||||||
|
cache_data_list.append(new_record)
|
||||||
|
|
||||||
|
self.session_dict[session_id] = cache_data_list
|
||||||
|
|
||||||
|
return chatgpt_res
|
||||||
|
|
||||||
|
async def image_chat(self, prompt, img_num=1, img_size="1024x1024"):
|
||||||
|
retry = 0
|
||||||
|
image_url = ''
|
||||||
|
|
||||||
|
image_generate_configs = self.cc.get("openai_image_generate", None)
|
||||||
|
|
||||||
|
while retry < 5:
|
||||||
|
try:
|
||||||
|
response: ImagesResponse = await self.client.images.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
**image_generate_configs
|
||||||
|
)
|
||||||
|
image_url = []
|
||||||
|
for i in range(img_num):
|
||||||
|
image_url.append(response.data[i].url)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(str(e))
|
||||||
|
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
|
||||||
|
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||||
|
logger.warning("当前 Key 已超额或者不正常, 正在切换")
|
||||||
|
self.key_stat[self.client.api_key]['exceed'] = True
|
||||||
|
is_switched = self.handle_switch_key()
|
||||||
|
if not is_switched:
|
||||||
|
raise e
|
||||||
|
elif 'Your request was rejected as a result of our safety system.' in str(e):
|
||||||
|
logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试")
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
retry += 1
|
||||||
|
if retry >= 5:
|
||||||
|
raise BaseException("连接超时")
|
||||||
|
|
||||||
|
return image_url
|
||||||
|
|
||||||
|
async def forget(self, session_id=None) -> bool:
|
||||||
|
if session_id is None:
|
||||||
|
return False
|
||||||
|
self.session_dict[session_id] = []
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
|
||||||
|
'''
|
||||||
|
获取缓存的会话
|
||||||
|
'''
|
||||||
|
prompts = ""
|
||||||
|
if paging:
|
||||||
|
page_begin = (page-1)*size
|
||||||
|
page_end = page*size
|
||||||
|
if page_begin < 0:
|
||||||
|
page_begin = 0
|
||||||
|
if page_end > len(cache_data_list):
|
||||||
|
page_end = len(cache_data_list)
|
||||||
|
cache_data_list = cache_data_list[page_begin:page_end]
|
||||||
|
for item in cache_data_list:
|
||||||
|
prompts += str(item['user']['role']) + ":\n" + \
|
||||||
|
str(item['user']['content']) + "\n"
|
||||||
|
prompts += str(item['AI']['role']) + ":\n" + \
|
||||||
|
str(item['AI']['content']) + "\n"
|
||||||
|
|
||||||
|
if divide:
|
||||||
|
prompts += "----------\n"
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
def wrap(self, prompt, session_id, image_url=None):
|
||||||
|
if image_url is not None:
|
||||||
|
prompt = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
# 获得缓存信息
|
||||||
|
context = self.session_dict[session_id]
|
||||||
|
new_record = {
|
||||||
|
"user": {
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
},
|
||||||
|
"AI": {},
|
||||||
|
'type': "common",
|
||||||
|
'usage_tokens': 0,
|
||||||
|
}
|
||||||
|
req_list = []
|
||||||
|
for i in context:
|
||||||
|
if 'user' in i:
|
||||||
|
req_list.append(i['user'])
|
||||||
|
if 'AI' in i:
|
||||||
|
req_list.append(i['AI'])
|
||||||
|
req_list.append(new_record['user'])
|
||||||
|
return context, new_record, req_list
|
||||||
|
|
||||||
|
def handle_switch_key(self):
|
||||||
|
is_all_exceed = True
|
||||||
|
for key in self.key_stat:
|
||||||
|
if key == None or self.key_stat[key]['exceed']:
|
||||||
|
continue
|
||||||
|
is_all_exceed = False
|
||||||
|
self.client.api_key = key
|
||||||
|
logger.warning(
|
||||||
|
f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})")
|
||||||
|
break
|
||||||
|
if is_all_exceed:
|
||||||
|
logger.warning(
|
||||||
|
"所有 Key 已超额")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_configs(self):
|
||||||
|
return self.openai_configs
|
||||||
|
|
||||||
|
def get_key_stat(self):
|
||||||
|
return self.key_stat
|
||||||
|
|
||||||
|
def get_key_list(self):
|
||||||
|
return self.key_list
|
||||||
|
|
||||||
|
def get_curr_key(self):
|
||||||
|
return self.client.api_key
|
||||||
|
|
||||||
|
def set_key(self, key):
|
||||||
|
self.client.api_key = key
|
||||||
|
|
||||||
|
# 添加key
|
||||||
|
def append_key(self, key, sponsor):
|
||||||
|
self.key_list.append(key)
|
||||||
|
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
|
||||||
|
|
||||||
|
# 检查key是否可用
|
||||||
|
async def check_key(self, key):
|
||||||
|
client_ = AsyncOpenAI(
|
||||||
|
api_key=key,
|
||||||
|
base_url=self.api_base
|
||||||
|
)
|
||||||
|
messages = [{"role": "user", "content": "please just echo `test`"}]
|
||||||
|
await client_.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
**self.openai_model_configs
|
||||||
|
)
|
||||||
|
return True
|
||||||
@@ -3,7 +3,7 @@ class Provider:
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
image_url: None,
|
image_url: None,
|
||||||
function_call: None,
|
tools: None,
|
||||||
extra_conf: dict = None,
|
extra_conf: dict = None,
|
||||||
default_personality: dict = None,
|
default_personality: dict = None,
|
||||||
**kwargs) -> str:
|
**kwargs) -> str:
|
||||||
@@ -14,7 +14,7 @@ class Provider:
|
|||||||
|
|
||||||
[optional]
|
[optional]
|
||||||
image_url: 图片url(识图)
|
image_url: 图片url(识图)
|
||||||
function_call: 函数调用
|
tools: 函数调用工具
|
||||||
extra_conf: 额外配置
|
extra_conf: 额外配置
|
||||||
default_personality: 默认人格
|
default_personality: 默认人格
|
||||||
'''
|
'''
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
import datetime
|
|
||||||
import time
|
import time
|
||||||
import socket
|
import socket
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import requests
|
import requests
|
||||||
from util.cmd_config import CmdConfig
|
import aiohttp
|
||||||
import socket
|
import socket
|
||||||
from cores.astrbot.types import GlobalObject
|
|
||||||
import platform
|
import platform
|
||||||
import logging
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from cores.astrbot.types import GlobalObject
|
||||||
|
from SparkleLogging.utils.core import LogManager
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||||
|
|
||||||
|
|
||||||
def port_checker(port: int, host: str = "localhost"):
|
def port_checker(port: int, host: str = "localhost"):
|
||||||
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
sk.settimeout(1)
|
sk.settimeout(1)
|
||||||
@@ -355,10 +359,23 @@ def save_temp_img(img: Image) -> str:
|
|||||||
|
|
||||||
# 获得时间戳
|
# 获得时间戳
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
p = f"temp/{timestamp}.png"
|
p = f"temp/{timestamp}.jpg"
|
||||||
img.save(p)
|
img.save(p)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
async def download_image_by_url(url: str) -> str:
|
||||||
|
'''
|
||||||
|
下载图片
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url) as resp:
|
||||||
|
img = Image.open(await resp.read())
|
||||||
|
p = save_temp_img(img)
|
||||||
|
return p
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def create_text_image(title: str, text: str, max_width=30, font_size=20):
|
def create_text_image(title: str, text: str, max_width=30, font_size=20):
|
||||||
'''
|
'''
|
||||||
@@ -455,6 +472,21 @@ def upload(_global_object: GlobalObject):
|
|||||||
pass
|
pass
|
||||||
time.sleep(10*60)
|
time.sleep(10*60)
|
||||||
|
|
||||||
|
def retry(n: int = 3):
|
||||||
|
'''
|
||||||
|
重试装饰器
|
||||||
|
'''
|
||||||
|
def decorator(func):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
for i in range(n):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if i == n-1: raise e
|
||||||
|
logger.warning(f"函数 {func.__name__} 第 {i+1} 次重试... {e}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def run_monitor(global_object: GlobalObject):
|
def run_monitor(global_object: GlobalObject):
|
||||||
'''
|
'''
|
||||||
|
|||||||
Reference in New Issue
Block a user