Compare commits
3 Commits
feat/agent
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
850b8c64d4 | ||
|
|
4de7a8a7a7 | ||
|
|
b8a10ccf08 |
@@ -12,6 +12,8 @@ from astrbot.core.star.register import (
|
||||
register_llm_tool as llm_tool,
|
||||
register_on_decorating_result as on_decorating_result,
|
||||
register_after_message_sent as after_message_sent,
|
||||
register_on_tool_start as on_tool_start,
|
||||
register_on_tool_end as on_tool_end,
|
||||
)
|
||||
|
||||
from astrbot.core.star.filter.event_message_type import (
|
||||
@@ -46,4 +48,6 @@ __all__ = [
|
||||
"on_decorating_result",
|
||||
"after_message_sent",
|
||||
"on_llm_response",
|
||||
"on_tool_start",
|
||||
"on_tool_end",
|
||||
]
|
||||
|
||||
@@ -124,15 +124,17 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
result.append({
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
}
|
||||
)
|
||||
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
@@ -142,15 +144,17 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for plugin_id, plugin_info in data.items():
|
||||
online_plugins.append({
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
})
|
||||
online_plugins.append(
|
||||
{
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
import typing as T
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@@ -14,4 +14,5 @@ class ContextWrapper(Generic[TContext]):
|
||||
context: TContext
|
||||
event: AstrMessageEvent
|
||||
|
||||
|
||||
NoContext = ContextWrapper[None]
|
||||
|
||||
@@ -272,7 +272,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context,
|
||||
func_tool_name,
|
||||
func_tool,
|
||||
func_tool_args,
|
||||
resp,
|
||||
)
|
||||
@@ -291,7 +291,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
self.run_context, func_tool, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -304,7 +304,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
self.run_context, func_tool, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@@ -70,7 +70,7 @@ DEFAULT_CONFIG = {
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"max_agent_step": 30
|
||||
"max_agent_step": 30,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
|
||||
@@ -53,7 +53,7 @@ async def do_migration_v4(
|
||||
await migration_webchat_data(db_helper, platform_id_map)
|
||||
|
||||
# 执行偏好设置迁移
|
||||
await migration_preferences(db_helper,platform_id_map)
|
||||
await migration_preferences(db_helper, platform_id_map)
|
||||
|
||||
# 执行平台统计表迁移
|
||||
await migration_platform_table(db_helper, platform_id_map)
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
@@ -42,4 +43,5 @@ class SharedPreferences:
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
|
||||
|
||||
sp = SharedPreferences()
|
||||
|
||||
@@ -4,6 +4,7 @@ from astrbot.core.db.po import Platform, Stats
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话存储
|
||||
@@ -76,7 +77,7 @@ PRAGMA encoding = 'UTF-8';
|
||||
"""
|
||||
|
||||
|
||||
class SQLiteDatabase():
|
||||
class SQLiteDatabase:
|
||||
def __init__(self, db_path: str) -> None:
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .vec_db import FaissVecDB
|
||||
|
||||
__all__ = ["FaissVecDB"]
|
||||
__all__ = ["FaissVecDB"]
|
||||
|
||||
@@ -128,6 +128,7 @@ class Plain(BaseMessageComponent):
|
||||
async def to_dict(self):
|
||||
return {"type": "text", "data": {"text": self.text}}
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
id: int
|
||||
|
||||
@@ -200,6 +200,18 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
|
||||
async def on_tool_start(self, run_context, tool, tool_args):
|
||||
# 执行 Tool 开始事件钩子
|
||||
await call_event_hook(
|
||||
run_context.event, EventType.OnToolStartEvent, tool, tool_args
|
||||
)
|
||||
|
||||
async def on_tool_end(self, run_context, tool, tool_args, tool_result):
|
||||
# 执行 Tool 完成事件钩子
|
||||
await call_event_hook(
|
||||
run_context.event, EventType.OnToolEndEvent, tool, tool_args, tool_result
|
||||
)
|
||||
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
@@ -423,7 +435,9 @@ class LLMRequestSubStage(Stage):
|
||||
if req.func_tool:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||
if "tool_use" not in provider_cfg:
|
||||
logger.debug(f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。")
|
||||
logger.debug(
|
||||
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
|
||||
)
|
||||
req.func_tool = None
|
||||
|
||||
# run agent
|
||||
|
||||
@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
|
||||
@@ -55,8 +55,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"钉钉图片处理失败: {e}")
|
||||
logger.warning(f"跳过图片发送: {image_path}")
|
||||
logger.warning(f"跳过图片发送: {segment.file}")
|
||||
continue
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
@@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot):
|
||||
await self.on_ready_once_callback()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
|
||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
|
||||
)
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
@@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot):
|
||||
message_data = self._create_message_data(message)
|
||||
await self.on_message_received(message_data)
|
||||
|
||||
|
||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||
"""从交互中提取内容"""
|
||||
interaction_type = interaction.type
|
||||
|
||||
@@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent):
|
||||
self.url = url
|
||||
self.disabled = disabled
|
||||
|
||||
|
||||
class DiscordReference(BaseMessageComponent):
|
||||
"""Discord引用组件"""
|
||||
|
||||
type: str = "discord_reference"
|
||||
|
||||
def __init__(self, message_id: str, channel_id: str):
|
||||
self.message_id = message_id
|
||||
self.channel_id = channel_id
|
||||
@@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def to_discord_view(self) -> discord.ui.View:
|
||||
"""转换为Discord View对象"""
|
||||
view = discord.ui.View(timeout=self.timeout)
|
||||
|
||||
@@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
|
||||
# 解析消息链为 Discord 所需的对象
|
||||
try:
|
||||
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
|
||||
(
|
||||
content,
|
||||
files,
|
||||
view,
|
||||
embeds,
|
||||
reference_message_id,
|
||||
) = await self._parse_to_discord(message)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
||||
return
|
||||
@@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
files.append(
|
||||
discord.File(BytesIO(file_bytes),
|
||||
filename=i.name)
|
||||
discord.File(BytesIO(file_bytes), filename=i.name)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
@@ -308,7 +308,9 @@ class SlackAdapter(Platform):
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
return base64_content
|
||||
else:
|
||||
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
|
||||
logger.error(
|
||||
f"Failed to download slack file: {resp.status} {await resp.text()}"
|
||||
)
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
|
||||
async def run(self) -> Awaitable[Any]:
|
||||
|
||||
@@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||
}
|
||||
file_url = response["files"][0]["permalink"]
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
return chunks
|
||||
|
||||
@classmethod
|
||||
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str):
|
||||
async def send_with_client(
|
||||
cls, client: ExtBot, message: MessageChain, user_name: str
|
||||
):
|
||||
image_path = None
|
||||
|
||||
has_reply = False
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class WebChatQueueMgr:
|
||||
def __init__(self) -> None:
|
||||
self.queues = {}
|
||||
@@ -30,4 +31,5 @@ class WebChatQueueMgr:
|
||||
"""Check if a queue exists for the given conversation ID"""
|
||||
return conversation_id in self.queues
|
||||
|
||||
|
||||
webchat_queue_mgr = WebChatQueueMgr()
|
||||
|
||||
@@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform):
|
||||
def _extract_auth_key(self, data):
|
||||
"""Helper method to extract auth_key from response data."""
|
||||
if isinstance(data, dict):
|
||||
auth_keys = data.get("authKeys") # 新接口
|
||||
auth_keys = data.get("authKeys") # 新接口
|
||||
if isinstance(auth_keys, list) and auth_keys:
|
||||
return auth_keys[0]
|
||||
elif isinstance(data, list) and data: # 旧接口
|
||||
elif isinstance(data, list) and data: # 旧接口
|
||||
return data[0]
|
||||
return None
|
||||
|
||||
@@ -234,7 +234,9 @@ class WeChatPadProAdapter(Platform):
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
|
||||
logger.error(
|
||||
f"生成授权码失败: {response.status}, {await response.text()}"
|
||||
)
|
||||
return
|
||||
|
||||
response_data = await response.json()
|
||||
@@ -245,7 +247,9 @@ class WeChatPadProAdapter(Platform):
|
||||
if self.auth_key:
|
||||
logger.info("成功获取授权码")
|
||||
else:
|
||||
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
|
||||
logger.error(
|
||||
f"生成授权码成功但未找到授权码: {response_data}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"生成授权码失败: {response_data}")
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
|
||||
@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
|
||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
||||
data = {
|
||||
"token": token,
|
||||
"cursor": cursor,
|
||||
"limit": limit,
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
return self._post("kf/sync_msg", data=data)
|
||||
|
||||
def get_service_state(self, open_kfid, external_userid):
|
||||
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
}
|
||||
return self._post("kf/service_state/get", data=data)
|
||||
|
||||
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
||||
def trans_service_state(
|
||||
self, open_kfid, external_userid, service_state, servicer_userid=""
|
||||
):
|
||||
"""
|
||||
变更会话状态
|
||||
|
||||
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
"""
|
||||
return self._get("kf/customer/get_upgrade_service_config")
|
||||
|
||||
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
||||
def upgrade_service(
|
||||
self, open_kfid, external_userid, service_type, member=None, groupchat=None
|
||||
):
|
||||
"""
|
||||
为客户升级为专员或客户群服务
|
||||
|
||||
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||
return self._post("kf/get_corp_statistic", data=data)
|
||||
|
||||
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
||||
def get_servicer_statistic(
|
||||
self, start_time, end_time, open_kfid=None, servicer_userid=None
|
||||
):
|
||||
"""
|
||||
获取「客户数据统计」接待人员明细数据
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from optionaldict import optionaldict
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
|
||||
class WeChatKFMessage(BaseWeChatAPI):
|
||||
"""
|
||||
发送微信客服消息
|
||||
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
|
||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||
)
|
||||
|
||||
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
||||
def send_msgmenu(
|
||||
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "msgmenu",
|
||||
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
||||
"msgmenu": {
|
||||
"head_content": head_content,
|
||||
"list": menu_list,
|
||||
"tail_content": tail_content,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
||||
def send_location(
|
||||
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "location",
|
||||
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
||||
"msgmenu": {
|
||||
"name": name,
|
||||
"address": address,
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
||||
def send_miniprogram(
|
||||
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "miniprogram",
|
||||
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
||||
"msgmenu": {
|
||||
"appid": appid,
|
||||
"title": title,
|
||||
"thumb_media_id": thumb_media_id,
|
||||
"pagepath": pagepath,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||
result = await asyncio.wait_for(
|
||||
asyncio.shield(future), 60
|
||||
) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
|
||||
@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from asyncio import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import (
|
||||
Provider,
|
||||
TTSProvider,
|
||||
|
||||
@@ -113,8 +113,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
)
|
||||
raise ValueError(
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
||||
+ tree
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||
)
|
||||
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
|
||||
@@ -14,6 +14,8 @@ from .star_handler import (
|
||||
register_agent,
|
||||
register_on_decorating_result,
|
||||
register_after_message_sent,
|
||||
register_on_tool_start,
|
||||
register_on_tool_end,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -32,4 +34,6 @@ __all__ = [
|
||||
"register_agent",
|
||||
"register_on_decorating_result",
|
||||
"register_after_message_sent",
|
||||
"register_on_tool_start",
|
||||
"register_on_tool_end",
|
||||
]
|
||||
|
||||
@@ -376,9 +376,11 @@ def register_llm_tool(name: str = None, **kwargs):
|
||||
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
||||
if registering_agent._agent.tools is None:
|
||||
registering_agent._agent.tools = []
|
||||
registering_agent._agent.tools.append(llm_tools.spec_to_func(
|
||||
llm_tool_name, args, docstring.description.strip(), awaitable
|
||||
))
|
||||
registering_agent._agent.tools.append(
|
||||
llm_tools.spec_to_func(
|
||||
llm_tool_name, args, docstring.description.strip(), awaitable
|
||||
)
|
||||
)
|
||||
|
||||
return awaitable
|
||||
|
||||
@@ -448,3 +450,52 @@ def register_after_message_sent(**kwargs):
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_tool_start(**kwargs):
|
||||
"""当 Tool 请求开始时的事件
|
||||
|
||||
Examples:
|
||||
```py
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
|
||||
@register_on_tool_start()
|
||||
async def test(event: AstrMessageEvent, tool: FunctionTool, tool_args: dict) -> None:
|
||||
# 在工具调用开始时执行的逻辑
|
||||
logger.info(f"Tool {tool.name} started with args: {tool_args}")
|
||||
```
|
||||
|
||||
请务必接收三个参数:event, tool, tool_args
|
||||
"""
|
||||
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnToolStartEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_tool_end(**kwargs):
|
||||
"""当 Tool 请求完成时的事件
|
||||
|
||||
Examples:
|
||||
```py
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
import mcp.types
|
||||
|
||||
@register_on_tool_end()
|
||||
async def test(event: AstrMessageEvent, tool: FunctionTool, tool_args: dict, tool_result: mcp.types.CallToolResult | None) -> None:
|
||||
# 在工具调用完成时执行的逻辑
|
||||
logger.info(f"Tool {tool.name} finished with result: {tool_result}")
|
||||
```
|
||||
|
||||
请务必接收四个参数:event, tool, tool_args, tool_result
|
||||
"""
|
||||
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnToolEndEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -84,7 +84,10 @@ class SessionPluginManager:
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put(
|
||||
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
|
||||
"session_plugin_config",
|
||||
session_plugin_config,
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -7,6 +7,7 @@ from .star import star_map
|
||||
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
def __init__(self):
|
||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
@@ -49,7 +50,8 @@ class StarHandlerRegistry(Generic[T]):
|
||||
self, module_name: str
|
||||
) -> List[StarHandlerMetadata]:
|
||||
return [
|
||||
handler for handler in self._handlers
|
||||
handler
|
||||
for handler in self._handlers
|
||||
if handler.handler_module_path == module_name
|
||||
]
|
||||
|
||||
@@ -67,6 +69,7 @@ class StarHandlerRegistry(Generic[T]):
|
||||
def __len__(self):
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
|
||||
@@ -83,6 +86,8 @@ class EventType(enum.Enum):
|
||||
OnLLMResponseEvent = enum.auto() # LLM 响应后
|
||||
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
||||
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
|
||||
OnToolStartEvent = enum.auto() # Tool 请求开始
|
||||
OnToolEndEvent = enum.auto() # Tool 请求完成
|
||||
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
||||
|
||||
|
||||
|
||||
@@ -819,11 +819,11 @@ class PluginManager:
|
||||
if star_metadata.star_cls is None:
|
||||
return
|
||||
|
||||
if '__del__' in star_metadata.star_cls_type.__dict__:
|
||||
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None, star_metadata.star_cls.__del__
|
||||
)
|
||||
elif 'terminate' in star_metadata.star_cls_type.__dict__:
|
||||
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||
await star_metadata.star_cls.terminate()
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
|
||||
@@ -182,7 +182,9 @@ class StarTools:
|
||||
|
||||
plugin_name = metadata.name
|
||||
|
||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
||||
data_dir = Path(
|
||||
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
|
||||
)
|
||||
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -56,9 +56,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
try:
|
||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||
if os.name == "nt":
|
||||
args = [
|
||||
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
|
||||
]
|
||||
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
|
||||
else:
|
||||
args = sys.argv[1:]
|
||||
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
||||
|
||||
@@ -10,7 +10,9 @@ class LogRoute(Route):
|
||||
super().__init__(context)
|
||||
self.log_broker = log_broker
|
||||
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
|
||||
self.app.add_url_rule("/api/log-history", view_func=self.log_history, methods=["GET"])
|
||||
self.app.add_url_rule(
|
||||
"/api/log-history", view_func=self.log_history, methods=["GET"]
|
||||
)
|
||||
|
||||
async def log(self):
|
||||
async def stream():
|
||||
@@ -48,9 +50,15 @@ class LogRoute(Route):
|
||||
"""获取日志历史"""
|
||||
try:
|
||||
logs = list(self.log_broker.log_cache)
|
||||
return Response().ok(data={
|
||||
"logs": logs,
|
||||
}).__dict__
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"logs": logs,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
|
||||
@@ -53,6 +53,8 @@ class PluginRoute(Route):
|
||||
EventType.OnLLMResponseEvent: "LLM 响应后",
|
||||
EventType.OnDecoratingResultEvent: "回复消息前",
|
||||
EventType.OnCallingFuncToolEvent: "函数工具",
|
||||
EventType.OnToolStartEvent: "Tool 请求开始",
|
||||
EventType.OnToolEndEvent: "Tool 请求完成",
|
||||
EventType.OnAfterMessageSentEvent: "发送消息后",
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import traceback
|
||||
|
||||
import aiohttp
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -60,9 +60,7 @@ class AstrBotDashboard:
|
||||
self.session_management_route = SessionManagementRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
self.persona_route = PersonaRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
self.persona_route = PersonaRoute(self.context, db, core_lifecycle)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
|
||||
@@ -88,7 +88,9 @@ class LongTermMemory:
|
||||
|
||||
if cfg["ar_whitelist"] and (
|
||||
event.unified_msg_origin not in cfg["ar_whitelist"]
|
||||
and (event.get_group_id() and event.get_group_id() not in cfg["ar_whitelist"])
|
||||
and (
|
||||
event.get_group_id() and event.get_group_id() not in cfg["ar_whitelist"]
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
|
||||
@@ -1110,7 +1110,9 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
@filter.command("unset")
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(umo="uid", key="session_variables", default={})
|
||||
session_var = await sp.session_get(
|
||||
umo="uid", key="session_variables", default={}
|
||||
)
|
||||
|
||||
if key not in session_var:
|
||||
yield event.plain_result("没有那个变量名。格式 /unset 变量名。")
|
||||
|
||||
@@ -5,7 +5,7 @@ import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api import llm_tool, agent, logger, AstrBotConfig
|
||||
from astrbot.api import llm_tool, logger, AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
from .engines import SearchResult
|
||||
from .engines.bing import Bing
|
||||
|
||||
Reference in New Issue
Block a user