Files
AstrBot/type/types.py
2024-09-11 16:47:44 +08:00

155 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio, os
from asyncio import Task
from type.register import *
from typing import List, Awaitable
from logging import Logger
from util.cmd_config import AstrBotConfig
from util.t2i.renderer import TextToImageRenderer
from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator
from type.command import CommandResult
from type.middleware import Middleware
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
from util.agent.func_call import FuncCall
class Context:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
def __init__(self):
self.running = True
self.logger: Logger = None
self.config_helper: AstrBotConfig = None
self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件
self.platforms: List[RegisteredPlatform] = []
self.llms: List[RegisteredLLM] = []
self.default_personality: dict = None
# self.unique_session = False # 独立会话
# self.version: str = None # 机器人版本
# self.nick: tuple = None # gocq 的唤醒词
# self.t2i_mode = False
# self.web_search = False # 是否开启了网页搜索
self.metrics_uploader = None
self.updator: AstrBotUpdator = None
self.plugin_updator: PluginUpdator = None
self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins)
self.image_renderer = TextToImageRenderer()
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []
self.middlewares: List[Middleware] = []
self.command_manager = None
self.running = True
# useless
# self.reply_prefix = ""
def register_commands(self,
plugin_name: str,
command_name: str,
description: str,
priority: int,
handler: callable,
use_regex: bool = False,
ignore_prefix: bool = False):
'''
注册插件指令。
@param plugin_name: 插件名,注意需要和你的 metadata 中的一致。
@param command_name: 指令名,如 "help"。不需要带前缀。
@param description: 指令描述。
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
@param handler: 指令处理函数。函数参数message: AstrMessageEvent, context: Context
@param use_regex: 是否使用正则表达式匹配指令名。
@param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
.. Example::
ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True则用户输入 "help" 也会被识别为 "help" 指令。
'''
self.plugin_command_bridge.register_command(plugin_name,
command_name,
description,
priority,
handler,
use_regex,
ignore_prefix)
def register_task(self, coro: Awaitable, task_name: str):
'''
注册任务。适用于需要长时间运行的插件。
`coro`: 协程对象
`task_name`: 任务名,用于标识任务。自定义即可。
'''
task = asyncio.create_task(coro, name=task_name)
self.ext_tasks.append(task)
def register_provider(self, llm_name: str, provider: Provider, origin: str = ''):
'''
注册一个提供 LLM 资源的 Provider。
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
'''
self.llms.append(RegisteredLLM(llm_name, provider, origin))
def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable):
'''
为函数调用function-calling / tools-use添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 处理函数
'''
self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
def unregister_llm_tool(self, tool_name: str):
'''
删除一个函数调用工具。
'''
self.message_handler.llm_tools.remove_func(tool_name)
def register_middleware(self, middleware: Middleware):
'''
注册一个中间件。所有的消息事件都会经过中间件处理,然后再进入 LLM 聊天模块。
在 AstrBot 中,会对到来的消息事件首先检查指令,然后再检查中间件。触发指令后将不会进入 LLM 聊天模块,而中间件会。
'''
self.middlewares.append(middleware)
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
if platform_name == platform.platform_name:
return platform
if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错
raise ValueError("couldn't find the platform you specified")
async def send_message(self, unified_msg_origin: str, message: CommandResult):
'''
发送消息。
`unified_msg_origin`: 统一消息来源
`message`: 消息内容
'''
l = unified_msg_origin.split(":")
if len(l) != 3:
raise ValueError("Invalid unified_msg_origin")
platform_name, message_type, id = l
platform = self.find_platform(platform_name)
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
def get_current_llm_provider(self) -> Provider:
'''
获取当前的 LLM Provider。
'''
return self.message_handler.provider