Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dba1ed1e19 | ||
|
|
a24514876b | ||
|
|
466a1c1c41 | ||
|
|
a2d5e9f40f | ||
|
|
1bbff1d161 | ||
|
|
0948bae99b | ||
|
|
850db41596 | ||
|
|
7bafc87e2b | ||
|
|
1a0de02a15 | ||
|
|
6d5d278624 | ||
|
|
3b4cc48fa0 | ||
|
|
c908461088 | ||
|
|
53d1398d30 | ||
|
|
782c0367d0 | ||
|
|
4678222e9b | ||
|
|
f71dc3e4be | ||
|
|
f6233893bd | ||
|
|
6427bcf130 | ||
|
|
8fa41b706c | ||
|
|
4706c4438d |
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.28"
|
||||
VERSION = "3.4.30"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -409,6 +409,17 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "grok-2-latest",
|
||||
},
|
||||
},
|
||||
"ollama": {
|
||||
"id": "ollama_default",
|
||||
"type": "openai_chat_completion",
|
||||
|
||||
@@ -11,6 +11,7 @@ from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
class DifyRequestSubStage(Stage):
|
||||
|
||||
@@ -48,17 +49,40 @@ class DifyRequestSubStage(Stage):
|
||||
return
|
||||
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
# 装饰 system_prompt 等功能
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, req)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
logger.debug(f"Dify 请求 Payload: {req.__dict__}")
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
# 执行 LLM 响应后的事件。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
# text completion
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT))
|
||||
yield # rick roll
|
||||
return
|
||||
elif llm_response.role == 'err':
|
||||
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"))
|
||||
return
|
||||
elif llm_response.role == 'tool':
|
||||
event.set_result(MessageEventResult().message(f"Dify 暂不支持工具调用。"))
|
||||
yield
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -64,7 +64,7 @@ class LLMRequestSubStage(Stage):
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件。
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
# 装饰 system_prompt 等功能
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
|
||||
for handler in handlers:
|
||||
|
||||
@@ -45,12 +45,14 @@ class ProcessStage(Stage):
|
||||
else:
|
||||
yield
|
||||
|
||||
# 调用提供商相关请求
|
||||
# 调用 LLM 相关请求
|
||||
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
||||
return
|
||||
|
||||
if not event._has_send_oper and event.is_at_or_wake_command:
|
||||
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
# 事件没有终止传播
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
import asyncio
|
||||
import math
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
@@ -88,7 +89,9 @@ class RespondStage(Stage):
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
||||
for handler in handlers:
|
||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
||||
await handler.handler(event)
|
||||
try:
|
||||
await handler.handler(event)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
event.clear_result()
|
||||
@@ -59,10 +59,14 @@ class ResultDecorateStage(Stage):
|
||||
async for _ in self.content_safe_check_stage.process(event, check_text=text):
|
||||
yield
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
|
||||
for handler in handlers:
|
||||
await handler.handler(event)
|
||||
|
||||
try:
|
||||
await handler.handler(event)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import inspect
|
||||
from astrbot.api import logger
|
||||
from typing import List, AsyncGenerator, Union, Awaitable
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
@@ -36,16 +37,18 @@ class Stage(abc.ABC):
|
||||
ctx: PipelineContext,
|
||||
event: AstrMessageEvent,
|
||||
handler: Awaitable,
|
||||
**params
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
'''调用 Handler。'''
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
try:
|
||||
ready_to_call = handler(event, **params)
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as e:
|
||||
# 向下兼容
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
|
||||
logger.debug(str(e))
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for ret in ready_to_call:
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At, Reply
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
|
||||
@@ -76,34 +77,17 @@ class WakingCheckStage(Stage):
|
||||
# 检查插件的 handler filter
|
||||
activated_handlers = []
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
|
||||
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
|
||||
# filter 需要满足 AND 的逻辑关系
|
||||
passed = True
|
||||
child_command_handler_md = None
|
||||
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
permission_not_pass = False
|
||||
|
||||
if len(handler.event_filters) == 0:
|
||||
# 不可能有这种情况, 也不允许有这种情况
|
||||
continue
|
||||
|
||||
if 'sub_command' in handler.extras_configs:
|
||||
# 如果是子指令
|
||||
continue
|
||||
|
||||
for filter in handler.event_filters:
|
||||
try:
|
||||
if isinstance(filter, CommandGroupFilter):
|
||||
"""如果指令组过滤成功, 会返回叶子指令的 StarHandlerMetadata"""
|
||||
ok, child_command_handler_md = filter.filter(
|
||||
event, self.ctx.astrbot_config
|
||||
)
|
||||
if not ok:
|
||||
passed = False
|
||||
else:
|
||||
handler = child_command_handler_md # handler 覆盖
|
||||
break
|
||||
elif isinstance(filter, PermissionTypeFilter):
|
||||
if isinstance(filter, PermissionTypeFilter):
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
permission_not_pass = True
|
||||
else:
|
||||
@@ -111,19 +95,15 @@ class WakingCheckStage(Stage):
|
||||
passed = False
|
||||
break
|
||||
except Exception as e:
|
||||
# event.set_result(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
|
||||
# yield
|
||||
await event.send(
|
||||
MessageEventResult().message(
|
||||
f"插件 {handler.handler_full_name} 报错:{e}"
|
||||
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
||||
)
|
||||
)
|
||||
event.stop_event()
|
||||
passed = False
|
||||
break
|
||||
|
||||
if passed:
|
||||
|
||||
if permission_not_pass:
|
||||
if self.no_permission_reply:
|
||||
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"))
|
||||
@@ -138,6 +118,7 @@ class WakingCheckStage(Stage):
|
||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||
"parsed_params"
|
||||
)
|
||||
|
||||
event.clear_extra()
|
||||
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
|
||||
@@ -51,6 +51,8 @@ class SimpleGewechatClient():
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
self.userrealnames = {}
|
||||
|
||||
async def get_token_id(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -118,10 +120,25 @@ class SimpleGewechatClient():
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0] \
|
||||
.replace('在群聊中@了你', '') \
|
||||
.replace('在群聊中发了一段语音', '') \
|
||||
.replace('在群聊中发了一张图片', '') # 真实昵称
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
if abm.group_id:
|
||||
if abm.group_id not in self.userrealnames or user_id not in self.userrealnames[abm.group_id]:
|
||||
# 获取群成员列表,并且缓存
|
||||
if abm.group_id not in self.userrealnames:
|
||||
self.userrealnames[abm.group_id] = {}
|
||||
member_list = await self.get_chatroom_member_list(abm.group_id)
|
||||
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
|
||||
if member_list and 'memberList' in member_list:
|
||||
for member in member_list['memberList']:
|
||||
self.userrealnames[abm.group_id][member['wxid']] = member['nickName']
|
||||
if user_id in self.userrealnames[abm.group_id]:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0]
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
abm.message_str = ""
|
||||
@@ -313,6 +330,23 @@ class SimpleGewechatClient():
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
'''API'''
|
||||
|
||||
async def get_chatroom_member_list(self, chatroom_wxid: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": chatroom_wxid
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob['data']
|
||||
|
||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
|
||||
@@ -75,52 +75,57 @@ class ProviderDify(Provider):
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={
|
||||
**session_var
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk['event'] == "message" or \
|
||||
chunk['event'] == "agent_message":
|
||||
result += chunk['answer']
|
||||
if not conversation_id:
|
||||
self.conversation_ids[session_id] = chunk['conversation_id']
|
||||
conversation_id = chunk['conversation_id']
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
self.dify_query_input_key: prompt,
|
||||
"astrbot_session_id": session_id,
|
||||
**session_var
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout
|
||||
):
|
||||
match chunk['event']:
|
||||
case "workflow_started":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
|
||||
case "node_finished":
|
||||
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
|
||||
case "workflow_finished":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
|
||||
if chunk['data']['error']:
|
||||
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
if self.workflow_output_key not in chunk['data']['outputs']:
|
||||
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
|
||||
result = chunk['data']['outputs'][self.workflow_output_key]
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
try:
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={
|
||||
**session_var
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk['event'] == "message" or \
|
||||
chunk['event'] == "agent_message":
|
||||
result += chunk['answer']
|
||||
if not conversation_id:
|
||||
self.conversation_ids[session_id] = chunk['conversation_id']
|
||||
conversation_id = chunk['conversation_id']
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
self.dify_query_input_key: prompt,
|
||||
"astrbot_session_id": session_id,
|
||||
**session_var
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout
|
||||
):
|
||||
match chunk['event']:
|
||||
case "workflow_started":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
|
||||
case "node_finished":
|
||||
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
|
||||
case "workflow_finished":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
|
||||
if chunk['data']['error']:
|
||||
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
if self.workflow_output_key not in chunk['data']['outputs']:
|
||||
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
|
||||
result = chunk['data']['outputs'][self.workflow_output_key]
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Dify 请求失败:{str(e)}")
|
||||
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{str(e)}")
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=result)
|
||||
|
||||
async def forget(self, session_id):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import aiohttp
|
||||
import random
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
@@ -138,8 +139,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
"role": "model",
|
||||
"parts": [{"text": message["content"]}]
|
||||
})
|
||||
|
||||
|
||||
|
||||
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
|
||||
|
||||
result = await self.client.generate_content(
|
||||
@@ -194,33 +194,49 @@ class ProviderGoogleGenAI(Provider):
|
||||
**model_config
|
||||
}
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
if retry_cnt == 0:
|
||||
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
|
||||
elif "Function calling is not enabled" in str(e):
|
||||
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
|
||||
if 'tools' in payloads:
|
||||
del payloads['tools']
|
||||
llm_response = await self._query(payloads, None)
|
||||
else:
|
||||
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
raise e
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(keys)
|
||||
|
||||
for i in range(retry):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
if retry_cnt == 0:
|
||||
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
|
||||
elif "Function calling is not enabled" in str(e):
|
||||
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
|
||||
if 'tools' in payloads:
|
||||
del payloads['tools']
|
||||
llm_response = await self._query(payloads, None)
|
||||
elif "429" in str(e) or "API key not valid" in str(e):
|
||||
keys.remove(chosen_key)
|
||||
if len(keys) > 0:
|
||||
chosen_key = random.choice(keys)
|
||||
logger.info(f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}...")
|
||||
continue
|
||||
else:
|
||||
logger.error(f"A检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}...")
|
||||
raise Exception("API 资源已耗尽,且没有可用的 Key 重试...")
|
||||
else:
|
||||
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
|
||||
import re
|
||||
import inspect
|
||||
from typing import List
|
||||
from typing import List, Any, Type, Dict
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.utils.param_validation_mixin import ParameterValidationMixin
|
||||
from .custom_filter import CustomFilter
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
class CommandFilter(HandlerFilter):
|
||||
'''标准指令过滤器'''
|
||||
def __init__(self, command_name: str, alias: set = None, handler_md: StarHandlerMetadata = None):
|
||||
def __init__(self, command_name: str, alias: set = None, handler_md: StarHandlerMetadata = None, parent_command_names: List[str] = [""]):
|
||||
self.command_name = command_name
|
||||
self.alias = alias if alias else set()
|
||||
self.parent_command_names = parent_command_names
|
||||
if handler_md:
|
||||
self.init_handler_md(handler_md)
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
@@ -26,6 +26,7 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
result += f"{k}({v.__name__}),"
|
||||
else:
|
||||
result += f"{k}({type(v).__name__})={v},"
|
||||
result = result.rstrip(",")
|
||||
return result
|
||||
|
||||
def init_handler_md(self, handle_md: StarHandlerMetadata):
|
||||
@@ -54,6 +55,39 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
if not custom_filter.filter(event, cfg):
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]:
|
||||
'''将参数列表 params 根据 param_type 转换为参数字典。
|
||||
'''
|
||||
result = {}
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
|
||||
if i >= len(params):
|
||||
if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty:
|
||||
# 是类型
|
||||
raise ValueError(f"必要参数缺失。该指令完整参数: {self.print_types()}")
|
||||
else:
|
||||
# 是默认值
|
||||
result[param_name] = param_type_or_default_val
|
||||
else:
|
||||
# 尝试强制转换
|
||||
try:
|
||||
if param_type_or_default_val is None:
|
||||
if params[i].isdigit():
|
||||
result[param_name] = int(params[i])
|
||||
else:
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, str):
|
||||
# 如果 param_type_or_default_val 是字符串,直接赋值
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, int):
|
||||
result[param_name] = int(params[i])
|
||||
elif isinstance(param_type_or_default_val, float):
|
||||
result[param_name] = float(params[i])
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
except ValueError:
|
||||
raise ValueError(f"参数 {param_name} 类型错误。完整参数: {self.print_types()}")
|
||||
return result
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
@@ -61,27 +95,31 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
|
||||
if not self.custom_filter_ok(event, cfg):
|
||||
return False
|
||||
|
||||
if event.get_extra("parsing_command"):
|
||||
message_str = event.get_extra("parsing_command").strip()
|
||||
else:
|
||||
message_str = event.get_message_str().strip()
|
||||
|
||||
# 分割为列表(每个参数之间可能会有多个空格)
|
||||
ls = re.split(r"\s+", message_str)
|
||||
if self.command_name != ls[0] and ls[0] not in self.alias:
|
||||
|
||||
# 检查是否以指令开头
|
||||
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
||||
candidates = [self.command_name] + list(self.alias)
|
||||
ok = False
|
||||
for candidate in candidates:
|
||||
for parent_command_name in self.parent_command_names:
|
||||
if parent_command_name:
|
||||
_full = f"{parent_command_name} {candidate}"
|
||||
else:
|
||||
_full = candidate
|
||||
if message_str.startswith(f"{_full} ") or message_str == _full:
|
||||
message_str = message_str[len(_full):].strip()
|
||||
ok = True
|
||||
break
|
||||
if not ok:
|
||||
return False
|
||||
# if len(self.handler_params) == 0 and len(ls) > 1:
|
||||
# # 一定程度避免 LLM 聊天时误判为指令
|
||||
# return False
|
||||
# params_str = message_str[len(self.command_name):].strip()
|
||||
ls = ls[1:]
|
||||
|
||||
# 分割为列表
|
||||
ls = message_str.split(" ")
|
||||
# 去除空字符串
|
||||
ls = [param for param in ls if param]
|
||||
params = {}
|
||||
try:
|
||||
params = self.validate_and_convert_params(ls, self.handler_params)
|
||||
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -11,11 +11,12 @@ from ..star_handler import StarHandlerMetadata
|
||||
|
||||
# 指令组受到 wake_prefix 的制约。
|
||||
class CommandGroupFilter(HandlerFilter):
|
||||
def __init__(self, group_name: str, alias: set = None):
|
||||
def __init__(self, group_name: str, alias: set = None, parent_group: CommandGroupFilter = None):
|
||||
self.group_name = group_name
|
||||
self.alias = alias if alias else set()
|
||||
self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = []
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
self.parent_group = parent_group
|
||||
|
||||
def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]):
|
||||
self.sub_command_filters.append(sub_command_filter)
|
||||
@@ -23,6 +24,24 @@ class CommandGroupFilter(HandlerFilter):
|
||||
def add_custom_filter(self, custom_filter: CustomFilter):
|
||||
self.custom_filter_list.append(custom_filter)
|
||||
|
||||
def get_complete_command_names(self) -> List[str]:
|
||||
'''遍历父节点获取完整的指令名。
|
||||
|
||||
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。'''
|
||||
parent_cmd_names = self.parent_group.get_complete_command_names() if self.parent_group else []
|
||||
|
||||
if not parent_cmd_names:
|
||||
# 根节点
|
||||
return [self.group_name] + list(self.alias)
|
||||
|
||||
result = []
|
||||
candidates = [self.group_name] + list(self.alias)
|
||||
for parent_cmd_name in parent_cmd_names:
|
||||
for candidate in candidates:
|
||||
result.append(parent_cmd_name + " " + candidate)
|
||||
return result
|
||||
|
||||
|
||||
# 以树的形式打印出来
|
||||
def print_cmd_tree(self,
|
||||
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
||||
@@ -43,6 +62,10 @@ class CommandGroupFilter(HandlerFilter):
|
||||
result += f" ({cmd_th})"
|
||||
else:
|
||||
result += " (无参数指令)"
|
||||
|
||||
if sub_filter.handler_md and sub_filter.handler_md.desc:
|
||||
result += f": {sub_filter.handler_md.desc}"
|
||||
|
||||
result += "\n"
|
||||
elif isinstance(sub_filter, CommandGroupFilter):
|
||||
custom_filter_pass = True
|
||||
@@ -61,46 +84,19 @@ class CommandGroupFilter(HandlerFilter):
|
||||
return False
|
||||
return True
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]:
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
return False, None
|
||||
|
||||
if event.get_extra("parsing_command"):
|
||||
message_str = event.get_extra("parsing_command").strip()
|
||||
else:
|
||||
message_str = event.get_message_str().strip()
|
||||
|
||||
ls = re.split(r"\s+", message_str)
|
||||
|
||||
if ls[0] != self.group_name and ls[0] not in self.alias:
|
||||
return False, None
|
||||
# 改写 message_str
|
||||
ls = ls[1:]
|
||||
# event.message_str = " ".join(ls)
|
||||
# event.message_str = event.message_str.strip()
|
||||
parsing_command = " ".join(ls)
|
||||
parsing_command = parsing_command.strip()
|
||||
event.set_extra("parsing_command", parsing_command)
|
||||
return False
|
||||
|
||||
# 判断当前指令组的自定义过滤器
|
||||
if not self.custom_filter_ok(event, cfg):
|
||||
return False, None
|
||||
return False
|
||||
|
||||
if parsing_command == "":
|
||||
# 当前还是指令组
|
||||
complete_command_names = self.get_complete_command_names()
|
||||
if event.message_str.strip() in complete_command_names:
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
|
||||
|
||||
child_command_handler_md = None
|
||||
for sub_filter in self.sub_command_filters:
|
||||
if isinstance(sub_filter, CommandFilter):
|
||||
if sub_filter.filter(event, cfg):
|
||||
child_command_handler_md = sub_filter.get_handler_md()
|
||||
return True, child_command_handler_md
|
||||
elif isinstance(sub_filter, CommandGroupFilter):
|
||||
ok, handler = sub_filter.filter(event, cfg)
|
||||
if ok:
|
||||
child_command_handler_md = handler
|
||||
return True, child_command_handler_md
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
raise ValueError(f"指令组 {self.group_name} 下没有找到对应的指令。这个指令组下有如下指令:\n"+tree)
|
||||
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
# return event.message_str.startswith(tuple(complete_command_names))
|
||||
return False
|
||||
@@ -54,14 +54,13 @@ def get_handler_or_create(
|
||||
def register_command(command_name: str = None, sub_command: str = None, alias: set = None, **kwargs):
|
||||
'''注册一个 Command.
|
||||
'''
|
||||
|
||||
# print("command: ", command_name, args, kwargs)
|
||||
|
||||
new_command = None
|
||||
add_to_event_filters = False
|
||||
if isinstance(command_name, RegisteringCommandable):
|
||||
# 子指令
|
||||
new_command = CommandFilter(sub_command, alias, None)
|
||||
parent_command_names = command_name.parent_group.get_complete_command_names()
|
||||
logger.debug(f"parent_command_names: {parent_command_names}")
|
||||
new_command = CommandFilter(sub_command, alias, None, parent_command_names=parent_command_names)
|
||||
command_name.parent_group.add_sub_command_filter(new_command)
|
||||
else:
|
||||
# 裸指令
|
||||
@@ -73,10 +72,7 @@ def register_command(command_name: str = None, sub_command: str = None, alias: s
|
||||
kwargs['sub_command'] = True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管)
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
new_command.init_handler_md(handler_md)
|
||||
if add_to_event_filters:
|
||||
# 裸指令
|
||||
handler_md.event_filters.append(new_command)
|
||||
|
||||
handler_md.event_filters.append(new_command)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
@@ -142,25 +138,19 @@ def register_command_group(
|
||||
):
|
||||
'''注册一个 CommandGroup
|
||||
'''
|
||||
|
||||
# print("commandgroup: ", command_group_name,args, kwargs)
|
||||
|
||||
new_group = None
|
||||
add_to_event_filters = False
|
||||
if isinstance(command_group_name, RegisteringCommandable):
|
||||
# 子指令组
|
||||
new_group = CommandGroupFilter(sub_command, alias)
|
||||
new_group = CommandGroupFilter(sub_command, alias, parent_group=command_group_name.parent_group)
|
||||
command_group_name.parent_group.add_sub_command_filter(new_group)
|
||||
else:
|
||||
# 根指令组
|
||||
new_group = CommandGroupFilter(command_group_name, alias)
|
||||
add_to_event_filters = True
|
||||
|
||||
def decorator(obj):
|
||||
if add_to_event_filters:
|
||||
# 根指令组
|
||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(new_group)
|
||||
# 根指令组
|
||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(new_group)
|
||||
|
||||
return RegisteringCommandable(new_group)
|
||||
|
||||
@@ -168,8 +158,8 @@ def register_command_group(
|
||||
|
||||
class RegisteringCommandable():
|
||||
'''用于指令组级联注册'''
|
||||
group = register_command_group
|
||||
command = register_command
|
||||
group: CommandGroupFilter = register_command_group
|
||||
command: CommandFilter = register_command
|
||||
custom_filter = register_custom_filter
|
||||
|
||||
def __init__(self, parent_group: CommandGroupFilter):
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import inspect
|
||||
from typing import List, Dict, Any, Type
|
||||
|
||||
class ParameterValidationMixin:
|
||||
def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]:
|
||||
'''将参数列表 params 根据 param_type 转换为参数字典。
|
||||
'''
|
||||
result = {}
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
|
||||
if i >= len(params):
|
||||
if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty:
|
||||
# 是类型
|
||||
raise ValueError(f"参数 {param_name} 缺失")
|
||||
else:
|
||||
# 是默认值
|
||||
result[param_name] = param_type_or_default_val
|
||||
else:
|
||||
# 尝试强制转换
|
||||
try:
|
||||
if param_type_or_default_val is None:
|
||||
if params[i].isdigit():
|
||||
result[param_name] = int(params[i])
|
||||
else:
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, str):
|
||||
# 如果 param_type_or_default_val 是字符串,直接赋值
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, int):
|
||||
result[param_name] = int(params[i])
|
||||
elif isinstance(param_type_or_default_val, float):
|
||||
result[param_name] = float(params[i])
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
except ValueError:
|
||||
raise ValueError(f"参数 {param_name} 类型错误")
|
||||
return result
|
||||
@@ -113,11 +113,17 @@ class PluginRoute(Route):
|
||||
for filter in handler.event_filters: # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
|
||||
if isinstance(filter, CommandFilter):
|
||||
info["type"] = "指令"
|
||||
info["cmd"] = filter.command_name
|
||||
info["cmd"] = f"{filter.parent_command_names[0]} {filter.command_name}"
|
||||
info["cmd"] = info["cmd"].strip()
|
||||
if self.core_lifecycle.astrbot_config['wake_prefix'] and len(self.core_lifecycle.astrbot_config['wake_prefix']) > 0:
|
||||
info["cmd"] = f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
|
||||
elif isinstance(filter, CommandGroupFilter):
|
||||
info["type"] = "指令组"
|
||||
info["cmd"] = filter.group_name
|
||||
info["cmd"] = filter.get_complete_command_names()[0]
|
||||
info["cmd"] = info["cmd"].strip()
|
||||
info["sub_command"] = filter.print_cmd_tree(filter.sub_command_filters)
|
||||
if self.core_lifecycle.astrbot_config['wake_prefix'] and len(self.core_lifecycle.astrbot_config['wake_prefix']) > 0:
|
||||
info["cmd"] = f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
|
||||
elif isinstance(filter, RegexFilter):
|
||||
info["type"] = "正则匹配"
|
||||
info["cmd"] = filter.regex_str
|
||||
|
||||
14
changelogs/v3.4.29.md
Normal file
14
changelogs/v3.4.29.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# What's Changed
|
||||
|
||||
1. ✨ 新增: gemini source 初步支持对 API Key 进行负载均衡请求 #534
|
||||
2. ✨ 新增: 开启对话隔离的群聊以及私聊下,非 op 可以可以使用 /del 和 /reset #519
|
||||
3. ✨ 新增: 事件钩子支持 yield 方式发送消息
|
||||
4. ⚡ 优化: 查询模型列表时,可以显示当前使用的模型名称 #523
|
||||
5. ⚡ 优化: 更换为预编译指令的方式处理指令组指令
|
||||
6. 🐛 修复: resolve KeyError when current conversation is not in paginated list
|
||||
7. 🐛 修复: 修复指令组的情况下,Permission Filter 对子指令失效的问题
|
||||
8. 🐛 修复: 🐛 fix: 修复 reminder rm失败 #529
|
||||
9. 🐛 修复: 🐛 fix: reminder 时区问题 #529
|
||||
10. 🐛 修复: 修复 Dify 下无法主动回复的问题 #494
|
||||
11. 🐛 修复: 添加代码执行器 Docker 宿主机绝对路径配置及相关功能以修复 Docker 下无法使用代码执行器的问题 #525
|
||||
12. 🐛 修复: gewechat 微信群聊情况下可能导致 unknown 的问题 #537
|
||||
5
changelogs/v3.4.30.md
Normal file
5
changelogs/v3.4.30.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# What's Changed
|
||||
|
||||
1. ‼️🐛 修复: 修复某些情况下导致插件报错 AttributeError 的问题 #549
|
||||
2. ✨ 新增: add xAI template
|
||||
3. 🐛 修复: 修复 dify 无法使用事件钩子的问题以及出现 GeneratorExit 的问题 #533 #264
|
||||
@@ -125,10 +125,6 @@ class LongTermMemory:
|
||||
else:
|
||||
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
|
||||
req.system_prompt += chats_str
|
||||
if self.image_caption:
|
||||
req.system_prompt += (
|
||||
"The images sent by the members are displayed in text form above."
|
||||
)
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
|
||||
@@ -96,6 +96,7 @@ AstrBot 指令:
|
||||
|
||||
@tool.command("ls")
|
||||
async def tool_ls(self, event: AstrMessageEvent):
|
||||
'''查看函数工具列表'''
|
||||
tm = self.context.get_llm_tool_manager()
|
||||
msg = "函数工具:\n"
|
||||
for tool in tm.func_list:
|
||||
@@ -107,6 +108,7 @@ AstrBot 指令:
|
||||
|
||||
@tool.command("on")
|
||||
async def tool_on(self, event: AstrMessageEvent, tool_name: str):
|
||||
'''启用一个函数工具'''
|
||||
if self.context.activate_llm_tool(tool_name):
|
||||
event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 成功。"))
|
||||
else:
|
||||
@@ -114,6 +116,7 @@ AstrBot 指令:
|
||||
|
||||
@tool.command("off")
|
||||
async def tool_off(self, event: AstrMessageEvent, tool_name: str):
|
||||
'''停用一个函数工具'''
|
||||
if self.context.deactivate_llm_tool(tool_name):
|
||||
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 成功。"))
|
||||
else:
|
||||
@@ -121,6 +124,7 @@ AstrBot 指令:
|
||||
|
||||
@tool.command("off_all")
|
||||
async def tool_all_off(self, event: AstrMessageEvent):
|
||||
'''停用所有函数工具'''
|
||||
tm = self.context.get_llm_tool_manager()
|
||||
for tool in tm.func_list:
|
||||
self.context.deactivate_llm_tool(tool.name)
|
||||
@@ -128,6 +132,7 @@ AstrBot 指令:
|
||||
|
||||
@filter.command("plugin")
|
||||
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
|
||||
'''插件管理'''
|
||||
if oper1 is None:
|
||||
plugin_list_info = "已加载的插件:\n"
|
||||
for plugin in self.context.get_all_stars():
|
||||
@@ -189,6 +194,7 @@ AstrBot 指令:
|
||||
|
||||
@filter.command("t2i")
|
||||
async def t2i(self, event: AstrMessageEvent):
|
||||
'''开关文本转图片'''
|
||||
config = self.context.get_config()
|
||||
if config['t2i']:
|
||||
config['t2i'] = False
|
||||
@@ -201,6 +207,7 @@ AstrBot 指令:
|
||||
|
||||
@filter.command("tts")
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
'''开关文本转语音'''
|
||||
config = self.context.get_config()
|
||||
if config['provider_tts_settings']['enable']:
|
||||
config['provider_tts_settings']['enable'] = False
|
||||
@@ -213,6 +220,7 @@ AstrBot 指令:
|
||||
|
||||
@filter.command("sid")
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
'''获取会话 ID 和 管理员 ID'''
|
||||
sid = event.unified_msg_origin
|
||||
user_id = str(event.get_sender_id())
|
||||
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。/wl <SID> 添加白名单, /dwl <SID> 删除白名单。
|
||||
@@ -222,6 +230,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("op")
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str):
|
||||
'''授权管理员。op <admin_id>'''
|
||||
self.context.get_config()['admins_id'].append(admin_id)
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("授权成功。"))
|
||||
@@ -229,6 +238,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("deop")
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str):
|
||||
'''取消授权管理员。deop <admin_id>'''
|
||||
try:
|
||||
self.context.get_config()['admins_id'].remove(admin_id)
|
||||
self.context.get_config().save_config()
|
||||
@@ -340,16 +350,20 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
else:
|
||||
event.set_result(MessageEventResult().message("无效的参数。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
'''重置 LLM 会话'''
|
||||
is_unique_session = self.context.get_config()['platform_settings']['unique_session']
|
||||
if message.get_group_id() and not is_unique_session and message.role != "admin":
|
||||
# 群聊,没开独立会话,发送人不是管理员
|
||||
message.set_result(MessageEventResult().message(f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限重置当前对话。"))
|
||||
return
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider()
|
||||
print(provider.meta())
|
||||
if provider and provider.meta().type == 'dify':
|
||||
assert isinstance(provider, ProviderDify)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
@@ -393,6 +407,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
for model in models:
|
||||
ret += f"\n{i}. {model}"
|
||||
i += 1
|
||||
|
||||
curr_model = self.context.get_using_provider().get_model() or "无"
|
||||
ret += f"\n当前模型: [{curr_model}]"
|
||||
|
||||
ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
else:
|
||||
@@ -418,7 +436,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换模型到 {self.context.get_using_provider().get_model()}。"))
|
||||
|
||||
|
||||
@filter.command("history")
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
'''查看对话记录'''
|
||||
@@ -458,6 +475,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
provider = self.context.get_using_provider()
|
||||
if provider and provider.meta().type == 'dify':
|
||||
"""原有的Dify处理逻辑保持不变"""
|
||||
ret = "Dify 对话列表:\n"
|
||||
assert isinstance(provider, ProviderDify)
|
||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||
@@ -474,32 +492,45 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
return
|
||||
|
||||
size_per_page = 6
|
||||
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
|
||||
total_pages = len(conversations) // size_per_page
|
||||
if len(conversations) % size_per_page != 0:
|
||||
total_pages += 1
|
||||
conversations = conversations[(page-1)*size_per_page:page*size_per_page]
|
||||
"""获取所有对话列表"""
|
||||
conversations_all = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
|
||||
"""计算总页数"""
|
||||
total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page
|
||||
"""确保页码有效"""
|
||||
page = max(1, min(page, total_pages))
|
||||
"""分页处理"""
|
||||
start_idx = (page - 1) * size_per_page
|
||||
end_idx = start_idx + size_per_page
|
||||
conversations_paged = conversations_all[start_idx:end_idx]
|
||||
|
||||
ret = "对话列表:\n---\n"
|
||||
global_index = (page - 1) * size_per_page + 1
|
||||
"""全局序号从当前页的第一个开始"""
|
||||
global_index = start_idx + 1
|
||||
|
||||
"""生成所有对话的标题字典"""
|
||||
_titles = {}
|
||||
for conv in conversations:
|
||||
|
||||
for conv in conversations_all:
|
||||
persona_id = conv.persona_id
|
||||
if not persona_id and not persona_id == "[%None]":
|
||||
if not persona_id or persona_id == "[%None]":
|
||||
persona_id = self.context.provider_manager.selected_default_persona['name']
|
||||
|
||||
title = conv.title if conv.title else "新对话"
|
||||
_titles[conv.cid] = title
|
||||
|
||||
|
||||
"""遍历分页后的对话生成列表显示"""
|
||||
for conv in conversations_paged:
|
||||
persona_id = conv.persona_id
|
||||
if not persona_id or persona_id == "[%None]":
|
||||
persona_id = self.context.provider_manager.selected_default_persona['name']
|
||||
title = _titles.get(conv.cid, "新对话")
|
||||
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
|
||||
global_index += 1
|
||||
|
||||
ret += "---\n"
|
||||
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
if curr_cid:
|
||||
ret += f"\n当前对话: {_titles[curr_cid]}({curr_cid[:4]})"
|
||||
"""从所有对话的标题字典中获取标题"""
|
||||
title = _titles.get(curr_cid, "新对话")
|
||||
ret += f"\n当前对话: {title}({curr_cid[:4]})"
|
||||
else:
|
||||
ret += "\n当前对话: 无"
|
||||
|
||||
@@ -508,11 +539,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
ret += "\n会话隔离粒度: 个人"
|
||||
else:
|
||||
ret += "\n会话隔离粒度: 群聊"
|
||||
|
||||
|
||||
ret += f"\n第 {page} 页 | 共 {total_pages} 页"
|
||||
ret += "\n*输入 /ls 2 跳转到第 2 页"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
return
|
||||
|
||||
@filter.command("new")
|
||||
async def new_conv(self, message: AstrMessageEvent):
|
||||
@@ -582,10 +614,14 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
await self.context.conversation_manager.update_conversation_title(message.unified_msg_origin, new_name)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("del")
|
||||
async def del_conv(self, message: AstrMessageEvent):
|
||||
'''删除当前对话'''
|
||||
is_unique_session = self.context.get_config()['platform_settings']['unique_session']
|
||||
if message.get_group_id() and not is_unique_session and message.role != "admin":
|
||||
# 群聊,没开独立会话,发送人不是管理员
|
||||
message.set_result(MessageEventResult().message(f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。"))
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider()
|
||||
if provider and provider.meta().type == 'dify':
|
||||
@@ -604,7 +640,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid)
|
||||
message.set_result(MessageEventResult().message("删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"))
|
||||
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("key")
|
||||
async def key(self, message: AstrMessageEvent, index: int=None):
|
||||
|
||||
@@ -85,7 +85,8 @@ DEFAULT_CONFIG = {
|
||||
"sandbox": {
|
||||
"image": "soulter/astrbot-code-interpreter-sandbox",
|
||||
"docker_mirror": "", # cjie.eu.org
|
||||
}
|
||||
},
|
||||
"docker_host_astrbot_abs_path": ""
|
||||
}
|
||||
PATH = "data/config/python_interpreter.json"
|
||||
|
||||
@@ -95,8 +96,14 @@ class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.workplace_path = os.path.join(self.curr_dir, "workplace")
|
||||
self.shared_path = os.path.join(self.curr_dir, "shared")
|
||||
|
||||
self.shared_path = os.path.join("data", "py_interpreter_shared")
|
||||
if not os.path.exists(self.shared_path):
|
||||
# 复制 api.py 到 shared 目录
|
||||
os.makedirs(self.shared_path, exist_ok=True)
|
||||
shared_api_file = os.path.join(self.curr_dir, "shared", "api.py")
|
||||
shutil.copy(shared_api_file, self.shared_path)
|
||||
self.workplace_path = os.path.join("data", "py_interpreter_workplace")
|
||||
os.makedirs(self.workplace_path, exist_ok=True)
|
||||
|
||||
self.user_file_msg_buffer = defaultdict(list)
|
||||
@@ -195,7 +202,16 @@ class Main(star.Star):
|
||||
@filter.command_group("pi")
|
||||
def pi(self):
|
||||
pass
|
||||
|
||||
|
||||
@pi.command("absdir")
|
||||
async def pi_absdir(self, event: AstrMessageEvent, path: str = ""):
|
||||
'''设置 Docker 宿主机绝对路径'''
|
||||
if not path:
|
||||
yield event.plain_result(f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}")
|
||||
else:
|
||||
self.config["docker_host_astrbot_abs_path"] = path
|
||||
self._save_config()
|
||||
yield event.plain_result(f"设置 Docker 宿主机绝对路径成功: {path}")
|
||||
|
||||
@pi.command("mirror")
|
||||
async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
|
||||
@@ -305,6 +321,20 @@ class Main(star.Star):
|
||||
|
||||
yield event.plain_result(f"使用沙箱执行代码中,请稍等...(尝试次数: {i+1}/{n})")
|
||||
|
||||
|
||||
self.docker_host_astrbot_abs_path = self.config.get("docker_host_astrbot_abs_path", "")
|
||||
if self.docker_host_astrbot_abs_path:
|
||||
host_shared = os.path.join(self.docker_host_astrbot_abs_path, self.shared_path)
|
||||
host_output = os.path.join(self.docker_host_astrbot_abs_path, output_path)
|
||||
host_workplace = os.path.join(self.docker_host_astrbot_abs_path, workplace_path)
|
||||
|
||||
else:
|
||||
host_shared = os.path.abspath(self.shared_path)
|
||||
host_output = os.path.abspath(output_path)
|
||||
host_workplace = os.path.abspath(workplace_path)
|
||||
|
||||
logger.debug(f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}")
|
||||
|
||||
container = await docker.containers.run({
|
||||
"Image": image_name,
|
||||
"Cmd": ["python", "exec.py"],
|
||||
@@ -312,9 +342,9 @@ class Main(star.Star):
|
||||
"NanoCPUs": 1000000000,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{self.shared_path}:/astrbot_sandbox/shared:ro",
|
||||
f"{output_path}:/astrbot_sandbox/output:rw",
|
||||
f"{workplace_path}:/astrbot_sandbox:rw",
|
||||
f"{host_shared}:/astrbot_sandbox/shared:ro",
|
||||
f"{host_output}:/astrbot_sandbox/output:rw",
|
||||
f"{host_workplace}:/astrbot_sandbox:rw",
|
||||
]
|
||||
},
|
||||
"Env": [
|
||||
|
||||
@@ -13,7 +13,7 @@ class Main(star.Star):
|
||||
'''使用 LLM 待办提醒。只需对 LLM 说想要提醒的事情和时间即可。比如:`之后每天这个时候都提醒我做多邻国`'''
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.scheduler = AsyncIOScheduler(timezone='Asia/Shanghai')
|
||||
|
||||
# set and load config
|
||||
if not os.path.exists("data/astrbot-reminder.json"):
|
||||
@@ -175,10 +175,18 @@ class Main(star.Star):
|
||||
else:
|
||||
reminder = reminders.pop(index - 1)
|
||||
job_id = reminder.get("id")
|
||||
|
||||
# self.reminder_data[event.unified_msg_origin] = reminder
|
||||
users_reminders = self.reminder_data.get(event.unified_msg_origin, [])
|
||||
for i, r in enumerate(users_reminders):
|
||||
if r.get("id") == job_id:
|
||||
users_reminders.pop(i)
|
||||
|
||||
try:
|
||||
self.scheduler.remove_job(job_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Remove job error: {e}")
|
||||
yield event.plain_result(f"成功移除对应的待办事项。删除定时任务失败: {str(e)} 可能需要重启 AstrBot 以取消该提醒任务。")
|
||||
await self._save_data()
|
||||
yield event.plain_result("成功删除待办事项:\n" + reminder["text"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user