refactor: 重写 LLM OpenAI 模块

This commit is contained in:
Soulter
2024-05-17 22:56:44 +08:00
parent 1775327c2e
commit 934ca94e62
8 changed files with 864 additions and 520 deletions

View File

@@ -84,7 +84,7 @@ def init(cfg):
global _global_object
# 迁移旧配置
gu.try_migrate_config(cfg)
gu.try_migrate_config()
# 使用新配置
cfg = cc.get_all()
@@ -106,6 +106,15 @@ def init(cfg):
else:
_global_object.reply_prefix = cfg['reply_prefix']
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
# 语言模型提供商
logger.info("正在载入语言模型...")
prov = privider_chooser(cfg)
@@ -122,6 +131,10 @@ def init(cfg):
llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal"))
chosen_provider = OPENAI_OFFICIAL
instance = llm_instance[OPENAI_OFFICIAL]
assert isinstance(instance, ProviderOpenAIOfficial)
instance.personality_set(_global_object.default_personality, session_id=None)
# 检查provider设置偏好
p = cc.get("chosen_provider", None)
if p is not None and p in llm_instance:
@@ -197,14 +210,6 @@ def init(cfg):
cfg, _global_object), daemon=True).start()
platform_str += "QQ_OFFICIAL,"
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
# 初始化dashboard
_global_object.dashboard_data = DashBoardData(
stats={},
@@ -430,12 +435,12 @@ async def oper_msg(message: AstrBotMessage,
official_fc = chosen_provider == OPENAI_OFFICIAL
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
else:
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality=_global_object.default_personality)
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url)
llm_result_str = _global_object.reply_prefix + llm_result_str
except BaseException as e:
logger.info(f"调用异常:{traceback.format_exc()}")
return MessageResult(f"调用语言模型例程时出现异常。原因: {str(e)}")
logger.error(f"调用异常:{traceback.format_exc()}")
return MessageResult(f"调用异常。详细原因:{str(e)}")
# 切换回原来的语言模型
if temp_switch != "":
@@ -458,14 +463,10 @@ async def oper_msg(message: AstrBotMessage,
return MessageResult(f"指令调用错误: \n{str(command_result[1])}")
# 画图指令
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw':
for i in command_result[1]:
if command == 'draw':
# 保存到本地
async with aiohttp.ClientSession() as session:
async with session.get(i) as resp:
if resp.status == 200:
image = PILImage.open(io.BytesIO(await resp.read()))
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
path = await gu.download_image_by_url(command_result[1])
return MessageResult([Image.fromFileSystem(path)])
# 其他指令
else:
try:

View File

@@ -223,8 +223,6 @@ class Command:
"nick": "设置机器人昵称",
"plugin": "插件安装、卸载和重载",
"web on/off": "LLM 网页搜索能力",
"reset": "重置 LLM 对话",
"/gpt": "切换到 OpenAI 官方接口"
}
async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None):

View File

@@ -16,8 +16,7 @@ class CommandOpenAIOfficial(Command):
self.commands = [
CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"),
CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"),
CommandItem("status", self.gpt, "查看 GPT 配置信息和用量状态。", "内置"),
CommandItem("status", self.status, "查看 GPT 配置信息和用量状态。", "内置"),
]
super().__init__(provider, global_object)
@@ -59,8 +58,6 @@ class CommandOpenAIOfficial(Command):
return True, self.update(message, role)
elif self.command_start_with(message, "", "draw"):
return True, await self.draw(message)
elif self.command_start_with(message, "key"):
return True, self.key(message)
elif self.command_start_with(message, "switch"):
return True, await self.switch(message)
elif self.command_start_with(message, "models"):
@@ -87,12 +84,13 @@ class CommandOpenAIOfficial(Command):
async def help(self):
commands = super().general_commands()
commands[''] = '画画'
commands['key'] = '添加OpenAI key'
commands[''] = '调用 OpenAI DallE 模型生成图片'
commands['set'] = '人格设置面板'
commands['gpt'] = '查看gpt配置信息'
commands['status'] = '查看key使用状态'
commands['token'] = '查看本轮会话token'
commands['status'] = '查看 Api Key 状态和配置信息'
commands['token'] = '查看本轮会话 token'
commands['reset'] = '重置当前与 LLM 的会话'
commands['reset p'] = '重置当前与 LLM 的会话但保留人格system prompt'
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
async def reset(self, session_id: str, message: str = "reset"):
@@ -103,66 +101,34 @@ class CommandOpenAIOfficial(Command):
await self.provider.forget(session_id)
return True, "重置成功", "reset"
if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id)
if self.personality_str != "":
self.set(self.personality_str, session_id) # 重新设置人格
return True, "重置成功", "reset"
await self.provider.forget(session_id, keep_system_prompt=True)
def his(self, message: str, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "his"
# 分页每页5条
msg = ''
size_per_page = 3
page = 1
if message[4:]:
page = int(message[4:])
# 检查是否有过历史记录
if session_id not in self.provider.session_dict:
msg = f"历史记录为空"
return True, msg, "his"
l = self.provider.session_dict[session_id]
max_page = len(l)//size_per_page + \
1 if len(l) % size_per_page != 0 else len(l)//size_per_page
p = self.provider.get_prompts_by_cache_list(
self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
return True, f"历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页", "his"
l = message.split(" ")
if len(l) == 2:
try:
page = int(l[1])
except BaseException as e:
return True, "页码不合法", "his"
contexts, total_num = self.provider.dump_contexts_page(size_per_page, page=page)
t_pages = total_num // size_per_page + 1
return True, f"历史记录如下:\n{contexts}\n{page} 页 | 共 {t_pages}\n*输入 /his 2 跳转到第 2 页", "his"
def status(self):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "status"
chatgpt_cfg_str = ""
key_stat = self.provider.get_key_stat()
index = 1
max = 9000000
gg_count = 0
total = 0
tag = ''
for key in key_stat.keys():
sponsor = ''
total += key_stat[key]['used']
if key_stat[key]['exceed']:
gg_count += 1
continue
if 'sponsor' in key_stat[key]:
sponsor = key_stat[key]['sponsor']
chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n"
index += 1
return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status"
keys_data = self.provider.get_keys_data()
ret = "OpenAI Key"
for k in keys_data:
status = "🟢" if keys_data[k]['status'] == 0 else "🔴"
ret += "\n|- " + k[:8] + " " + status
def key(self, message: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ")
if len(l) == 1:
msg = "感谢您赞助keykey为官方API使用请以以下格式赞助:\n/key xxxxx"
return True, msg, "key"
key = l[1]
if self.provider.check_key(key):
self.provider.append_key(key)
return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~"
else:
return True, "该Key被验证为无效。也许是输入错误了或者重试。", "key"
conf = self.provider.get_configs()
ret += "\n当前模型:" + conf['model']
async def switch(self, message: str):
'''
@@ -179,14 +145,13 @@ class CommandOpenAIOfficial(Command):
return True, ret, "switch"
elif len(l) == 2:
try:
key_stat = self.provider.get_key_stat()
key_stat = self.provider.get_keys_data()
index = int(l[1])
if index > len(key_stat) or index < 1:
return True, "账号序号不合法。", "switch"
else:
try:
new_key = list(key_stat.keys())[index-1]
ret = await self.provider.check_key(new_key)
self.provider.set_key(new_key)
except BaseException as e:
return True, "账号切换失败,原因: " + str(e), "switch"
@@ -235,58 +200,22 @@ class CommandOpenAIOfficial(Command):
'name': ps,
'prompt': personalities[ps]
}
self.provider.session_dict[session_id] = []
new_record = {
"user": {
"role": "user",
"content": personalities[ps],
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
self.provider.personality_set(ps, session_id)
return True, f"人格{ps}已设置。", "set"
else:
self.provider.curr_personality = {
'name': '自定义人格',
'prompt': ps
}
new_record = {
"user": {
"role": "user",
"content": ps,
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.provider.session_dict[session_id] = []
self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
self.provider.personality_set(ps, session_id)
return True, f"自定义人格已设置。 \n人格信息: {ps}", "set"
async def draw(self, message):
async def draw(self, message: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "draw"
if message.startswith("/画"):
message = message[2:]
elif message.startswith(""):
message = message[1:]
try:
# 画图模式传回3个参数
img_url = await self.provider.image_chat(message)
img_url = await self.provider.image_generate(message)
return True, img_url, "draw"
except Exception as e:
if 'exceeded' in str(e):
return f"OpenAI API错误。原因\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库QQChannelChatGPT)"
return False, f"图片生成失败: {e}", "draw"

View File

@@ -5,10 +5,12 @@ import time
import tiktoken
import threading
import traceback
import base64
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import *
from cores.database.conn import dbConn
from model.provider.provider import Provider
@@ -16,84 +18,94 @@ from util import general_utils as gu
from util.cmd_config import CmdConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
from typing import List, Dict
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
MODELS = {
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-0613": 16385,
"gpt-3.5-turbo-16k-0613": 16385,
}
class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg):
self.cc = CmdConfig()
def __init__(self, cfg) -> None:
super().__init__()
self.key_list = []
# 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错
for key in cfg['key']:
if len(key) == 1:
raise BaseException(
"检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。")
if cfg['key'] != '' and cfg['key'] != None:
self.key_list = cfg['key']
if len(self.key_list) == 0:
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
os.makedirs("data/openai", exist_ok=True)
self.key_stat = {}
for k in self.key_list:
self.key_stat[k] = {'exceed': False, 'used': 0}
self.cc = CmdConfig
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.keys_data = {} # 记录超额
self.api_base = None
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
self.api_base = cfg['api_base']
logger.info(f"设置 api_base 为: {self.api_base}")
if cfg['key']: self.api_keys = cfg['key']
if cfg['api_base']: self.base_url = cfg['api_base']
if not self.api_keys:
logger.warn("看起来你没有添加 OpenAI 的 API 密钥OpenAI LLM 能力将不会启用。")
else:
self.chosen_api_key = self.api_keys[0]
for key in self.api_keys:
self.keys_data[key] = True
# 创建 OpenAI Client
self.client = AsyncOpenAI(
api_key=self.key_list[0],
base_url=self.api_base
api_key=self.chosen_api_key,
base_url=self.base_url
)
self.openai_model_configs: dict = cfg['chatGPTConfigs']
self.openai_configs = cfg
# 会话缓存
self.session_dict = {}
# 最大缓存token
self.max_tokens = cfg['total_tokens_limit']
# 历史记录持久化间隔时间
self.history_dump_interval = 20
self.enc = tiktoken.get_encoding("cl100k_base")
self.model_configs: Dict = cfg['chatGPTConfigs']
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
self.curr_personality = {
"name": "default",
"prompt": "你是一个很有帮助的 AI 助手。"
}
# 从 SQLite DB 读取历史记录
try:
db1 = dbConn()
for session in db1.get_all_session():
self.session_dict[session[0]] = json.loads(session[1])['data']
logger.info("读取历史记录成功。")
self.session_memory_lock.acquire()
self.session_memory[session[0]] = json.loads(session[1])['data']
self.session_memory_lock.release()
except BaseException as e:
logger.info("读取历史记录失败,但不影响使用。")
logger.warn(f"读取 OpenAI LLM 对话历史记录 失败{e}。仍可正常使用。")
# 创建转储定时器线程
# 定时保存历史记录
threading.Thread(target=self.dump_history, daemon=True).start()
# 人格
self.curr_personality = {}
def make_tmp_client(self, api_key: str, base_url: str):
return AsyncOpenAI(
api_key=api_key,
base_url=base_url
)
# 转储历史记录
def dump_history(self):
'''
转储历史记录
'''
time.sleep(10)
db = dbConn()
while True:
try:
# print("转储历史记录...")
for key in self.session_dict:
data = self.session_dict[key]
for key in self.session_memory:
data = self.session_memory[key]
data_json = {
'data': data
}
@@ -101,309 +113,338 @@ class ProviderOpenAIOfficial(Provider):
db.update_session(key, json.dumps(data_json))
else:
db.insert_session(key, json.dumps(data_json))
# print("转储历史记录完毕")
logger.debug("已保存 OpenAI 会话历史记录")
except BaseException as e:
print(e)
# 每隔10分钟转储一次
time.sleep(10*self.history_dump_interval)
finally:
time.sleep(10*60)
def personality_set(self, default_personality: dict, session_id: str):
if not default_personality: return
if session_id not in self.session_memory: self.session_memory[session_id] = []
self.curr_personality = default_personality
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
tokens_num = len(encoded_prompt)
model = self.model_configs['model']
if model in MODELS and tokens_num > MODELS[model] - 800:
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 800])
new_record = {
"user": {
"role": "user",
"role": "system",
"content": default_personality['prompt'],
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
'usage_tokens': 0, # 到该条目的总 token 数
'single-tokens': 0 # 该条目的 token 数
}
self.session_dict[session_id].append(new_record)
async def text_chat(self, prompt,
session_id=None,
image_url=None,
function_call=None,
extra_conf: dict = None,
default_personality: dict = None):
if session_id is None:
session_id = "unknown"
if "unknown" in self.session_dict:
del self.session_dict["unknown"]
# 会话机制
if session_id not in self.session_dict:
self.session_dict[session_id] = []
self.session_memory[session_id].append(new_record)
if len(self.session_dict[session_id]) == 0:
# 设置默认人格
if default_personality is not None:
self.personality_set(default_personality, session_id)
async def encode_image_bs64(self, image_url: str) -> str:
'''
将图片转换为 base64
'''
if image_url.startswith("http"):
image_url = await gu.download_image_by_url(image_url)
# 使用 tictoken 截断消息
_encoded_prompt = self.enc.encode(prompt)
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
prompt = self.enc.decode(_encoded_prompt[:int(
self.openai_model_configs['max_tokens']*0.80)])
logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode()
return "data:image/jpeg;base64," + image_bs64
cache_data_list, new_record, req = self.wrap(
prompt, session_id, image_url)
logger.debug(f"cache: {str(cache_data_list)}")
logger.debug(f"request: {str(req)}")
retry = 0
response = None
err = ''
async def retrieve_context(self, session_id: str):
'''
根据 session_id 获取保存的 OpenAI 格式的上下文
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
# 截断倍率
truncate_rate = 0.75
# 转换为 openai 要求的格式
context = []
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
context.append(record['user'])
if "AI" in record and record['AI']:
context.append(record['AI'])
conf = self.openai_model_configs
if extra_conf is not None:
conf.update(extra_conf)
return context
while retry < 10:
try:
if function_call is None:
response = await self.client.chat.completions.create(
messages=req,
**conf
)
else:
response = await self.client.chat.completions.create(
messages=req,
tools=function_call,
**conf
)
break
except Exception as e:
traceback.print_exc()
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
logger.info("当前 Key 已超额或异常, 正在切换",
)
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
retry -= 1
elif 'maximum context length' in str(e):
logger.info("token 超限, 清空对应缓存,并进行消息截断")
self.session_dict[session_id] = []
prompt = prompt[:int(len(prompt)*truncate_rate)]
truncate_rate -= 0.05
cache_data_list, new_record, req = self.wrap(
prompt, session_id)
async def get_models(self):
'''
获取所有模型
'''
models = await self.client.models.list()
logger.info(f"OpenAI 模型列表:{models}")
return models
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
time.sleep(30)
continue
else:
logger.error(str(e))
time.sleep(2)
err = str(e)
retry += 1
if retry >= 10:
logger.warning(
r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki")
raise BaseException("连接出错: "+str(err))
assert isinstance(response, ChatCompletion)
logger.debug(
f"OPENAI RESPONSE: {response.usage}")
async def assemble_context(self, session_id: str, prompt: str, image_url: str = None):
'''
组装上下文,并且根据当前上下文窗口大小截断
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
# 结果分类
choice = response.choices[0]
if choice.message.content != None:
# 文本形式
chatgpt_res = str(choice.message.content).strip()
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
# tools call (function calling)
return choice.message.tool_calls[0].function
tokens_num = len(self.tokenizer.encode(prompt))
previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens']
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
current_usage_tokens = response.usage.total_tokens
# 超过指定tokens 尽可能的保留最多的条目直到小于max_tokens
if current_usage_tokens > self.max_tokens:
t = current_usage_tokens
index = 0
while t > self.max_tokens:
if index >= len(cache_data_list):
break
# 保留人格信息
if cache_data_list[index]['type'] != 'personality':
t -= int(cache_data_list[index]['single_tokens'])
del cache_data_list[index]
else:
index += 1
# 删除完后更新相关字段
self.session_dict[session_id] = cache_data_list
# 添加新条目进入缓存的prompt
new_record['AI'] = {
'role': 'assistant',
'content': chatgpt_res,
message = {
"usage_tokens": previous_total_tokens_num + tokens_num,
"single_tokens": tokens_num,
"AI": None
}
new_record['usage_tokens'] = current_usage_tokens
if len(cache_data_list) > 0:
new_record['single_tokens'] = current_usage_tokens - \
int(cache_data_list[-1]['usage_tokens'])
else:
new_record['single_tokens'] = current_usage_tokens
cache_data_list.append(new_record)
self.session_dict[session_id] = cache_data_list
return chatgpt_res
async def image_chat(self, prompt, img_num=1, img_size="1024x1024"):
retry = 0
image_url = ''
image_generate_configs = self.cc.get("openai_image_generate", None)
while retry < 5:
try:
response: ImagesResponse = await self.client.images.generate(
prompt=prompt,
**image_generate_configs
)
image_url = []
for i in range(img_num):
image_url.append(response.data[i].url)
break
except Exception as e:
logger.warning(str(e))
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
logger.warning("当前 Key 已超额或者不正常, 正在切换")
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
elif 'Your request was rejected as a result of our safety system.' in str(e):
logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试")
raise e
else:
retry += 1
if retry >= 5:
raise BaseException("连接超时")
return image_url
async def forget(self, session_id=None) -> bool:
if session_id is None:
return False
self.session_dict[session_id] = []
return True
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
'''
获取缓存的会话
'''
prompts = ""
if paging:
page_begin = (page-1)*size
page_end = page*size
if page_begin < 0:
page_begin = 0
if page_end > len(cache_data_list):
page_end = len(cache_data_list)
cache_data_list = cache_data_list[page_begin:page_end]
for item in cache_data_list:
prompts += str(item['user']['role']) + ":\n" + \
str(item['user']['content']) + "\n"
prompts += str(item['AI']['role']) + ":\n" + \
str(item['AI']['content']) + "\n"
if divide:
prompts += "----------\n"
return prompts
def wrap(self, prompt, session_id, image_url=None):
if image_url is not None:
prompt = [
if image_url:
user_content = {
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
"url": await self.encode_image_bs64(image_url)
}
]
# 获得缓存信息
context = self.session_dict[session_id]
new_record = {
"user": {
"role": "user",
"content": prompt,
},
"AI": {},
'type': "common",
'usage_tokens': 0,
}
req_list = []
for i in context:
if 'user' in i:
req_list.append(i['user'])
if 'AI' in i:
req_list.append(i['AI'])
req_list.append(new_record['user'])
return context, new_record, req_list
else:
user_content = {
"role": "user",
"content": prompt
}
def handle_switch_key(self):
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
message["user"] = user_content
self.session_memory[session_id].append(message)
# 根据 模型的上下文窗口 淘汰掉多余的记录
curr_model = self.model_configs['model']
if curr_model in MODELS:
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
# if message['usage_tokens'] > maxium_tokens_num:
# 淘汰多余的记录,使得最终的 usage_tokens 不超过 maxium_tokens_num - 300
# contexts = self.session_memory[session_id]
# need_to_remove_idx = 0
# freed_tokens_num = contexts[0]['single-tokens']
# while freed_tokens_num < message['usage_tokens'] - maxium_tokens_num:
# need_to_remove_idx += 1
# freed_tokens_num += contexts[need_to_remove_idx]['single-tokens']
# # 更新之后的所有记录的 usage_tokens
# for i in range(len(contexts)):
# if i > need_to_remove_idx:
# contexts[i]['usage_tokens'] -= freed_tokens_num
# logger.debug(f"淘汰上下文记录 {need_to_remove_idx+1} 条,释放 {freed_tokens_num} 个 token。当前上下文总 token 为 {contexts[-1]['usage_tokens']}。")
# self.session_memory[session_id] = contexts[need_to_remove_idx+1:]
while len(self.session_memory[session_id]) and self.session_memory[session_id][-1]['usage_tokens'] > maxium_tokens_num:
self.pop_record(session_id)
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
continue
is_all_exceed = False
self.client.api_key = key
logger.warning(
f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})")
record = self.session_memory[session_id].pop(i)
break
if is_all_exceed:
logger.warning(
"所有 Key 已超额")
# 更新之后所有记录的 usage_tokens
for i in range(len(self.session_memory[session_id])):
self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens']
logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}")
return record
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None,
tools: None=None,
extra_conf: Dict = None,
**kwargs
) -> str:
if not session_id:
session_id = "unknown"
if "unknown" in self.session_memory:
del self.session_memory["unknown"]
if session_id not in self.session_memory:
self.session_memory[session_id] = []
self.personality_set(self.curr_personality, session_id)
# 如果 prompt 超过了最大窗口,截断。
# 1. 可以保证之后 pop 的时候不会出现问题
# 2. 可以保证不会超过最大 token 数
_encoded_prompt = self.tokenizer.encode(prompt)
curr_model = self.model_configs['model']
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
prompt = self.tokenizer.decode(_encoded_prompt)
# 组装上下文,并且根据当前上下文窗口大小截断
await self.assemble_context(session_id, prompt, image_url)
# 获取上下文openai 格式
contexts = await self.retrieve_context(session_id)
conf = self.model_configs
if extra_conf: conf.update(extra_conf)
# start request
retry = 0
rate_limit_retry = 0
while retry < 3 or rate_limit_retry < 5:
logger.debug(conf)
logger.debug(contexts)
if tools:
completion_coro = self.client.chat.completions.create(
messages=contexts,
tools=tools,
**conf
)
else:
completion_coro = self.client.chat.completions.create(
messages=contexts,
**conf
)
try:
completion = await completion_coro
break
except AuthenticationError as e:
api_key = self.chosen_api_key[10:] + "..."
logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key如果有的话")
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
except RateLimitError as e:
if "You exceeded your current quota" in e:
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
await self.switch_to_next_key()
rate_limit_retry += 1
time.sleep(1)
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。")
if "maximum context length" in str(e):
logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}")
choice = completion.choices[0]
usage_tokens = completion.usage.total_tokens
completion_tokens = completion.usage.completion_tokens
self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens
self.session_memory[session_id][-1]['single_tokens'] += completion_tokens
if choice.message.content:
# 返回文本
completion_text = str(choice.message.content).strip()
elif choice.message.tool_calls and choice.message.tool_calls:
# tools call (function calling)
return choice.message.tool_calls[0].function
self.session_memory[session_id][-1]['AI'] = {
"role": "assistant",
"content": completion_text
}
return completion_text
async def switch_to_next_key(self):
'''
切换到下一个 API Key
'''
if not self.api_keys:
logger.error("OpenAI API Key 不存在。")
return False
for key in self.keys_data:
if self.keys_data[key]:
# 没超额
self.chosen_api_key = key
self.client.api_key = key
logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。")
return True
return False
async def image_generate(self, prompt, session_id, **kwargs) -> str:
'''
生成图片
'''
retry = 0
conf = self.image_generator_model_configs
if not conf:
logger.error("OpenAI 图片生成模型配置不存在。")
raise Exception("OpenAI 图片生成模型配置不存在。")
while retry < 3:
try:
images_response = await self.client.images.generate(
prompt=prompt,
**conf
)
image_url = images_response.data[0].url
return image_url
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
if session_id is None: return False
self.session_memory[session_id] = []
if keep_system_prompt:
self.personality_set(self.curr_personality, session_id)
return True
def dump_contexts_page(self, size=5, page=1):
'''
获取缓存的会话
'''
contexts_str = ""
for i, key in enumerate(self.session_memory):
if i < (page-1)*size or i >= page*size:
continue
contexts_str += f"Session ID: {key}\n"
for record in self.session_memory[key]:
if "user" in record:
contexts_str += f"User: {record['user']['content']}\n"
if "AI" in record:
contexts_str += f"AI: {record['AI']['content']}\n"
contexts_str += "---\n"
return contexts_str, len(self.session_memory)
def get_configs(self):
return self.openai_configs
return self.model_configs
def get_key_stat(self):
return self.key_stat
def get_key_list(self):
return self.key_list
def get_keys_data(self):
return self.keys_data
def get_curr_key(self):
return self.client.api_key
return self.chosen_api_key
def set_key(self, key):
self.client.api_key = key
# 添加key
def append_key(self, key, sponsor):
self.key_list.append(key)
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
# 检查key是否可用
async def check_key(self, key):
client_ = AsyncOpenAI(
api_key=key,
base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
await client_.chat.completions.create(
messages=messages,
**self.openai_model_configs
)
return True

View File

@@ -1,66 +0,0 @@
import os
import sys
import json
import time
import tiktoken
import threading
import traceback
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from cores.database.conn import dbConn
from model.provider.provider import Provider
from util import general_utils as gu
from util.cmd_config import CmdConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg) -> None:
super().__init__()
os.makedirs("data/openai", exist_ok=True)
self.cc = CmdConfig
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.keys_data = {
"keys": []
}
if cfg['key']: self.api_keys = cfg['key']
if cfg['api_base']: self.base_url = cfg['api_base']
if not self.api_keys:
logger.warn("看起来你没有添加 OpenAI 的 API 密钥OpenAI LLM 能力将不会启用。")
else:
self.chosen_api_key = self.api_keys[0]
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=self.base_url
)
self.model_configs: dict = cfg['chatGPTConfigs']
self.session_memory = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
# 从 SQLite DB 读取历史记录
try:
db1 = dbConn()
for session in db1.get_all_session():
self.session_memory_lock.acquire()
self.session_memory[session[0]] = json.loads(session[1])['data']
self.session_memory_lock.release()
except BaseException as e:
logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。")
# 定时保存历史记录
threading.Thread(target=self.dump_history, daemon=True).start()

View File

@@ -0,0 +1,409 @@
import os
import sys
import json
import time
import tiktoken
import threading
import traceback
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from cores.database.conn import dbConn
from model.provider.provider import Provider
from util import general_utils as gu
from util.cmd_config import CmdConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg):
self.cc = CmdConfig()
self.key_list = []
# 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错
for key in cfg['key']:
if len(key) == 1:
raise BaseException(
"检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。")
if cfg['key'] != '' and cfg['key'] != None:
self.key_list = cfg['key']
if len(self.key_list) == 0:
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
self.key_stat = {}
for k in self.key_list:
self.key_stat[k] = {'exceed': False, 'used': 0}
self.api_base = None
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
self.api_base = cfg['api_base']
logger.info(f"设置 api_base 为: {self.api_base}")
# 创建 OpenAI Client
self.client = AsyncOpenAI(
api_key=self.key_list[0],
base_url=self.api_base
)
self.openai_model_configs: dict = cfg['chatGPTConfigs']
self.openai_configs = cfg
# 会话缓存
self.session_dict = {}
# 最大缓存token
self.max_tokens = cfg['total_tokens_limit']
# 历史记录持久化间隔时间
self.history_dump_interval = 20
self.enc = tiktoken.get_encoding("cl100k_base")
# 从 SQLite DB 读取历史记录
try:
db1 = dbConn()
for session in db1.get_all_session():
self.session_dict[session[0]] = json.loads(session[1])['data']
logger.info("读取历史记录成功。")
except BaseException as e:
logger.info("读取历史记录失败,但不影响使用。")
# 创建转储定时器线程
threading.Thread(target=self.dump_history, daemon=True).start()
# 人格
self.curr_personality = {}
def make_tmp_client(self, api_key: str, base_url: str):
return AsyncOpenAI(
api_key=api_key,
base_url=base_url
)
# 转储历史记录
def dump_history(self):
time.sleep(10)
db = dbConn()
while True:
try:
# print("转储历史记录...")
for key in self.session_dict:
data = self.session_dict[key]
data_json = {
'data': data
}
if db.check_session(key):
db.update_session(key, json.dumps(data_json))
else:
db.insert_session(key, json.dumps(data_json))
# print("转储历史记录完毕")
except BaseException as e:
print(e)
# 每隔10分钟转储一次
time.sleep(10*self.history_dump_interval)
def personality_set(self, default_personality: dict, session_id: str):
self.curr_personality = default_personality
new_record = {
"user": {
"role": "user",
"content": default_personality['prompt'],
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.session_dict[session_id].append(new_record)
async def text_chat(self, prompt,
session_id=None,
image_url=None,
function_call=None,
extra_conf: dict = None,
default_personality: dict = None):
if session_id is None:
session_id = "unknown"
if "unknown" in self.session_dict:
del self.session_dict["unknown"]
# 会话机制
if session_id not in self.session_dict:
self.session_dict[session_id] = []
if len(self.session_dict[session_id]) == 0:
# 设置默认人格
if default_personality is not None:
self.personality_set(default_personality, session_id)
# 使用 tictoken 截断消息
_encoded_prompt = self.enc.encode(prompt)
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
prompt = self.enc.decode(_encoded_prompt[:int(
self.openai_model_configs['max_tokens']*0.80)])
logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。")
cache_data_list, new_record, req = self.wrap(
prompt, session_id, image_url)
logger.debug(f"cache: {str(cache_data_list)}")
logger.debug(f"request: {str(req)}")
retry = 0
response = None
err = ''
# 截断倍率
truncate_rate = 0.75
conf = self.openai_model_configs
if extra_conf is not None:
conf.update(extra_conf)
while retry < 10:
try:
if function_call is None:
response = await self.client.chat.completions.create(
messages=req,
**conf
)
else:
response = await self.client.chat.completions.create(
messages=req,
tools=function_call,
**conf
)
break
except Exception as e:
traceback.print_exc()
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
logger.info("当前 Key 已超额或异常, 正在切换",
)
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
retry -= 1
elif 'maximum context length' in str(e):
logger.info("token 超限, 清空对应缓存,并进行消息截断")
self.session_dict[session_id] = []
prompt = prompt[:int(len(prompt)*truncate_rate)]
truncate_rate -= 0.05
cache_data_list, new_record, req = self.wrap(
prompt, session_id)
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
time.sleep(30)
continue
else:
logger.error(str(e))
time.sleep(2)
err = str(e)
retry += 1
if retry >= 10:
logger.warning(
r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki")
raise BaseException("连接出错: "+str(err))
assert isinstance(response, ChatCompletion)
logger.debug(
f"OPENAI RESPONSE: {response.usage}")
# 结果分类
choice = response.choices[0]
if choice.message.content != None:
# 文本形式
chatgpt_res = str(choice.message.content).strip()
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
# tools call (function calling)
return choice.message.tool_calls[0].function
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
current_usage_tokens = response.usage.total_tokens
# 超过指定tokens 尽可能的保留最多的条目直到小于max_tokens
if current_usage_tokens > self.max_tokens:
t = current_usage_tokens
index = 0
while t > self.max_tokens:
if index >= len(cache_data_list):
break
# 保留人格信息
if cache_data_list[index]['type'] != 'personality':
t -= int(cache_data_list[index]['single_tokens'])
del cache_data_list[index]
else:
index += 1
# 删除完后更新相关字段
self.session_dict[session_id] = cache_data_list
# 添加新条目进入缓存的prompt
new_record['AI'] = {
'role': 'assistant',
'content': chatgpt_res,
}
new_record['usage_tokens'] = current_usage_tokens
if len(cache_data_list) > 0:
new_record['single_tokens'] = current_usage_tokens - \
int(cache_data_list[-1]['usage_tokens'])
else:
new_record['single_tokens'] = current_usage_tokens
cache_data_list.append(new_record)
self.session_dict[session_id] = cache_data_list
return chatgpt_res
async def image_chat(self, prompt, img_num=1, img_size="1024x1024"):
retry = 0
image_url = ''
image_generate_configs = self.cc.get("openai_image_generate", None)
while retry < 5:
try:
response: ImagesResponse = await self.client.images.generate(
prompt=prompt,
**image_generate_configs
)
image_url = []
for i in range(img_num):
image_url.append(response.data[i].url)
break
except Exception as e:
logger.warning(str(e))
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
logger.warning("当前 Key 已超额或者不正常, 正在切换")
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
elif 'Your request was rejected as a result of our safety system.' in str(e):
logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试")
raise e
else:
retry += 1
if retry >= 5:
raise BaseException("连接超时")
return image_url
async def forget(self, session_id=None) -> bool:
if session_id is None:
return False
self.session_dict[session_id] = []
return True
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
'''
获取缓存的会话
'''
prompts = ""
if paging:
page_begin = (page-1)*size
page_end = page*size
if page_begin < 0:
page_begin = 0
if page_end > len(cache_data_list):
page_end = len(cache_data_list)
cache_data_list = cache_data_list[page_begin:page_end]
for item in cache_data_list:
prompts += str(item['user']['role']) + ":\n" + \
str(item['user']['content']) + "\n"
prompts += str(item['AI']['role']) + ":\n" + \
str(item['AI']['content']) + "\n"
if divide:
prompts += "----------\n"
return prompts
def wrap(self, prompt, session_id, image_url=None):
if image_url is not None:
prompt = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
# 获得缓存信息
context = self.session_dict[session_id]
new_record = {
"user": {
"role": "user",
"content": prompt,
},
"AI": {},
'type': "common",
'usage_tokens': 0,
}
req_list = []
for i in context:
if 'user' in i:
req_list.append(i['user'])
if 'AI' in i:
req_list.append(i['AI'])
req_list.append(new_record['user'])
return context, new_record, req_list
def handle_switch_key(self):
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
continue
is_all_exceed = False
self.client.api_key = key
logger.warning(
f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})")
break
if is_all_exceed:
logger.warning(
"所有 Key 已超额")
return False
return True
def get_configs(self):
return self.openai_configs
def get_key_stat(self):
return self.key_stat
def get_key_list(self):
return self.key_list
def get_curr_key(self):
return self.client.api_key
def set_key(self, key):
self.client.api_key = key
# 添加key
def append_key(self, key, sponsor):
self.key_list.append(key)
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
# 检查key是否可用
async def check_key(self, key):
client_ = AsyncOpenAI(
api_key=key,
base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
await client_.chat.completions.create(
messages=messages,
**self.openai_model_configs
)
return True

View File

@@ -3,7 +3,7 @@ class Provider:
prompt: str,
session_id: str,
image_url: None,
function_call: None,
tools: None,
extra_conf: dict = None,
default_personality: dict = None,
**kwargs) -> str:
@@ -14,7 +14,7 @@ class Provider:
[optional]
image_url: 图片url识图
function_call: 函数调用
tools: 函数调用工具
extra_conf: 额外配置
default_personality: 默认人格
'''

View File

@@ -1,19 +1,23 @@
import datetime
import time
import socket
from PIL import Image, ImageDraw, ImageFont
import os
import re
import requests
from util.cmd_config import CmdConfig
import aiohttp
import socket
from cores.astrbot.types import GlobalObject
import platform
import logging
import json
import sys
import psutil
from PIL import Image, ImageDraw, ImageFont
from cores.astrbot.types import GlobalObject
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
def port_checker(port: int, host: str = "localhost"):
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sk.settimeout(1)
@@ -355,10 +359,23 @@ def save_temp_img(img: Image) -> str:
# 获得时间戳
timestamp = int(time.time())
p = f"temp/{timestamp}.png"
p = f"temp/{timestamp}.jpg"
img.save(p)
return p
async def download_image_by_url(url: str) -> str:
'''
下载图片
'''
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
img = Image.open(await resp.read())
p = save_temp_img(img)
return p
except Exception as e:
raise e
def create_text_image(title: str, text: str, max_width=30, font_size=20):
'''
@@ -455,6 +472,21 @@ def upload(_global_object: GlobalObject):
pass
time.sleep(10*60)
def retry(n: int = 3):
'''
重试装饰器
'''
def decorator(func):
def wrapper(*args, **kwargs):
for i in range(n):
try:
return func(*args, **kwargs)
except Exception as e:
if i == n-1: raise e
logger.warning(f"函数 {func.__name__}{i+1} 次重试... {e}")
return wrapper
return decorator
def run_monitor(global_object: GlobalObject):
'''