From e643eea365efac05f0ebc4f17c7c366dd3ba6bce Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 12 Mar 2024 18:50:50 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E7=BB=93=E6=9E=84=E5=8C=96=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E7=9A=84=E8=A1=A8=E7=A4=BA=E6=A0=BC=E5=BC=8F;=20?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=8F=92=E4=BB=B6=E5=BC=80=E5=8F=91=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + addons/dashboard/server.py | 36 +++--- addons/plugins/helloworld/README.md | 5 - addons/plugins/helloworld/helloworld.py | 22 +++- cores/qqbot/core.py | 30 +++-- cores/qqbot/global_object.py | 90 --------------- cores/qqbot/types.py | 140 ++++++++++++++++++++++++ main.py | 2 +- model/command/command.py | 65 ++++++----- model/command/openai_official.py | 3 +- model/command/rev_chatgpt.py | 3 +- model/platform/_platfrom.py | 8 +- model/provider/provider.py | 40 +++++-- util/general_utils.py | 4 +- util/plugin_dev/api/v1/bot.py | 12 +- util/plugin_dev/api/v1/config.py | 2 - util/plugin_dev/api/v1/llm.py | 6 + util/plugin_dev/api/v1/message.py | 2 +- util/plugin_dev/api/v1/platform.py | 11 ++ util/plugin_dev/api/v1/register.py | 77 +++++++++++++ util/plugin_dev/api/v1/types.py | 5 + util/plugin_util.py | 108 ++++++++++++------ 22 files changed, 457 insertions(+), 215 deletions(-) delete mode 100644 addons/plugins/helloworld/README.md delete mode 100644 cores/qqbot/global_object.py create mode 100644 cores/qqbot/types.py create mode 100644 util/plugin_dev/api/v1/llm.py create mode 100644 util/plugin_dev/api/v1/platform.py create mode 100644 util/plugin_dev/api/v1/register.py create mode 100644 util/plugin_dev/api/v1/types.py diff --git a/.gitignore b/.gitignore index 9e4c06f5..f7f601cc 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ temp cmd_config.json addons/plugins/ data/ +cookies.json diff --git a/addons/dashboard/server.py b/addons/dashboard/server.py index 1d71854a..70448d55 100644 --- a/addons/dashboard/server.py +++ b/addons/dashboard/server.py @@ -7,6 +7,7 @@ import logging from cores.database.conn import dbConn from util.cmd_config import CmdConfig from util.updator import check_update, update_project, request_release_info +from cores.qqbot.types import * import util.plugin_util as putil import websockets import json @@ -20,7 +21,7 @@ class DashBoardData(): stats: dict configs: dict logs: dict - plugins: list[dict] + plugins: List[RegisteredPlugin] @dataclass class Response(): @@ -33,7 +34,7 @@ class AstrBotDashBoard(): self.global_object = global_object self.loop = asyncio.get_event_loop() asyncio.set_event_loop(self.loop) - self.dashboard_data = global_object.dashboard_data + self.dashboard_data: DashBoardData = global_object.dashboard_data self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/") log = logging.getLogger('werkzeug') log.setLevel(logging.ERROR) @@ -151,13 +152,13 @@ class AstrBotDashBoard(): def get_plugins(): _plugin_resp = [] for plugin in self.dashboard_data.plugins: - _p = self.dashboard_data.plugins[plugin] + _p = plugin.metadata _t = { - "name": _p["info"]["name"], - "repo": '' if "repo" not in _p["info"] else _p["info"]["repo"], - "author": _p["info"]["author"], - "desc": _p["info"]["desc"], - "version": _p["info"]["version"] + "name": _p.plugin_name, + "repo": '' if _p.repo is None else _p.repo, + "author": _p.author, + "desc": _p.desc, + "version": _p.version } _plugin_resp.append(_t) return Response( @@ -359,17 +360,14 @@ class AstrBotDashBoard(): } ] for plugin in self.global_object.cached_plugins: - # 从插件信息中获取 plugin_type 字段,如果有则归类到对应的大纲中 - if "plugin_type" in self.global_object.cached_plugins[plugin]["info"]: - _t = self.global_object.cached_plugins[plugin]["info"]["plugin_type"] - for item in outline: - if item["type"] == _t: - item["body"].append({ - "title": self.global_object.cached_plugins[plugin]["info"]["name"], - "desc": self.global_object.cached_plugins[plugin]["info"]["desc"], - "namespace": plugin, - "tag": plugin, - }) + for item in outline: + if item['type'] == plugin.metadata.plugin_type: + item['body'].append({ + "title": plugin.metadata.plugin_name, + "desc": plugin.metadata.desc, + "namespace": plugin.metadata.plugin_name, + "tag": plugin.metadata.plugin_name + }) return outline def register(self, name: str): diff --git a/addons/plugins/helloworld/README.md b/addons/plugins/helloworld/README.md deleted file mode 100644 index e7f48777..00000000 --- a/addons/plugins/helloworld/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# helloworld - -QQChannelChatGPT项目的测试插件 - -A test plugin for QQChannelChatGPT plugin feature diff --git a/addons/plugins/helloworld/helloworld.py b/addons/plugins/helloworld/helloworld.py index c7aa10ec..9e74b266 100644 --- a/addons/plugins/helloworld/helloworld.py +++ b/addons/plugins/helloworld/helloworld.py @@ -1,15 +1,25 @@ +import os +import shutil from nakuru.entities.components import * from nakuru import ( GroupMessage, FriendMessage ) from botpy.message import Message, DirectMessage -from cores.qqbot.global_object import ( - AstrMessageEvent, - CommandResult -) -import os -import shutil +flag_not_support = False +try: + from util.plugin_dev.api.v1.config import * + from util.plugin_dev.api.v1.bot import ( + PluginMetadata, + PluginType, + AstrMessageEvent, + CommandResult, + ) + from util.plugin_dev.api.v1.register import register_llm, unregister_llm +except ImportError: + flag_not_support = True + print("llms: 导入接口失败。请升级到 AstrBot 最新版本。") + ''' 注意改插件名噢!格式:XXXPlugin 或 Main diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 617b5bcd..98a719dc 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -29,12 +29,13 @@ from util import general_utils as gu from util.general_utils import Logger, upload, run_monitor from util.cmd_config import CmdConfig as cc from util.cmd_config import init_astrbot_config_items -from . global_object import GlobalObject +from .types import * from addons.dashboard.helper import DashBoardHelper from addons.dashboard.server import DashBoardData from cores.database.conn import dbConn from model.platform._message_result import MessageResult + # 用户发言频率 user_frequency = {} # 时间默认值 @@ -43,7 +44,7 @@ frequency_time = 60 frequency_count = 10 # 版本 -version = '3.1.6' +version = '3.1.7' # 语言模型 REV_CHATGPT = 'rev_chatgpt' @@ -98,9 +99,6 @@ def init(cfg): _global_object = GlobalObject() _global_object.version = version _global_object.base_config = cfg - _global_object.stat['session'] = {} - _global_object.stat['message'] = {} - _global_object.stat['platform'] = {} _global_object.logger = logger logger.log("AstrBot v"+version, gu.LEVEL_INFO) @@ -125,6 +123,7 @@ def init(cfg): llm_instance[REV_CHATGPT] = ProviderRevChatGPT(cfg['rev_ChatGPT'], base_url=cc.get("CHATGPT_BASE_URL", None)) llm_command_instance[REV_CHATGPT] = CommandRevChatGPT(llm_instance[REV_CHATGPT], _global_object) chosen_provider = REV_CHATGPT + _global_object.llms.append(RegisteredLLM(llm_name=REV_CHATGPT, llm_instance=llm_instance[REV_CHATGPT], origin="internal")) else: input("请退出本程序, 然后在配置文件中填写rev_ChatGPT相关配置") if OPENAI_OFFICIAL in prov: @@ -134,6 +133,7 @@ def init(cfg): from model.command.openai_official import CommandOpenAIOfficial llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial(cfg['openai']) llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial(llm_instance[OPENAI_OFFICIAL], _global_object) + _global_object.llms.append(RegisteredLLM(llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal")) chosen_provider = OPENAI_OFFICIAL # 检查provider设置偏好 @@ -184,7 +184,7 @@ def init(cfg): _command = Command(None, _global_object) ok, err = putil.plugin_reload(_global_object.cached_plugins) if ok: - logger.log(f"成功载入{len(_global_object.cached_plugins)}个插件", gu.LEVEL_INFO) + logger.log(f"成功载入 {len(_global_object.cached_plugins)} 个插件", gu.LEVEL_INFO) else: logger.log(err, gu.LEVEL_ERROR) @@ -244,7 +244,7 @@ def run_qqchan_bot(cfg: dict, global_object: GlobalObject): try: from model.platform.qq_official import QQOfficial qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg, global_object=global_object) - global_object.platform_qqchan = qqchannel_bot + global_object.platforms.append(RegisteredPlatform(platform_name="qq_official", platform_instance=qqchannel_bot, origin="internal")) qqchannel_bot.run() except BaseException as e: logger.log("启动QQ频道机器人时出现错误, 原因如下: " + str(e), gu.LEVEL_CRITICAL, tag="QQ频道") @@ -272,7 +272,7 @@ def run_gocq_bot(cfg: dict, _global_object: GlobalObject): break try: qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, global_object=_global_object) - _global_object.platform_qq = qq_gocq + _global_object.platforms.append(RegisteredPlatform(platform_name="gocq", platform_instance=qq_gocq, origin="internal")) qq_gocq.run() except BaseException as e: input("启动QQ机器人出现错误"+str(e)) @@ -317,7 +317,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak message: 消息对象 session_id: 该消息源的唯一识别号 role: member | admin - platform: 平台(gocq, qqchan) + platform: str 所注册的平台的名称。如果没有注册,将抛出一个异常。 """ global chosen_provider, _global_object message_str = '' @@ -326,6 +326,16 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak hit = False # 是否命中指令 command_result = () # 调用指令返回的结果 + # 获取平台实例 + reg_platform: RegisteredPlatform = None + for p in _global_object.platforms: + if p.platform_name == platform: + reg_platform = p + break + if not reg_platform: + _global_object.logger.log(f"未找到平台 {platform} 的实例。", gu.LEVEL_ERROR) + raise Exception(f"未找到平台 {platform} 的实例。") + # 统计数据,如频道消息量 await record_message(platform, session_id) @@ -365,7 +375,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak message_str, session_id, role, - platform, + reg_platform, message, ) diff --git a/cores/qqbot/global_object.py b/cores/qqbot/global_object.py deleted file mode 100644 index 3a3c2453..00000000 --- a/cores/qqbot/global_object.py +++ /dev/null @@ -1,90 +0,0 @@ -from model.platform.qq_official import QQOfficial, NakuruGuildMessage -from model.platform.qq_gocq import QQGOCQ -from model.provider.provider import Provider -from addons.dashboard.server import DashBoardData -from nakuru import ( - GroupMessage, - FriendMessage, - GuildMessage, -) -from typing import Union - -class GlobalObject: - ''' - 存放一些公用的数据,用于在不同模块(如core与command)之间传递 - ''' - version: str - nick: str # gocq 的昵称 - base_config: dict # config.json - cached_plugins: dict # 缓存的插件 - web_search: bool # 是否开启了网页搜索 - reply_prefix: str - admin_qq: str - admin_qqchan: str - unique_session: bool - cnt_total: int - platform_qq: QQGOCQ - platform_qqchan: QQOfficial - default_personality: dict - dashboard_data: DashBoardData - stat: dict - logger: None - - def __init__(self): - self.nick = None # gocq 的昵称 - self.base_config = None # config.yaml - self.cached_plugins = {} # 缓存的插件 - self.web_search = False # 是否开启了网页搜索 - self.reply_prefix = None - self.admin_qq = "123456" - self.admin_qqchan = "123456" - self.unique_session = False - self.cnt_total = 0 - self.platform_qq = None - self.platform_qqchan = None - self.default_personality = None - self.dashboard_data = None - self.stat = {} - - -class AstrMessageEvent(): - message_str: str # 纯消息字符串 - message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage] # 消息对象 - gocq_platform: QQGOCQ - qq_sdk_platform: QQOfficial - platform: str # `gocq` 或 `qqchan` - role: str # `admin` 或 `member` - global_object: GlobalObject # 一些公用数据 - session_id: int # 会话id (可能是群id,也可能是某个user的id。取决于是否开启了 unique_session) - - def __init__(self, message_str: str, - message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage], - gocq_platform: QQGOCQ, - qq_sdk_platform: QQOfficial, - platform: str, - role: str, - global_object: GlobalObject, - llm_provider: Provider = None, - session_id: int = None): - self.message_str = message_str - self.message_obj = message_obj - self.gocq_platform = gocq_platform - self.qq_sdk_platform = qq_sdk_platform - self.platform = platform - self.role = role - self.global_object = global_object - self.llm_provider = llm_provider - self.session_id = session_id - -class CommandResult(): - ''' - 用于在Command中返回多个值 - ''' - def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None: - self.hit = hit - self.success = success - self.message_chain = message_chain - self.command_name = command_name - - def _result_tuple(self): - return (self.success, self.message_chain, self.command_name) \ No newline at end of file diff --git a/cores/qqbot/types.py b/cores/qqbot/types.py new file mode 100644 index 00000000..32ec0a6f --- /dev/null +++ b/cores/qqbot/types.py @@ -0,0 +1,140 @@ +from model.platform.qq_official import NakuruGuildMessage +from model.provider.provider import Provider as LLMProvider +from model.platform._platfrom import Platform +from nakuru import ( + GroupMessage, + FriendMessage, + GuildMessage, +) +from typing import Union, List, ClassVar +from types import ModuleType +from enum import Enum +from dataclasses import dataclass + +class PluginType(Enum): + PLATFORM = 'platfrom' # 平台类插件。 + LLM = 'llm' # 大语言模型类插件 + COMMON = 'common' # 其他插件 + +@dataclass +class PluginMetadata: + ''' + 插件的元数据。 + ''' + # required + plugin_name: str + plugin_type: PluginType + author: str # 插件作者 + desc: str # 插件简介 + version: str # 插件版本 + + # optional + repo: str = None # 插件仓库地址 + + def __str__(self) -> str: + return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})" + +@dataclass +class RegisteredPlugin: + ''' + 注册在 AstrBot 中的插件。 + ''' + metadata: PluginMetadata + plugin_instance: object + module_path: str + module: ModuleType + root_dir_name: str + + def __str__(self) -> str: + return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" + +RegisteredPlugins = List[RegisteredPlugin] + +@dataclass +class RegisteredPlatform: + ''' + 注册在 AstrBot 中的平台。平台应当实现 Platform 接口。 + ''' + platform_name: str + platform_instance: Platform + origin: str = None # 注册来源 + +@dataclass +class RegisteredLLM: + ''' + 注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。 + ''' + llm_name: str + llm_instance: LLMProvider + origin: str = None # 注册来源 + +class GlobalObject: + ''' + 存放一些公用的数据,用于在不同模块(如core与command)之间传递 + ''' + version: str # 机器人版本 + nick: str # 用户定义的机器人的别名 + base_config: dict # config.json 中导出的配置 + cached_plugins: List[RegisteredPlugin] # 加载的插件 + platforms: List[RegisteredPlatform] + llms: List[RegisteredLLM] + + web_search: bool # 是否开启了网页搜索 + reply_prefix: str # 回复前缀 + unique_session: bool # 是否开启了独立会话 + cnt_total: int # 总消息数 + default_personality: dict + dashboard_data = None + logger: None + + def __init__(self): + self.nick = None # gocq 的昵称 + self.base_config = None # config.yaml + self.cached_plugins = [] # 缓存的插件 + self.web_search = False # 是否开启了网页搜索 + self.reply_prefix = None + self.unique_session = False + self.cnt_total = 0 + self.platforms = [] + self.llms = [] + self.default_personality = None + self.dashboard_data = None + self.stat = {} + +class AstrMessageEvent(): + ''' + 消息事件。 + ''' + context: GlobalObject # 一些公用数据 + message_str: str # 纯消息字符串 + message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage] # 消息对象 + platform: RegisteredPlatform # 来源平台 + role: str # 基本身份。`admin` 或 `member` + session_id: int # 会话 id + + def __init__(self, + message_str: str, + message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage], + platform: RegisteredPlatform, + role: str, + context: GlobalObject, + session_id: str = None): + self.context = context + self.message_str = message_str + self.message_obj = message_obj + self.platform = platform + self.role = role + self.session_id = session_id + +class CommandResult(): + ''' + 用于在Command中返回多个值 + ''' + def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None: + self.hit = hit + self.success = success + self.message_chain = message_chain + self.command_name = command_name + + def _result_tuple(self): + return (self.success, self.message_chain, self.command_name) \ No newline at end of file diff --git a/main.py b/main.py index 3a4ce64d..ec625480 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,7 @@ def main(): print(file_not_found) input("配置文件不存在,请检查是否已经下载配置文件。") except BaseException as e: - print(e) + raise e # 设置代理 if 'http_proxy' in cfg and cfg['http_proxy'] != '': diff --git a/model/command/command.py b/model/command/command.py index e0d509a0..f5ea4e76 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -8,15 +8,22 @@ import util.plugin_util as putil import util.updator from nakuru.entities.components import ( - Plain, Image ) from util import general_utils as gu from model.provider.provider import Provider from util.cmd_config import CmdConfig as cc from util.general_utils import Logger -from cores.qqbot.global_object import GlobalObject, AstrMessageEvent -from cores.qqbot.global_object import CommandResult +from cores.qqbot.types import ( + GlobalObject, + AstrMessageEvent, + PluginType, + CommandResult, + RegisteredPlugin, + RegisteredPlatform +) + +from typing import List, Tuple PLATFORM_QQCHAN = 'qqchan' PLATFORM_GOCQ = 'gocq' @@ -31,8 +38,8 @@ class Command: async def check_command(self, message, session_id: str, - role, - platform, + role: str, + platform: RegisteredPlatform, message_obj): self.platform = platform # 插件 @@ -41,23 +48,21 @@ class Command: ame = AstrMessageEvent( message_str=message, message_obj=message_obj, - gocq_platform=self.global_object.platform_qq, - qq_sdk_platform=self.global_object.platform_qqchan, platform=platform, role=role, - global_object=self.global_object, + context=self.global_object, session_id = session_id ) # 从已启动的插件中查找是否有匹配的指令 - for k, v in cached_plugins.items(): + for plugin in cached_plugins: # 过滤掉平台类插件 - if "type" in v["info"] and v["info"]["plugin_type"] == "platform": + if plugin.metadata.plugin_type == PluginType.PLATFORM: continue try: - if inspect.iscoroutinefunction(v["clsobj"].run): - result = await v["clsobj"].run(ame) + if inspect.iscoroutinefunction(plugin.plugin_instance.run): + result = await plugin.plugin_instance.run(ame) else: - result = await asyncio.to_thread(v["clsobj"].run, ame) + result = await asyncio.to_thread(plugin.plugin_instance.run, ame) if isinstance(result, CommandResult): hit = result.hit res = result._result_tuple() @@ -71,16 +76,16 @@ class Command: except TypeError as e: # 参数不匹配,尝试使用旧的参数方案 try: - if inspect.iscoroutinefunction(v["clsobj"].run): - hit, res = await v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq) + if inspect.iscoroutinefunction(plugin.plugin_instance.run): + hit, res = await plugin.plugin_instance.run(message, role, platform, message_obj, self.global_object.platform_qq) else: - hit, res = await asyncio.to_thread(v["clsobj"].run, message, role, platform, message_obj, self.global_object.platform_qq) + hit, res = await asyncio.to_thread(plugin.plugin_instance.run, message, role, platform, message_obj, self.global_object.platform_qq) if hit: return True, res except BaseException as e: - self.logger.log(f"{k} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) + self.logger.log(f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) except BaseException as e: - self.logger.log(f"{k} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) + self.logger.log(f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) if self.command_start_with(message, "nick"): return True, self.set_nick(message, platform, role) @@ -125,7 +130,7 @@ class Command: ''' 插件指令 ''' - def plugin_oper(self, message: str, role: str, cached_plugins: dict, platform: str): + def plugin_oper(self, message: str, role: str, cached_plugins: List[RegisteredPlugin], platform: str): l = message.split(" ") if len(l) < 2: p = gu.create_text_image("【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n") @@ -155,25 +160,27 @@ class Command: return False, f"更新插件失败,原因: {str(e)}。\n建议: 使用 plugin i 指令进行覆盖安装(插件数据可能会丢失)", "plugin" elif l[1] == "l": try: - plugin_list_info = "\n".join([f"{k}: \n名称: {v['info']['name']}\n简介: {v['info']['desc']}\n版本: {v['info']['version']}\n作者: {v['info']['author']}\n" for k, v in cached_plugins.items()]) + plugin_list_info = "" + for plugin in cached_plugins: + plugin_list_info += f"{plugin.metadata.plugin_name}: \n名称: {plugin.metadata.plugin_name}\n简介: {plugin.metadata.plugin_desc}\n版本: {plugin.metadata.version}\n作者: {plugin.metadata.author}\n" p = gu.create_text_image("【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n") return True, [Image.fromFileSystem(p)], "plugin" except BaseException as e: return False, f"获取插件列表失败,原因: {str(e)}", "plugin" elif l[1] == "v": try: - if l[2] in cached_plugins: - info = cached_plugins[l[2]]["info"] + info = None + for i in cached_plugins: + if i.metadata.plugin_name == l[2]: + info = i.metadata + break + if info: p = gu.create_text_image(f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}") return True, [Image.fromFileSystem(p)], "plugin" else: return False, "未找到该插件", "plugin" except BaseException as e: return False, f"获取插件信息失败,原因: {str(e)}", "plugin" - elif l[1] == "dev": - if role != "admin": - return False, f"你的身份组{role}没有权限开发者模式", "plugin" - return True, "cached_plugins: \n" + str(cached_plugins), "plugin" ''' nick: 存储机器人的昵称 @@ -206,7 +213,7 @@ class Command: "/revgpt": "切换到网页版ChatGPT", } - async def help_messager(self, commands: dict, platform: str, cached_plugins: dict = None): + async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None): try: async with aiohttp.ClientSession() as session: async with session.get("https://soulter.top/channelbot/notice.json") as resp: @@ -218,7 +225,9 @@ class Command: msg += f"`{key}` - {value}\n" # plugins if cached_plugins != None: - plugin_list_info = "\n".join([f"`{k}` {v['info']['name']}\n{v['info']['desc']}\n" for k, v in cached_plugins.items()]) + plugin_list_info = "" + for plugin in cached_plugins: + plugin_list_info += f"`{plugin.metadata.plugin_name}` {plugin.metadata.desc}\n" if plugin_list_info.strip() != "": msg += "\n## 插件列表\n> 使用plugin v 插件名 查看插件帮助\n" msg += plugin_list_info diff --git a/model/command/openai_official.py b/model/command/openai_official.py index 9d632790..472e3fd5 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -1,12 +1,11 @@ from model.command.command import Command from model.provider.openai_official import ProviderOpenAIOfficial from cores.qqbot.personality import personalities -from cores.qqbot.global_object import GlobalObject +from cores.qqbot.types import GlobalObject class CommandOpenAIOfficial(Command): def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject): self.provider = provider - self.cached_plugins = {} self.global_object = global_object self.personality_str = "" super().__init__(provider, global_object) diff --git a/model/command/rev_chatgpt.py b/model/command/rev_chatgpt.py index 1e7b91aa..f540c994 100644 --- a/model/command/rev_chatgpt.py +++ b/model/command/rev_chatgpt.py @@ -1,12 +1,11 @@ from model.command.command import Command from model.provider.rev_chatgpt import ProviderRevChatGPT from cores.qqbot.personality import personalities -from cores.qqbot.global_object import GlobalObject +from cores.qqbot.types import GlobalObject class CommandRevChatGPT(Command): def __init__(self, provider: ProviderRevChatGPT, global_object: GlobalObject): self.provider = provider - self.cached_plugins = {} self.global_object = global_object self.personality_str = "" super().__init__(provider, global_object) diff --git a/model/platform/_platfrom.py b/model/platform/_platfrom.py index 106f94df..f3e1c027 100644 --- a/model/platform/_platfrom.py +++ b/model/platform/_platfrom.py @@ -20,28 +20,28 @@ class Platform(): pass @abc.abstractmethod - def handle_msg(): + async def handle_msg(): ''' 处理到来的消息 ''' pass @abc.abstractmethod - def reply_msg(): + async def reply_msg(): ''' 回复消息(被动发送) ''' pass @abc.abstractmethod - def send_msg(): + async def send_msg(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): ''' 发送消息(主动发送) ''' pass @abc.abstractmethod - def send(): + async def send(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): ''' 发送消息(主动发送)同 send_msg() ''' diff --git a/model/provider/provider.py b/model/provider/provider.py index 6f4157e6..69e3e3fe 100644 --- a/model/provider/provider.py +++ b/model/provider/provider.py @@ -1,13 +1,35 @@ -import abc - class Provider: - def __init__(self, cfg): - pass + async def text_chat(self, + prompt: str, + session_id: str, + image_url: None, + function_call: None, + extra_conf: dict = None, + default_personality: dict = None, + **kwargs) -> str: + ''' + [require] + prompt: 提示词 + session_id: 会话id + + [optional] + image_url: 图片url(识图) + function_call: 函数调用 + extra_conf: 额外配置 + default_personality: 默认人格 + ''' + raise NotImplementedError - @abc.abstractmethod - async def text_chat(self, prompt, session_id, image_url: None, function_call: None, extra_conf: dict = None, default_personality: dict = None) -> str: - pass + async def image_generate(self, prompt, session_id, **kwargs) -> str: + ''' + [require] + prompt: 提示词 + session_id: 会话id + ''' + raise NotImplementedError - @abc.abstractmethod async def forget(self, session_id = None) -> bool: - pass \ No newline at end of file + ''' + 重置会话 + ''' + raise NotImplementedError \ No newline at end of file diff --git a/util/general_utils.py b/util/general_utils.py index d24cecbc..0dad7ecd 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -7,7 +7,7 @@ import re import requests from util.cmd_config import CmdConfig import socket -from cores.qqbot.global_object import GlobalObject +from cores.qqbot.types import GlobalObject import platform import logging import json @@ -537,7 +537,7 @@ def upload(_global_object: GlobalObject): "count": _global_object.cnt_total, "ip": addr_ip, "sys": sys.platform, - "admin": _global_object.admin_qq, + "admin": "null", } resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5) if resp.status_code == 200: diff --git a/util/plugin_dev/api/v1/bot.py b/util/plugin_dev/api/v1/bot.py index 588d6685..ee3f085d 100644 --- a/util/plugin_dev/api/v1/bot.py +++ b/util/plugin_dev/api/v1/bot.py @@ -1 +1,11 @@ -from cores.qqbot.global_object import GlobalObject \ No newline at end of file +from cores.qqbot.types import ( + PluginMetadata, + RegisteredLLM, + RegisteredPlugin, + RegisteredPlatform, + RegisteredPlugins, + PluginType, + GlobalObject, + AstrMessageEvent, + CommandResult +) \ No newline at end of file diff --git a/util/plugin_dev/api/v1/config.py b/util/plugin_dev/api/v1/config.py index 39a0858f..914b2498 100644 --- a/util/plugin_dev/api/v1/config.py +++ b/util/plugin_dev/api/v1/config.py @@ -1,4 +1,3 @@ -from cores.qqbot.global_object import GlobalObject from typing import Union import os import json @@ -19,7 +18,6 @@ def load_config(namespace: str) -> Union[dict, bool]: ret[k] = data[k]["value"] return ret - def put_config(namespace: str, name: str, key: str, value, description: str): ''' 将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。 diff --git a/util/plugin_dev/api/v1/llm.py b/util/plugin_dev/api/v1/llm.py new file mode 100644 index 00000000..e585689c --- /dev/null +++ b/util/plugin_dev/api/v1/llm.py @@ -0,0 +1,6 @@ +''' +大语言模型. + +插件开发者可以继承这个类来做实现。 +''' +from model.provider.provider import Provider as LLMProvider \ No newline at end of file diff --git a/util/plugin_dev/api/v1/message.py b/util/plugin_dev/api/v1/message.py index 8431952e..39ce128d 100644 --- a/util/plugin_dev/api/v1/message.py +++ b/util/plugin_dev/api/v1/message.py @@ -1,5 +1,5 @@ from cores.qqbot.core import oper_msg -from cores.qqbot.global_object import AstrMessageEvent, CommandResult +from cores.qqbot.types import AstrMessageEvent, CommandResult from model.platform._message_result import MessageResult ''' diff --git a/util/plugin_dev/api/v1/platform.py b/util/plugin_dev/api/v1/platform.py new file mode 100644 index 00000000..bf93986d --- /dev/null +++ b/util/plugin_dev/api/v1/platform.py @@ -0,0 +1,11 @@ +''' +消息平台。 + +Platform类是消息平台的抽象类,定义了消息平台的基本接口。 +消息平台的具体实现类需要继承Platform类,并实现其中的抽象方法。 +''' + +from model.platform._platfrom import Platform + +from model.platform.qq_gocq import QQGOCQ +from model.platform.qq_official import QQOfficial \ No newline at end of file diff --git a/util/plugin_dev/api/v1/register.py b/util/plugin_dev/api/v1/register.py new file mode 100644 index 00000000..ee198345 --- /dev/null +++ b/util/plugin_dev/api/v1/register.py @@ -0,0 +1,77 @@ +''' +允许开发者注册某一个类的实例到 LLM 或者 PLATFORM 中,方便其他插件调用。 + +必须分别实现 Platform 和 LLMProvider 中涉及的接口 +''' +from model.provider.provider import Provider as LLMProvider +from model.platform._platfrom import Platform +from cores.qqbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM + +def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None: + ''' + 注册一个消息平台。 + + Args: + platform_name: 平台名称。 + platform_instance: 平台实例。 + ''' + + # check 是否已经注册 + for platform in context.platforms: + if platform.platform_name == platform_name: + raise ValueError(f"Platform {platform_name} has been registered.") + + # check + should_attrs = Platform.__dir__() + has_attrs = platform_instance.__dir__() + + if not all([attr in has_attrs for attr in should_attrs]): + raise ValueError(f"Platform {platform_name} should implement all methods in LLMProvider.") + + context.platforms.append(RegisteredPlatform(platform_name, platform_instance)) + +def register_llm(llm_name: str, llm_instance: LLMProvider, context: GlobalObject) -> None: + ''' + 注册一个大语言模型。 + + Args: + llm_name: 大语言模型名称。 + llm_instance: 大语言模型实例。 + ''' + # check 是否已经注册 + for llm in context.llms: + if llm.llm_name == llm_name: + raise ValueError(f"LLMProvider {llm_name} has been registered.") + + # check + should_attrs = LLMProvider.__dir__() + has_attrs = llm_instance.__dir__() + + if not all([attr in has_attrs for attr in should_attrs]): + raise ValueError(f"LLMProvider {llm_name} should implement all methods in LLMProvider.") + + context.llms.append(RegisteredLLM(llm_name, llm_instance)) + +def unregister_platform(platform_name: str, context: GlobalObject) -> None: + ''' + 注销一个消息平台。 + + Args: + platform_name: 平台名称。 + ''' + for i, platform in enumerate(context.platforms): + if platform.platform_name == platform_name: + context.platforms.pop(i) + return + +def unregister_llm(llm_name: str, context: GlobalObject) -> None: + ''' + 注销一个大语言模型。 + + Args: + llm_name: 大语言模型名称。 + ''' + for i, llm in enumerate(context.llms): + if llm.llm_name == llm_name: + context.llms.pop(i) + return \ No newline at end of file diff --git a/util/plugin_dev/api/v1/types.py b/util/plugin_dev/api/v1/types.py new file mode 100644 index 00000000..8649b1f2 --- /dev/null +++ b/util/plugin_dev/api/v1/types.py @@ -0,0 +1,5 @@ +''' +插件类型 +''' + +from cores.qqbot.types import PluginType \ No newline at end of file diff --git a/util/plugin_util.py b/util/plugin_util.py index 855d908e..e67d019d 100644 --- a/util/plugin_util.py +++ b/util/plugin_util.py @@ -9,11 +9,19 @@ try: except ImportError: pass import shutil -from pip._internal import main as pipmain import importlib import stat import traceback from types import ModuleType +from typing import List +from pip._internal import main as pipmain +from cores.qqbot.types import ( + PluginMetadata, + PluginType, + RegisteredPlugin, + RegisteredPlugins +) + # 找出模块里所有的类名 def get_classes(p_name, arg: ModuleType): @@ -45,7 +53,8 @@ def get_modules(path): if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(os.path.join(path, d, d + ".py")): modules.append({ "pname": d, - "module": module_str + "module": module_str, + "module_path": os.path.join(path, d, module_str) }) return modules @@ -73,39 +82,62 @@ def get_plugin_modules(): except BaseException as e: raise e -def plugin_reload(cached_plugins: dict, target: str = None, all: bool = False): +def plugin_reload(cached_plugins: RegisteredPlugins): plugins = get_plugin_modules() if plugins is None: return False, "未找到任何插件模块" fail_rec = "" + + registered_map = {} + for p in cached_plugins: + registered_map[p.module_path] = None + for plugin in plugins: try: p = plugin['module'] + module_path = plugin['module_path'] root_dir_name = plugin['pname'] - if p not in cached_plugins or p == target or all: + + if module_path in registered_map: + # 之前注册过 + module = importlib.reload(module) + else: module = __import__("addons.plugins." + root_dir_name + "." + p, fromlist=[p]) - if p in cached_plugins: - module = importlib.reload(module) - cls = get_classes(p, module) - obj = getattr(module, cls[0])() - try: - info = obj.info() + + cls = get_classes(p, module) + obj = getattr(module, cls[0])() + + metadata = None + try: + info = obj.info() + if isinstance(info, dict): if 'name' not in info or 'desc' not in info or 'version' not in info or 'author' not in info: - fail_rec += f"载入插件{p}失败,原因: 插件信息不完整\n" + fail_rec += f"注册插件 {module_path} 失败,原因: 插件信息不完整\n" continue - if isinstance(info, dict) == False: - fail_rec += f"载入插件{p}失败,原因: 插件信息格式不正确\n" - continue - except BaseException as e: - fail_rec += f"调用插件{p} info失败, 原因: {str(e)}\n" + else: + metadata = PluginMetadata( + plugin_name=info['name'], + plugin_type=PluginType.COMMON if 'plugin_type' not in info else PluginType(info['plugin_type']), + author=info['author'], + desc=info['desc'], + version=info['version'], + repo=info['repo'] if 'repo' in info else None + ) + elif isinstance(info, PluginMetadata): + metadata = info + else: + fail_rec += f"注册插件 {module_path} 失败,原因: info 函数返回值类型错误\n" continue - cached_plugins[info['name']] = { - "module": module, - "clsobj": obj, - "info": info, - "name": info['name'], - "root_dir_name": root_dir_name, - } + except BaseException as e: + fail_rec += f"注册插件 {module_path} 失败, 原因: {str(e)}\n" + continue + cached_plugins.append(RegisteredPlugin( + metadata=metadata, + plugin_instance=obj, + module=module, + module_path=module_path, + root_dir_name=root_dir_name + )) except BaseException as e: traceback.print_exc() fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n" @@ -114,7 +146,7 @@ def plugin_reload(cached_plugins: dict, target: str = None, all: bool = False): else: return False, fail_rec -def install_plugin(repo_url: str, cached_plugins: dict): +def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins): ppath = get_plugin_store_path() # 删除末尾的 / if repo_url.endswith("/"): @@ -132,23 +164,33 @@ def install_plugin(repo_url: str, cached_plugins: dict): if os.path.exists(os.path.join(plugin_path, "requirements.txt")): if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0: raise Exception("插件的依赖安装失败, 需要您手动 pip 安装对应插件的依赖。") - ok, err = plugin_reload(cached_plugins, target=d) + ok, err = plugin_reload(cached_plugins) if not ok: raise Exception(err) + +def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) -> RegisteredPlugin: + ret = None + for p in cached_plugins: + if p.metadata.plugin_name == plugin_name: + ret = p + break + return ret -def uninstall_plugin(plugin_name: str, cached_plugins: dict): - if plugin_name not in cached_plugins: +def uninstall_plugin(plugin_name: str, cached_plugins: RegisteredPlugins): + plugin = get_registered_plugin(plugin_name, cached_plugins) + if not plugin: raise Exception("插件不存在。") - root_dir_name = cached_plugins[plugin_name]["root_dir_name"] + root_dir_name = plugin.root_dir_name ppath = get_plugin_store_path() - del cached_plugins[plugin_name] + cached_plugins.remove(plugin) if not remove_dir(os.path.join(ppath, root_dir_name)): raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") -def update_plugin(plugin_name: str, cached_plugins: dict): - if plugin_name not in cached_plugins: +def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins): + plugin = get_registered_plugin(plugin_name, cached_plugins) + if not plugin: raise Exception("插件不存在。") ppath = get_plugin_store_path() - root_dir_name = cached_plugins[plugin_name]["root_dir_name"] + root_dir_name = plugin.root_dir_name plugin_path = os.path.join(ppath, root_dir_name) repo = Repo(path = plugin_path) repo.remotes.origin.pull() @@ -156,7 +198,7 @@ def update_plugin(plugin_name: str, cached_plugins: dict): if os.path.exists(os.path.join(plugin_path, "requirements.txt")): if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0: raise Exception("插件依赖安装失败, 需要您手动pip安装对应插件的依赖。") - ok, err = plugin_reload(cached_plugins, target=plugin_name) + ok, err = plugin_reload(cached_plugins) if not ok: raise Exception(err) def remove_dir(file_path) -> bool: