diff --git a/cores/astrbot/core.py b/cores/astrbot/core.py index 611ade77..7576ca08 100644 --- a/cores/astrbot/core.py +++ b/cores/astrbot/core.py @@ -84,7 +84,7 @@ def init(cfg): global _global_object # 迁移旧配置 - gu.try_migrate_config(cfg) + gu.try_migrate_config() # 使用新配置 cfg = cc.get_all() @@ -105,6 +105,15 @@ def init(cfg): cc.put("reply_prefix", "") else: _global_object.reply_prefix = cfg['reply_prefix'] + + default_personality_str = cc.get("default_personality_str", "") + if default_personality_str == "": + _global_object.default_personality = None + else: + _global_object.default_personality = { + "name": "default", + "prompt": default_personality_str, + } # 语言模型提供商 logger.info("正在载入语言模型...") @@ -122,6 +131,10 @@ def init(cfg): llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal")) chosen_provider = OPENAI_OFFICIAL + instance = llm_instance[OPENAI_OFFICIAL] + assert isinstance(instance, ProviderOpenAIOfficial) + instance.personality_set(_global_object.default_personality, session_id=None) + # 检查provider设置偏好 p = cc.get("chosen_provider", None) if p is not None and p in llm_instance: @@ -197,14 +210,6 @@ def init(cfg): cfg, _global_object), daemon=True).start() platform_str += "QQ_OFFICIAL," - default_personality_str = cc.get("default_personality_str", "") - if default_personality_str == "": - _global_object.default_personality = None - else: - _global_object.default_personality = { - "name": "default", - "prompt": default_personality_str, - } # 初始化dashboard _global_object.dashboard_data = DashBoardData( stats={}, @@ -430,12 +435,12 @@ async def oper_msg(message: AstrBotMessage, official_fc = chosen_provider == OPENAI_OFFICIAL llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) else: - llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality=_global_object.default_personality) + llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url) llm_result_str = _global_object.reply_prefix + llm_result_str except BaseException as e: - logger.info(f"调用异常:{traceback.format_exc()}") - return MessageResult(f"调用语言模型例程时出现异常。原因: {str(e)}") + logger.error(f"调用异常:{traceback.format_exc()}") + return MessageResult(f"调用异常。详细原因:{str(e)}") # 切换回原来的语言模型 if temp_switch != "": @@ -458,14 +463,10 @@ async def oper_msg(message: AstrBotMessage, return MessageResult(f"指令调用错误: \n{str(command_result[1])}") # 画图指令 - if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw': - for i in command_result[1]: - # 保存到本地 - async with aiohttp.ClientSession() as session: - async with session.get(i) as resp: - if resp.status == 200: - image = PILImage.open(io.BytesIO(await resp.read())) - return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))]) + if command == 'draw': + # 保存到本地 + path = await gu.download_image_by_url(command_result[1]) + return MessageResult([Image.fromFileSystem(path)]) # 其他指令 else: try: diff --git a/model/command/command.py b/model/command/command.py index 362cc54c..99eebb4a 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -223,8 +223,6 @@ class Command: "nick": "设置机器人昵称", "plugin": "插件安装、卸载和重载", "web on/off": "LLM 网页搜索能力", - "reset": "重置 LLM 对话", - "/gpt": "切换到 OpenAI 官方接口" } async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None): diff --git a/model/command/openai_official.py b/model/command/openai_official.py index 27e8de3e..fe686dbc 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -16,8 +16,7 @@ class CommandOpenAIOfficial(Command): self.commands = [ CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"), CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"), - CommandItem("status", self.gpt, "查看 GPT 配置信息和用量状态。", "内置"), - + CommandItem("status", self.status, "查看 GPT 配置信息和用量状态。", "内置"), ] super().__init__(provider, global_object) @@ -59,8 +58,6 @@ class CommandOpenAIOfficial(Command): return True, self.update(message, role) elif self.command_start_with(message, "画", "draw"): return True, await self.draw(message) - elif self.command_start_with(message, "key"): - return True, self.key(message) elif self.command_start_with(message, "switch"): return True, await self.switch(message) elif self.command_start_with(message, "models"): @@ -87,12 +84,13 @@ class CommandOpenAIOfficial(Command): async def help(self): commands = super().general_commands() - commands['画'] = '画画' - commands['key'] = '添加OpenAI key' + commands['画'] = '调用 OpenAI DallE 模型生成图片' commands['set'] = '人格设置面板' - commands['gpt'] = '查看gpt配置信息' - commands['status'] = '查看key使用状态' - commands['token'] = '查看本轮会话token' + commands['status'] = '查看 Api Key 状态和配置信息' + commands['token'] = '查看本轮会话 token' + commands['reset'] = '重置当前与 LLM 的会话' + commands['reset p'] = '重置当前与 LLM 的会话,但保留人格(system prompt)' + return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" async def reset(self, session_id: str, message: str = "reset"): @@ -103,66 +101,34 @@ class CommandOpenAIOfficial(Command): await self.provider.forget(session_id) return True, "重置成功", "reset" if len(l) == 2 and l[1] == "p": - self.provider.forget(session_id) - if self.personality_str != "": - self.set(self.personality_str, session_id) # 重新设置人格 - return True, "重置成功", "reset" + await self.provider.forget(session_id, keep_system_prompt=True) def his(self, message: str, session_id: str): if self.provider is None: return False, "未启用 OpenAI 官方 API", "his" - # 分页,每页5条 - msg = '' size_per_page = 3 page = 1 - if message[4:]: - page = int(message[4:]) - # 检查是否有过历史记录 - if session_id not in self.provider.session_dict: - msg = f"历史记录为空" - return True, msg, "his" - l = self.provider.session_dict[session_id] - max_page = len(l)//size_per_page + \ - 1 if len(l) % size_per_page != 0 else len(l)//size_per_page - p = self.provider.get_prompts_by_cache_list( - self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page) - return True, f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页", "his" + l = message.split(" ") + if len(l) == 2: + try: + page = int(l[1]) + except BaseException as e: + return True, "页码不合法", "his" + contexts, total_num = self.provider.dump_contexts_page(size_per_page, page=page) + t_pages = total_num // size_per_page + 1 + return True, f"历史记录如下:\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n*输入 /his 2 跳转到第 2 页", "his" def status(self): if self.provider is None: return False, "未启用 OpenAI 官方 API", "status" - chatgpt_cfg_str = "" - key_stat = self.provider.get_key_stat() - index = 1 - max = 9000000 - gg_count = 0 - total = 0 - tag = '' - for key in key_stat.keys(): - sponsor = '' - total += key_stat[key]['used'] - if key_stat[key]['exceed']: - gg_count += 1 - continue - if 'sponsor' in key_stat[key]: - sponsor = key_stat[key]['sponsor'] - chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n" - index += 1 - return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status" + keys_data = self.provider.get_keys_data() + ret = "OpenAI Key" + for k in keys_data: + status = "🟢" if keys_data[k]['status'] == 0 else "🔴" + ret += "\n|- " + k[:8] + " " + status - def key(self, message: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "reset" - l = message.split(" ") - if len(l) == 1: - msg = "感谢您赞助key,key为官方API使用,请以以下格式赞助:\n/key xxxxx" - return True, msg, "key" - key = l[1] - if self.provider.check_key(key): - self.provider.append_key(key) - return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~" - else: - return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key" + conf = self.provider.get_configs() + ret += "\n当前模型:" + conf['model'] async def switch(self, message: str): ''' @@ -179,14 +145,13 @@ class CommandOpenAIOfficial(Command): return True, ret, "switch" elif len(l) == 2: try: - key_stat = self.provider.get_key_stat() + key_stat = self.provider.get_keys_data() index = int(l[1]) if index > len(key_stat) or index < 1: return True, "账号序号不合法。", "switch" else: try: new_key = list(key_stat.keys())[index-1] - ret = await self.provider.check_key(new_key) self.provider.set_key(new_key) except BaseException as e: return True, "账号切换失败,原因: " + str(e), "switch" @@ -235,58 +200,22 @@ class CommandOpenAIOfficial(Command): 'name': ps, 'prompt': personalities[ps] } - self.provider.session_dict[session_id] = [] - new_record = { - "user": { - "role": "user", - "content": personalities[ps], - }, - "AI": { - "role": "assistant", - "content": "好的,接下来我会扮演这个角色。" - }, - 'type': "personality", - 'usage_tokens': 0, - 'single-tokens': 0 - } - self.provider.session_dict[session_id].append(new_record) - self.personality_str = message + self.provider.personality_set(ps, session_id) return True, f"人格{ps}已设置。", "set" else: self.provider.curr_personality = { 'name': '自定义人格', 'prompt': ps } - new_record = { - "user": { - "role": "user", - "content": ps, - }, - "AI": { - "role": "assistant", - "content": "好的,接下来我会扮演这个角色。" - }, - 'type': "personality", - 'usage_tokens': 0, - 'single-tokens': 0 - } - self.provider.session_dict[session_id] = [] - self.provider.session_dict[session_id].append(new_record) - self.personality_str = message + self.provider.personality_set(ps, session_id) return True, f"自定义人格已设置。 \n人格信息: {ps}", "set" - async def draw(self, message): + async def draw(self, message: str): if self.provider is None: return False, "未启用 OpenAI 官方 API", "draw" if message.startswith("/画"): message = message[2:] elif message.startswith("画"): message = message[1:] - try: - # 画图模式传回3个参数 - img_url = await self.provider.image_chat(message) - return True, img_url, "draw" - except Exception as e: - if 'exceeded' in str(e): - return f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)" - return False, f"图片生成失败: {e}", "draw" + img_url = await self.provider.image_generate(message) + return True, img_url, "draw" \ No newline at end of file diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 72f1c2f8..6ff53218 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -5,10 +5,12 @@ import time import tiktoken import threading import traceback +import base64 from openai import AsyncOpenAI from openai.types.images_response import ImagesResponse from openai.types.chat.chat_completion import ChatCompletion +from openai._exceptions import * from cores.database.conn import dbConn from model.provider.provider import Provider @@ -16,84 +18,94 @@ from util import general_utils as gu from util.cmd_config import CmdConfig from SparkleLogging.utils.core import LogManager from logging import Logger +from typing import List, Dict logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - -abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' - +MODELS = { + "gpt-4o": 128000, + "gpt-4o-2024-05-13": 128000, + "gpt-4-turbo": 128000, + "gpt-4-turbo-2024-04-09": 128000, + "gpt-4-turbo-preview": 128000, + "gpt-4-0125-preview": 128000, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4-1106-vision-preview": 128000, + "gpt-4": 8192, + "gpt-4-0613": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0613": 32768, + "gpt-3.5-turbo-0125": 16385, + "gpt-3.5-turbo": 16385, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-0613": 16385, + "gpt-3.5-turbo-16k-0613": 16385, +} class ProviderOpenAIOfficial(Provider): - def __init__(self, cfg): - self.cc = CmdConfig() + def __init__(self, cfg) -> None: + super().__init__() - self.key_list = [] - # 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错 - for key in cfg['key']: - if len(key) == 1: - raise BaseException( - "检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。") - if cfg['key'] != '' and cfg['key'] != None: - self.key_list = cfg['key'] - if len(self.key_list) == 0: - raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。") + os.makedirs("data/openai", exist_ok=True) - self.key_stat = {} - for k in self.key_list: - self.key_stat[k] = {'exceed': False, 'used': 0} + self.cc = CmdConfig + self.key_data_path = "data/openai/keys.json" + self.api_keys = [] + self.chosen_api_key = None + self.base_url = None + self.keys_data = {} # 记录超额 - self.api_base = None - if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '': - self.api_base = cfg['api_base'] - logger.info(f"设置 api_base 为: {self.api_base}") + if cfg['key']: self.api_keys = cfg['key'] + if cfg['api_base']: self.base_url = cfg['api_base'] + if not self.api_keys: + logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。") + else: + self.chosen_api_key = self.api_keys[0] + + for key in self.api_keys: + self.keys_data[key] = True - # 创建 OpenAI Client self.client = AsyncOpenAI( - api_key=self.key_list[0], - base_url=self.api_base + api_key=self.chosen_api_key, + base_url=self.base_url ) - - self.openai_model_configs: dict = cfg['chatGPTConfigs'] - self.openai_configs = cfg - # 会话缓存 - self.session_dict = {} - # 最大缓存token - self.max_tokens = cfg['total_tokens_limit'] - # 历史记录持久化间隔时间 - self.history_dump_interval = 20 - - self.enc = tiktoken.get_encoding("cl100k_base") + self.model_configs: Dict = cfg['chatGPTConfigs'] + self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None) + self.session_memory: Dict[str, List] = {} # 会话记忆 + self.session_memory_lock = threading.Lock() + self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 + self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器 + self.curr_personality = { + "name": "default", + "prompt": "你是一个很有帮助的 AI 助手。" + } # 从 SQLite DB 读取历史记录 try: db1 = dbConn() for session in db1.get_all_session(): - self.session_dict[session[0]] = json.loads(session[1])['data'] - logger.info("读取历史记录成功。") + self.session_memory_lock.acquire() + self.session_memory[session[0]] = json.loads(session[1])['data'] + self.session_memory_lock.release() except BaseException as e: - logger.info("读取历史记录失败,但不影响使用。") - - # 创建转储定时器线程 + logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。") + + # 定时保存历史记录 threading.Thread(target=self.dump_history, daemon=True).start() - # 人格 - self.curr_personality = {} - - def make_tmp_client(self, api_key: str, base_url: str): - return AsyncOpenAI( - api_key=api_key, - base_url=base_url - ) - - # 转储历史记录 def dump_history(self): + ''' + 转储历史记录 + ''' time.sleep(10) db = dbConn() while True: try: - # print("转储历史记录...") - for key in self.session_dict: - data = self.session_dict[key] + for key in self.session_memory: + data = self.session_memory[key] data_json = { 'data': data } @@ -101,309 +113,338 @@ class ProviderOpenAIOfficial(Provider): db.update_session(key, json.dumps(data_json)) else: db.insert_session(key, json.dumps(data_json)) - # print("转储历史记录完毕") + logger.debug("已保存 OpenAI 会话历史记录") except BaseException as e: print(e) - # 每隔10分钟转储一次 - time.sleep(10*self.history_dump_interval) - + finally: + time.sleep(10*60) + def personality_set(self, default_personality: dict, session_id: str): + if not default_personality: return + if session_id not in self.session_memory: self.session_memory[session_id] = [] self.curr_personality = default_personality + encoded_prompt = self.tokenizer.encode(default_personality['prompt']) + tokens_num = len(encoded_prompt) + model = self.model_configs['model'] + if model in MODELS and tokens_num > MODELS[model] - 800: + default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 800]) + new_record = { "user": { - "role": "user", + "role": "system", "content": default_personality['prompt'], }, - "AI": { - "role": "assistant", - "content": "好的,接下来我会扮演这个角色。" - }, - 'type': "personality", - 'usage_tokens': 0, - 'single-tokens': 0 + 'usage_tokens': 0, # 到该条目的总 token 数 + 'single-tokens': 0 # 该条目的 token 数 } - self.session_dict[session_id].append(new_record) - async def text_chat(self, prompt, - session_id=None, - image_url=None, - function_call=None, - extra_conf: dict = None, - default_personality: dict = None): - if session_id is None: + self.session_memory[session_id].append(new_record) + + async def encode_image_bs64(self, image_url: str) -> str: + ''' + 将图片转换为 base64 + ''' + if image_url.startswith("http"): + image_url = await gu.download_image_by_url(image_url) + + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode() + return "data:image/jpeg;base64," + image_bs64 + + async def retrieve_context(self, session_id: str): + ''' + 根据 session_id 获取保存的 OpenAI 格式的上下文 + ''' + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + + # 转换为 openai 要求的格式 + context = [] + for record in self.session_memory[session_id]: + if "user" in record and record['user']: + context.append(record['user']) + if "AI" in record and record['AI']: + context.append(record['AI']) + + return context + + async def get_models(self): + ''' + 获取所有模型 + ''' + models = await self.client.models.list() + logger.info(f"OpenAI 模型列表:{models}") + return models + + async def assemble_context(self, session_id: str, prompt: str, image_url: str = None): + ''' + 组装上下文,并且根据当前上下文窗口大小截断 + ''' + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + + tokens_num = len(self.tokenizer.encode(prompt)) + previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens'] + + message = { + "usage_tokens": previous_total_tokens_num + tokens_num, + "single_tokens": tokens_num, + "AI": None + } + if image_url: + user_content = { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "url": await self.encode_image_bs64(image_url) + } + ] + } + else: + user_content = { + "role": "user", + "content": prompt + } + + message["user"] = user_content + self.session_memory[session_id].append(message) + + # 根据 模型的上下文窗口 淘汰掉多余的记录 + curr_model = self.model_configs['model'] + if curr_model in MODELS: + maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion + # if message['usage_tokens'] > maxium_tokens_num: + # 淘汰多余的记录,使得最终的 usage_tokens 不超过 maxium_tokens_num - 300 + # contexts = self.session_memory[session_id] + # need_to_remove_idx = 0 + # freed_tokens_num = contexts[0]['single-tokens'] + # while freed_tokens_num < message['usage_tokens'] - maxium_tokens_num: + # need_to_remove_idx += 1 + # freed_tokens_num += contexts[need_to_remove_idx]['single-tokens'] + # # 更新之后的所有记录的 usage_tokens + # for i in range(len(contexts)): + # if i > need_to_remove_idx: + # contexts[i]['usage_tokens'] -= freed_tokens_num + # logger.debug(f"淘汰上下文记录 {need_to_remove_idx+1} 条,释放 {freed_tokens_num} 个 token。当前上下文总 token 为 {contexts[-1]['usage_tokens']}。") + # self.session_memory[session_id] = contexts[need_to_remove_idx+1:] + while len(self.session_memory[session_id]) and self.session_memory[session_id][-1]['usage_tokens'] > maxium_tokens_num: + self.pop_record(session_id) + + + async def pop_record(self, session_id: str, pop_system_prompt: bool = False): + ''' + 弹出第一条记录 + ''' + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + + if len(self.session_memory[session_id]) == 0: + return None + + for i in range(len(self.session_memory[session_id])): + # 检查是否是 system prompt + if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system": + continue + record = self.session_memory[session_id].pop(i) + break + + # 更新之后所有记录的 usage_tokens + for i in range(len(self.session_memory[session_id])): + self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens'] + logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}。") + return record + + async def text_chat(self, + prompt: str, + session_id: str, + image_url: None, + tools: None=None, + extra_conf: Dict = None, + **kwargs + ) -> str: + if not session_id: session_id = "unknown" - if "unknown" in self.session_dict: - del self.session_dict["unknown"] - # 会话机制 - if session_id not in self.session_dict: - self.session_dict[session_id] = [] + if "unknown" in self.session_memory: + del self.session_memory["unknown"] - if len(self.session_dict[session_id]) == 0: - # 设置默认人格 - if default_personality is not None: - self.personality_set(default_personality, session_id) + if session_id not in self.session_memory: + self.session_memory[session_id] = [] + self.personality_set(self.curr_personality, session_id) + + # 如果 prompt 超过了最大窗口,截断。 + # 1. 可以保证之后 pop 的时候不会出现问题 + # 2. 可以保证不会超过最大 token 数 + _encoded_prompt = self.tokenizer.encode(prompt) + curr_model = self.model_configs['model'] + if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300: + _encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300] + prompt = self.tokenizer.decode(_encoded_prompt) + + # 组装上下文,并且根据当前上下文窗口大小截断 + await self.assemble_context(session_id, prompt, image_url) - # 使用 tictoken 截断消息 - _encoded_prompt = self.enc.encode(prompt) - if self.openai_model_configs['max_tokens'] < len(_encoded_prompt): - prompt = self.enc.decode(_encoded_prompt[:int( - self.openai_model_configs['max_tokens']*0.80)]) - logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。") + # 获取上下文,openai 格式 + contexts = await self.retrieve_context(session_id) - cache_data_list, new_record, req = self.wrap( - prompt, session_id, image_url) - logger.debug(f"cache: {str(cache_data_list)}") - logger.debug(f"request: {str(req)}") + conf = self.model_configs + if extra_conf: conf.update(extra_conf) + + # start request retry = 0 - response = None - err = '' - - # 截断倍率 - truncate_rate = 0.75 - - conf = self.openai_model_configs - if extra_conf is not None: - conf.update(extra_conf) - - while retry < 10: + rate_limit_retry = 0 + while retry < 3 or rate_limit_retry < 5: + logger.debug(conf) + logger.debug(contexts) + if tools: + completion_coro = self.client.chat.completions.create( + messages=contexts, + tools=tools, + **conf + ) + else: + completion_coro = self.client.chat.completions.create( + messages=contexts, + **conf + ) try: - if function_call is None: - response = await self.client.chat.completions.create( - messages=req, - **conf - ) - else: - response = await self.client.chat.completions.create( - messages=req, - tools=function_call, - **conf - ) + completion = await completion_coro break + except AuthenticationError as e: + api_key = self.chosen_api_key[10:] + "..." + logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key(如果有的话)") + self.keys_data[self.chosen_api_key] = False + ok = await self.switch_to_next_key() + if ok: continue + else: raise Exception("所有 OpenAI API Key 目前都不可用。") + + except RateLimitError as e: + if "You exceeded your current quota" in e: + self.keys_data[self.chosen_api_key] = False + ok = await self.switch_to_next_key() + if ok: continue + else: raise Exception("所有 OpenAI API Key 目前都不可用。") + logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}") + await self.switch_to_next_key() + rate_limit_retry += 1 + time.sleep(1) except Exception as e: - traceback.print_exc() - if 'Invalid content type. image_url is only supported by certain models.' in str(e): - raise e - if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - logger.info("当前 Key 已超额或异常, 正在切换", - ) - self.key_stat[self.client.api_key]['exceed'] = True - is_switched = self.handle_switch_key() - if not is_switched: - raise e - retry -= 1 - elif 'maximum context length' in str(e): - logger.info("token 超限, 清空对应缓存,并进行消息截断") - self.session_dict[session_id] = [] - prompt = prompt[:int(len(prompt)*truncate_rate)] - truncate_rate -= 0.05 - cache_data_list, new_record, req = self.wrap( - prompt, session_id) - - elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e): - time.sleep(30) - continue - else: - logger.error(str(e)) - time.sleep(2) - err = str(e) retry += 1 - if retry >= 10: - logger.warning( - r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki") - raise BaseException("连接出错: "+str(err)) - assert isinstance(response, ChatCompletion) - logger.debug( - f"OPENAI RESPONSE: {response.usage}") + if retry >= 3: + logger.error(traceback.format_exc()) + raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。") + if "maximum context length" in str(e): + logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") + self.pop_record(session_id) - # 结果分类 - choice = response.choices[0] - if choice.message.content != None: - # 文本形式 - chatgpt_res = str(choice.message.content).strip() - elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0: + logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。") + time.sleep(1) + + assert isinstance(completion, ChatCompletion) + logger.debug(f"openai completion: {completion.usage}") + + choice = completion.choices[0] + + usage_tokens = completion.usage.total_tokens + completion_tokens = completion.usage.completion_tokens + self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens + self.session_memory[session_id][-1]['single_tokens'] += completion_tokens + + if choice.message.content: + # 返回文本 + completion_text = str(choice.message.content).strip() + elif choice.message.tool_calls and choice.message.tool_calls: # tools call (function calling) return choice.message.tool_calls[0].function - - self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens - current_usage_tokens = response.usage.total_tokens - - # 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens - if current_usage_tokens > self.max_tokens: - t = current_usage_tokens - index = 0 - while t > self.max_tokens: - if index >= len(cache_data_list): - break - # 保留人格信息 - if cache_data_list[index]['type'] != 'personality': - t -= int(cache_data_list[index]['single_tokens']) - del cache_data_list[index] - else: - index += 1 - # 删除完后更新相关字段 - self.session_dict[session_id] = cache_data_list - - # 添加新条目进入缓存的prompt - new_record['AI'] = { - 'role': 'assistant', - 'content': chatgpt_res, + + self.session_memory[session_id][-1]['AI'] = { + "role": "assistant", + "content": completion_text } - new_record['usage_tokens'] = current_usage_tokens - if len(cache_data_list) > 0: - new_record['single_tokens'] = current_usage_tokens - \ - int(cache_data_list[-1]['usage_tokens']) - else: - new_record['single_tokens'] = current_usage_tokens - cache_data_list.append(new_record) - - self.session_dict[session_id] = cache_data_list - - return chatgpt_res - - async def image_chat(self, prompt, img_num=1, img_size="1024x1024"): - retry = 0 - image_url = '' - - image_generate_configs = self.cc.get("openai_image_generate", None) - - while retry < 5: - try: - response: ImagesResponse = await self.client.images.generate( - prompt=prompt, - **image_generate_configs - ) - image_url = [] - for i in range(img_num): - image_url.append(response.data[i].url) - break - except Exception as e: - logger.warning(str(e)) - if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str( - e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - logger.warning("当前 Key 已超额或者不正常, 正在切换") - self.key_stat[self.client.api_key]['exceed'] = True - is_switched = self.handle_switch_key() - if not is_switched: - raise e - elif 'Your request was rejected as a result of our safety system.' in str(e): - logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试") - raise e - else: - retry += 1 - if retry >= 5: - raise BaseException("连接超时") - - return image_url - - async def forget(self, session_id=None) -> bool: - if session_id is None: + return completion_text + + async def switch_to_next_key(self): + ''' + 切换到下一个 API Key + ''' + if not self.api_keys: + logger.error("OpenAI API Key 不存在。") return False - self.session_dict[session_id] = [] - return True - def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1): + for key in self.keys_data: + if self.keys_data[key]: + # 没超额 + self.chosen_api_key = key + self.client.api_key = key + logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。") + return True + + return False + + async def image_generate(self, prompt, session_id, **kwargs) -> str: + ''' + 生成图片 + ''' + retry = 0 + conf = self.image_generator_model_configs + if not conf: + logger.error("OpenAI 图片生成模型配置不存在。") + raise Exception("OpenAI 图片生成模型配置不存在。") + + while retry < 3: + try: + images_response = await self.client.images.generate( + prompt=prompt, + **conf + ) + image_url = images_response.data[0].url + return image_url + except Exception as e: + retry += 1 + if retry >= 3: + logger.error(traceback.format_exc()) + raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。") + logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。") + time.sleep(1) + + async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool: + if session_id is None: return False + self.session_memory[session_id] = [] + if keep_system_prompt: + self.personality_set(self.curr_personality, session_id) + return True + + def dump_contexts_page(self, size=5, page=1): ''' 获取缓存的会话 ''' - prompts = "" - if paging: - page_begin = (page-1)*size - page_end = page*size - if page_begin < 0: - page_begin = 0 - if page_end > len(cache_data_list): - page_end = len(cache_data_list) - cache_data_list = cache_data_list[page_begin:page_end] - for item in cache_data_list: - prompts += str(item['user']['role']) + ":\n" + \ - str(item['user']['content']) + "\n" - prompts += str(item['AI']['role']) + ":\n" + \ - str(item['AI']['content']) + "\n" - - if divide: - prompts += "----------\n" - return prompts - - def wrap(self, prompt, session_id, image_url=None): - if image_url is not None: - prompt = [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - } - ] - # 获得缓存信息 - context = self.session_dict[session_id] - new_record = { - "user": { - "role": "user", - "content": prompt, - }, - "AI": {}, - 'type': "common", - 'usage_tokens': 0, - } - req_list = [] - for i in context: - if 'user' in i: - req_list.append(i['user']) - if 'AI' in i: - req_list.append(i['AI']) - req_list.append(new_record['user']) - return context, new_record, req_list - - def handle_switch_key(self): - is_all_exceed = True - for key in self.key_stat: - if key == None or self.key_stat[key]['exceed']: + contexts_str = "" + for i, key in enumerate(self.session_memory): + if i < (page-1)*size or i >= page*size: continue - is_all_exceed = False - self.client.api_key = key - logger.warning( - f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})") - break - if is_all_exceed: - logger.warning( - "所有 Key 已超额") - return False - return True + contexts_str += f"Session ID: {key}\n" + for record in self.session_memory[key]: + if "user" in record: + contexts_str += f"User: {record['user']['content']}\n" + if "AI" in record: + contexts_str += f"AI: {record['AI']['content']}\n" + contexts_str += "---\n" + return contexts_str, len(self.session_memory) + def get_configs(self): - return self.openai_configs + return self.model_configs - def get_key_stat(self): - return self.key_stat - - def get_key_list(self): - return self.key_list + def get_keys_data(self): + return self.keys_data def get_curr_key(self): - return self.client.api_key + return self.chosen_api_key def set_key(self, key): - self.client.api_key = key - - # 添加key - def append_key(self, key, sponsor): - self.key_list.append(key) - self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor} - - # 检查key是否可用 - async def check_key(self, key): - client_ = AsyncOpenAI( - api_key=key, - base_url=self.api_base - ) - messages = [{"role": "user", "content": "please just echo `test`"}] - await client_.chat.completions.create( - messages=messages, - **self.openai_model_configs - ) - return True + self.client.api_key = key \ No newline at end of file diff --git a/model/provider/openai_official_new.py b/model/provider/openai_official_new.py deleted file mode 100644 index 6699c8e2..00000000 --- a/model/provider/openai_official_new.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -import sys -import json -import time -import tiktoken -import threading -import traceback - -from openai import AsyncOpenAI -from openai.types.images_response import ImagesResponse -from openai.types.chat.chat_completion import ChatCompletion - -from cores.database.conn import dbConn -from model.provider.provider import Provider -from util import general_utils as gu -from util.cmd_config import CmdConfig -from SparkleLogging.utils.core import LogManager -from logging import Logger - -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - -class ProviderOpenAIOfficial(Provider): - def __init__(self, cfg) -> None: - super().__init__() - - os.makedirs("data/openai", exist_ok=True) - - self.cc = CmdConfig - self.key_data_path = "data/openai/keys.json" - self.api_keys = [] - self.chosen_api_key = None - self.base_url = None - self.keys_data = { - "keys": [] - } - - if cfg['key']: self.api_keys = cfg['key'] - if cfg['api_base']: self.base_url = cfg['api_base'] - if not self.api_keys: - logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。") - else: - self.chosen_api_key = self.api_keys[0] - - self.client = AsyncOpenAI( - api_key=self.chosen_api_key, - base_url=self.base_url - ) - self.model_configs: dict = cfg['chatGPTConfigs'] - self.session_memory = {} # 会话记忆 - self.session_memory_lock = threading.Lock() - self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 - self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器 - - # 从 SQLite DB 读取历史记录 - try: - db1 = dbConn() - for session in db1.get_all_session(): - self.session_memory_lock.acquire() - self.session_memory[session[0]] = json.loads(session[1])['data'] - self.session_memory_lock.release() - except BaseException as e: - logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。") - - # 定时保存历史记录 - threading.Thread(target=self.dump_history, daemon=True).start() - diff --git a/model/provider/openai_official_old.py b/model/provider/openai_official_old.py new file mode 100644 index 00000000..72f1c2f8 --- /dev/null +++ b/model/provider/openai_official_old.py @@ -0,0 +1,409 @@ +import os +import sys +import json +import time +import tiktoken +import threading +import traceback + +from openai import AsyncOpenAI +from openai.types.images_response import ImagesResponse +from openai.types.chat.chat_completion import ChatCompletion + +from cores.database.conn import dbConn +from model.provider.provider import Provider +from util import general_utils as gu +from util.cmd_config import CmdConfig +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot-core') + + +abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' + + +class ProviderOpenAIOfficial(Provider): + def __init__(self, cfg): + self.cc = CmdConfig() + + self.key_list = [] + # 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错 + for key in cfg['key']: + if len(key) == 1: + raise BaseException( + "检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。") + if cfg['key'] != '' and cfg['key'] != None: + self.key_list = cfg['key'] + if len(self.key_list) == 0: + raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。") + + self.key_stat = {} + for k in self.key_list: + self.key_stat[k] = {'exceed': False, 'used': 0} + + self.api_base = None + if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '': + self.api_base = cfg['api_base'] + logger.info(f"设置 api_base 为: {self.api_base}") + + # 创建 OpenAI Client + self.client = AsyncOpenAI( + api_key=self.key_list[0], + base_url=self.api_base + ) + + self.openai_model_configs: dict = cfg['chatGPTConfigs'] + self.openai_configs = cfg + # 会话缓存 + self.session_dict = {} + # 最大缓存token + self.max_tokens = cfg['total_tokens_limit'] + # 历史记录持久化间隔时间 + self.history_dump_interval = 20 + + self.enc = tiktoken.get_encoding("cl100k_base") + + # 从 SQLite DB 读取历史记录 + try: + db1 = dbConn() + for session in db1.get_all_session(): + self.session_dict[session[0]] = json.loads(session[1])['data'] + logger.info("读取历史记录成功。") + except BaseException as e: + logger.info("读取历史记录失败,但不影响使用。") + + # 创建转储定时器线程 + threading.Thread(target=self.dump_history, daemon=True).start() + + # 人格 + self.curr_personality = {} + + def make_tmp_client(self, api_key: str, base_url: str): + return AsyncOpenAI( + api_key=api_key, + base_url=base_url + ) + + # 转储历史记录 + def dump_history(self): + time.sleep(10) + db = dbConn() + while True: + try: + # print("转储历史记录...") + for key in self.session_dict: + data = self.session_dict[key] + data_json = { + 'data': data + } + if db.check_session(key): + db.update_session(key, json.dumps(data_json)) + else: + db.insert_session(key, json.dumps(data_json)) + # print("转储历史记录完毕") + except BaseException as e: + print(e) + # 每隔10分钟转储一次 + time.sleep(10*self.history_dump_interval) + + def personality_set(self, default_personality: dict, session_id: str): + self.curr_personality = default_personality + new_record = { + "user": { + "role": "user", + "content": default_personality['prompt'], + }, + "AI": { + "role": "assistant", + "content": "好的,接下来我会扮演这个角色。" + }, + 'type': "personality", + 'usage_tokens': 0, + 'single-tokens': 0 + } + self.session_dict[session_id].append(new_record) + + async def text_chat(self, prompt, + session_id=None, + image_url=None, + function_call=None, + extra_conf: dict = None, + default_personality: dict = None): + if session_id is None: + session_id = "unknown" + if "unknown" in self.session_dict: + del self.session_dict["unknown"] + # 会话机制 + if session_id not in self.session_dict: + self.session_dict[session_id] = [] + + if len(self.session_dict[session_id]) == 0: + # 设置默认人格 + if default_personality is not None: + self.personality_set(default_personality, session_id) + + # 使用 tictoken 截断消息 + _encoded_prompt = self.enc.encode(prompt) + if self.openai_model_configs['max_tokens'] < len(_encoded_prompt): + prompt = self.enc.decode(_encoded_prompt[:int( + self.openai_model_configs['max_tokens']*0.80)]) + logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。") + + cache_data_list, new_record, req = self.wrap( + prompt, session_id, image_url) + logger.debug(f"cache: {str(cache_data_list)}") + logger.debug(f"request: {str(req)}") + retry = 0 + response = None + err = '' + + # 截断倍率 + truncate_rate = 0.75 + + conf = self.openai_model_configs + if extra_conf is not None: + conf.update(extra_conf) + + while retry < 10: + try: + if function_call is None: + response = await self.client.chat.completions.create( + messages=req, + **conf + ) + else: + response = await self.client.chat.completions.create( + messages=req, + tools=function_call, + **conf + ) + break + except Exception as e: + traceback.print_exc() + if 'Invalid content type. image_url is only supported by certain models.' in str(e): + raise e + if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): + logger.info("当前 Key 已超额或异常, 正在切换", + ) + self.key_stat[self.client.api_key]['exceed'] = True + is_switched = self.handle_switch_key() + if not is_switched: + raise e + retry -= 1 + elif 'maximum context length' in str(e): + logger.info("token 超限, 清空对应缓存,并进行消息截断") + self.session_dict[session_id] = [] + prompt = prompt[:int(len(prompt)*truncate_rate)] + truncate_rate -= 0.05 + cache_data_list, new_record, req = self.wrap( + prompt, session_id) + + elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e): + time.sleep(30) + continue + else: + logger.error(str(e)) + time.sleep(2) + err = str(e) + retry += 1 + if retry >= 10: + logger.warning( + r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki") + raise BaseException("连接出错: "+str(err)) + assert isinstance(response, ChatCompletion) + logger.debug( + f"OPENAI RESPONSE: {response.usage}") + + # 结果分类 + choice = response.choices[0] + if choice.message.content != None: + # 文本形式 + chatgpt_res = str(choice.message.content).strip() + elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0: + # tools call (function calling) + return choice.message.tool_calls[0].function + + self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens + current_usage_tokens = response.usage.total_tokens + + # 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens + if current_usage_tokens > self.max_tokens: + t = current_usage_tokens + index = 0 + while t > self.max_tokens: + if index >= len(cache_data_list): + break + # 保留人格信息 + if cache_data_list[index]['type'] != 'personality': + t -= int(cache_data_list[index]['single_tokens']) + del cache_data_list[index] + else: + index += 1 + # 删除完后更新相关字段 + self.session_dict[session_id] = cache_data_list + + # 添加新条目进入缓存的prompt + new_record['AI'] = { + 'role': 'assistant', + 'content': chatgpt_res, + } + new_record['usage_tokens'] = current_usage_tokens + if len(cache_data_list) > 0: + new_record['single_tokens'] = current_usage_tokens - \ + int(cache_data_list[-1]['usage_tokens']) + else: + new_record['single_tokens'] = current_usage_tokens + + cache_data_list.append(new_record) + + self.session_dict[session_id] = cache_data_list + + return chatgpt_res + + async def image_chat(self, prompt, img_num=1, img_size="1024x1024"): + retry = 0 + image_url = '' + + image_generate_configs = self.cc.get("openai_image_generate", None) + + while retry < 5: + try: + response: ImagesResponse = await self.client.images.generate( + prompt=prompt, + **image_generate_configs + ) + image_url = [] + for i in range(img_num): + image_url.append(response.data[i].url) + break + except Exception as e: + logger.warning(str(e)) + if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str( + e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): + logger.warning("当前 Key 已超额或者不正常, 正在切换") + self.key_stat[self.client.api_key]['exceed'] = True + is_switched = self.handle_switch_key() + if not is_switched: + raise e + elif 'Your request was rejected as a result of our safety system.' in str(e): + logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试") + raise e + else: + retry += 1 + if retry >= 5: + raise BaseException("连接超时") + + return image_url + + async def forget(self, session_id=None) -> bool: + if session_id is None: + return False + self.session_dict[session_id] = [] + return True + + def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1): + ''' + 获取缓存的会话 + ''' + prompts = "" + if paging: + page_begin = (page-1)*size + page_end = page*size + if page_begin < 0: + page_begin = 0 + if page_end > len(cache_data_list): + page_end = len(cache_data_list) + cache_data_list = cache_data_list[page_begin:page_end] + for item in cache_data_list: + prompts += str(item['user']['role']) + ":\n" + \ + str(item['user']['content']) + "\n" + prompts += str(item['AI']['role']) + ":\n" + \ + str(item['AI']['content']) + "\n" + + if divide: + prompts += "----------\n" + return prompts + + def wrap(self, prompt, session_id, image_url=None): + if image_url is not None: + prompt = [ + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + } + ] + # 获得缓存信息 + context = self.session_dict[session_id] + new_record = { + "user": { + "role": "user", + "content": prompt, + }, + "AI": {}, + 'type': "common", + 'usage_tokens': 0, + } + req_list = [] + for i in context: + if 'user' in i: + req_list.append(i['user']) + if 'AI' in i: + req_list.append(i['AI']) + req_list.append(new_record['user']) + return context, new_record, req_list + + def handle_switch_key(self): + is_all_exceed = True + for key in self.key_stat: + if key == None or self.key_stat[key]['exceed']: + continue + is_all_exceed = False + self.client.api_key = key + logger.warning( + f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})") + break + if is_all_exceed: + logger.warning( + "所有 Key 已超额") + return False + return True + + def get_configs(self): + return self.openai_configs + + def get_key_stat(self): + return self.key_stat + + def get_key_list(self): + return self.key_list + + def get_curr_key(self): + return self.client.api_key + + def set_key(self, key): + self.client.api_key = key + + # 添加key + def append_key(self, key, sponsor): + self.key_list.append(key) + self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor} + + # 检查key是否可用 + async def check_key(self, key): + client_ = AsyncOpenAI( + api_key=key, + base_url=self.api_base + ) + messages = [{"role": "user", "content": "please just echo `test`"}] + await client_.chat.completions.create( + messages=messages, + **self.openai_model_configs + ) + return True diff --git a/model/provider/provider.py b/model/provider/provider.py index f2a80202..022ffbed 100644 --- a/model/provider/provider.py +++ b/model/provider/provider.py @@ -3,7 +3,7 @@ class Provider: prompt: str, session_id: str, image_url: None, - function_call: None, + tools: None, extra_conf: dict = None, default_personality: dict = None, **kwargs) -> str: @@ -14,7 +14,7 @@ class Provider: [optional] image_url: 图片url(识图) - function_call: 函数调用 + tools: 函数调用工具 extra_conf: 额外配置 default_personality: 默认人格 ''' diff --git a/util/general_utils.py b/util/general_utils.py index 33a77266..c9a3e749 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -1,19 +1,23 @@ -import datetime import time import socket -from PIL import Image, ImageDraw, ImageFont import os import re import requests -from util.cmd_config import CmdConfig +import aiohttp import socket -from cores.astrbot.types import GlobalObject import platform -import logging import json import sys import psutil +from PIL import Image, ImageDraw, ImageFont +from cores.astrbot.types import GlobalObject +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot-core') + + def port_checker(port: int, host: str = "localhost"): sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sk.settimeout(1) @@ -355,10 +359,23 @@ def save_temp_img(img: Image) -> str: # 获得时间戳 timestamp = int(time.time()) - p = f"temp/{timestamp}.png" + p = f"temp/{timestamp}.jpg" img.save(p) return p +async def download_image_by_url(url: str) -> str: + ''' + 下载图片 + ''' + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + img = Image.open(await resp.read()) + p = save_temp_img(img) + return p + except Exception as e: + raise e + def create_text_image(title: str, text: str, max_width=30, font_size=20): ''' @@ -455,6 +472,21 @@ def upload(_global_object: GlobalObject): pass time.sleep(10*60) +def retry(n: int = 3): + ''' + 重试装饰器 + ''' + def decorator(func): + def wrapper(*args, **kwargs): + for i in range(n): + try: + return func(*args, **kwargs) + except Exception as e: + if i == n-1: raise e + logger.warning(f"函数 {func.__name__} 第 {i+1} 次重试... {e}") + return wrapper + return decorator + def run_monitor(global_object: GlobalObject): '''