Compare commits
66 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8b2b09e0f | ||
|
|
6858b8c555 | ||
|
|
0e493b1a0e | ||
|
|
37d478f970 | ||
|
|
7d0d42a49f | ||
|
|
0eb1684ef1 | ||
|
|
9b0b723143 | ||
|
|
532bc6e1e6 | ||
|
|
fe3ed4c454 | ||
|
|
b5ec89e586 | ||
|
|
895e7397c2 | ||
|
|
59b767957a | ||
|
|
17d4bf8f22 | ||
|
|
836be3b097 | ||
|
|
310415bea9 | ||
|
|
aafc1276a9 | ||
|
|
2993e794cc | ||
|
|
58cb9cfb2d | ||
|
|
fbdf0901d5 | ||
|
|
af8c81b621 | ||
|
|
06b5275e48 | ||
|
|
ad95572d5f | ||
|
|
aebc7850f4 | ||
|
|
1b7efbc607 | ||
|
|
3800e96d14 | ||
|
|
461f1bb07c | ||
|
|
7d4c07e4f6 | ||
|
|
31b788f463 | ||
|
|
96ab761f73 | ||
|
|
2b3f05c039 | ||
|
|
f2e8303b66 | ||
|
|
2a614b545b | ||
|
|
5c0ab21f68 | ||
|
|
689d109438 | ||
|
|
2a6934b283 | ||
|
|
760cb94e9a | ||
|
|
2a6cff0013 | ||
|
|
ce578f0417 | ||
|
|
1745bdb9e2 | ||
|
|
3f90b89c3c | ||
|
|
f343e40d15 | ||
|
|
5cc4be9e65 | ||
|
|
da5aada002 | ||
|
|
07f2ee9ad9 | ||
|
|
12f4e1146f | ||
|
|
92c57e5476 | ||
|
|
a923baacd8 | ||
|
|
999b094d55 | ||
|
|
d4213f2352 | ||
|
|
3f65c9a066 | ||
|
|
1d427e2645 | ||
|
|
36414c4b00 | ||
|
|
47e253d76c | ||
|
|
b73cf84df0 | ||
|
|
a5b885a774 | ||
|
|
0c785413da | ||
|
|
482d7ef5f7 | ||
|
|
9f9073c0ff | ||
|
|
ef05ff4abd | ||
|
|
5848aae435 | ||
|
|
fb06f33de0 | ||
|
|
0d7ddb149e | ||
|
|
4f2d7b9c4e | ||
|
|
c02ed96f6f | ||
|
|
3b2ac891b2 | ||
|
|
ef0108881b |
@@ -12,8 +12,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
|
||||
|
||||
RUN python -m pip install -r requirements.txt
|
||||
|
||||
EXPOSE 6185
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
|
||||
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
@@ -70,7 +72,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 微信(企业微信) | 🚧 | 计划内 | - |
|
||||
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
| 飞书 | 🚧 | 计划内 | - |
|
||||
| Discord | 🚧 | 计划内 | - |
|
||||
|
||||
@@ -11,7 +11,8 @@ from astrbot.core.config import AstrBotConfig
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
html_renderer = HtmlRenderer()
|
||||
t2i_base_url = astrbot_config.get('t2i_endpoint', 'https://t2i.soulter.top/text2img')
|
||||
html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
if os.environ.get('TESTING', ""):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.18"
|
||||
VERSION = "3.4.22"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -30,7 +30,8 @@ DEFAULT_CONFIG = {
|
||||
"only_llm_result": True,
|
||||
"interval": "1.5,3.5",
|
||||
"regex": ".*?[。?!~…]+|.+$"
|
||||
}
|
||||
},
|
||||
"no_permission_reply": True,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
@@ -55,14 +56,14 @@ DEFAULT_CONFIG = {
|
||||
"group_icl_enable": False,
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_provider_id": "",
|
||||
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||
"active_reply": {
|
||||
"enable": False,
|
||||
"method": "possibility_reply",
|
||||
"possibility_reply": 0.1,
|
||||
"prompt": "",
|
||||
},
|
||||
"put_history_to_prompt": True,
|
||||
}
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
@@ -70,6 +71,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
"admins_id": [],
|
||||
"t2i": False,
|
||||
"t2i_word_threshold": 150,
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
@@ -109,7 +111,7 @@ CONFIG_METADATA_2 = {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_host": "0.0.0.0",
|
||||
"ws_reverse_port": 6199,
|
||||
},
|
||||
"gewechat(微信)": {
|
||||
@@ -194,6 +196,11 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"no_permission_reply": {
|
||||
"description": "无权限回复",
|
||||
"type": "bool",
|
||||
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
|
||||
},
|
||||
"segmented_reply": {
|
||||
"description": "分段回复",
|
||||
"type": "object",
|
||||
@@ -238,7 +245,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
||||
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单",
|
||||
},
|
||||
"id_whitelist_log": {
|
||||
"description": "打印白名单日志",
|
||||
@@ -321,11 +328,24 @@ CONFIG_METADATA_2 = {
|
||||
"type": "list",
|
||||
"config_template": {
|
||||
"openai": {
|
||||
"id": "default",
|
||||
"id": "openai",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"azure_openai": {
|
||||
"id": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
@@ -346,6 +366,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
},
|
||||
@@ -356,6 +377,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
},
|
||||
@@ -366,6 +388,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek-chat",
|
||||
},
|
||||
@@ -375,6 +398,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "zhipu_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
@@ -385,6 +409,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"model_config": {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
@@ -436,6 +461,11 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"timeout": {
|
||||
"description": "超时时间",
|
||||
"type": "int",
|
||||
"hint": "超时时间,单位为秒。",
|
||||
},
|
||||
"openai-tts-voice": {
|
||||
"description": "voice",
|
||||
"type": "string",
|
||||
@@ -687,6 +717,12 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"description": "图像转述提供商 ID",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"description": "图像转述提示词",
|
||||
"type": "string"
|
||||
@@ -717,15 +753,9 @@ CONFIG_METADATA_2 = {
|
||||
"description": "提示词",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,prompt是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"put_history_to_prompt": {
|
||||
"description": "将群聊历史记录作为 prompt",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "需要先启用 group_icl_enable。此功能会将群聊历史记录放到 prompt 再请求。如果关闭,则是放在 system_prompt。如果开启了主动回复,建议启用,模型能够更好地完成回复任务。",
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -746,11 +776,16 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。",
|
||||
},
|
||||
"t2i_word_threshold": {
|
||||
"description": "文本转图像字数阈值",
|
||||
"type": "int",
|
||||
"hint": "超出此字符长度的文本将会被转换成图片。字数不能低于 50。",
|
||||
},
|
||||
"admins_id": {
|
||||
"description": "管理员 ID",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
|
||||
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/sid` 指令获得。回车添加,可添加多个。",
|
||||
},
|
||||
"http_proxy": {
|
||||
"description": "HTTP 代理",
|
||||
|
||||
118
astrbot/core/conversation_mgr.py
Normal file
118
astrbot/core/conversation_mgr.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation
|
||||
|
||||
class ConversationManager():
|
||||
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
self._start_periodic_save()
|
||||
|
||||
def _start_periodic_save(self):
|
||||
asyncio.create_task(self._periodic_save())
|
||||
|
||||
async def _periodic_save(self):
|
||||
while True:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
self._save_to_storage()
|
||||
|
||||
def _save_to_storage(self):
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||
'''新建对话,并将当前会话的对话转移到新对话'''
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.new_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id
|
||||
)
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
return conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
'''切换会话的对话'''
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
|
||||
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.delete_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id
|
||||
)
|
||||
del self.session_conversations[unified_msg_origin]
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||
'''获取会话当前的对话 ID'''
|
||||
return self.session_conversations.get(unified_msg_origin, None)
|
||||
|
||||
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
|
||||
'''获取会话的对话'''
|
||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
|
||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||
'''获取会话的所有对话'''
|
||||
return self.db.get_conversations(unified_msg_origin)
|
||||
|
||||
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
|
||||
'''更新会话的对话'''
|
||||
if conversation_id:
|
||||
self.db.update_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
history=json.dumps(history)
|
||||
)
|
||||
|
||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||
'''更新会话的对话标题'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_title(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
title=title
|
||||
)
|
||||
|
||||
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
|
||||
'''更新会话的对话 Persona ID'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_persona_id(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
persona_id=persona_id
|
||||
)
|
||||
|
||||
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
|
||||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||
history = json.loads(conversation.history)
|
||||
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in history:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
@@ -18,7 +18,7 @@ from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
self.log_broker = log_broker
|
||||
@@ -43,12 +43,15 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self.astrbot_config,
|
||||
self.db,
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager
|
||||
)
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
@@ -79,25 +79,35 @@ class BaseDatabase(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
||||
'''通过 user_id 和 cid 获取 WebChatConversation'''
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
'''通过 user_id 和 cid 获取 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
||||
'''新建 WebChatConversation'''
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
'''新建 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
|
||||
def get_conversations(self, user_id: str) -> List[Conversation]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新 WebChatConversation'''
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
||||
'''删除 WebChatConversation'''
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
'''删除 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
'''更新 Conversation 标题'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
'''更新 Conversation Persona ID'''
|
||||
raise NotImplementedError
|
||||
@@ -33,16 +33,16 @@ class Stats():
|
||||
command: List[Command] = field(default_factory=list)
|
||||
llm: List[Provider] = field(default_factory=list)
|
||||
|
||||
'''LLM 聊天时持久化的信息'''
|
||||
|
||||
@dataclass
|
||||
class LLMHistory():
|
||||
'''LLM 聊天时持久化的信息'''
|
||||
provider_type: str
|
||||
session_id: str
|
||||
content: str
|
||||
|
||||
@dataclass
|
||||
class ATRIVision():
|
||||
'''Deprecated'''
|
||||
id: str
|
||||
url_or_path: str
|
||||
caption: str
|
||||
@@ -53,13 +53,18 @@ class ATRIVision():
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebChatConversation():
|
||||
class Conversation():
|
||||
'''LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
'''
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
'''字符串格式的列表。'''
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
@@ -6,7 +6,7 @@ from astrbot.core.db.po import (
|
||||
Stats,
|
||||
LLMHistory,
|
||||
ATRIVision,
|
||||
WebChatConversation
|
||||
Conversation
|
||||
)
|
||||
from . import BaseDatabase
|
||||
from typing import Tuple
|
||||
@@ -25,6 +25,37 @@ class SQLiteDatabase(BaseDatabase):
|
||||
c = self.conn.cursor()
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
# 检查 webchat_conversation 的 title 字段是否存在
|
||||
c.execute(
|
||||
'''
|
||||
PRAGMA table_info(webchat_conversation)
|
||||
'''
|
||||
)
|
||||
res = c.fetchall()
|
||||
has_title = False
|
||||
has_persona_id = False
|
||||
for row in res:
|
||||
if row[1] == "title":
|
||||
has_title = True
|
||||
if row[1] == "persona_id":
|
||||
has_persona_id = True
|
||||
if not has_title:
|
||||
c.execute(
|
||||
'''
|
||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||
'''
|
||||
)
|
||||
self.conn.commit()
|
||||
if not has_persona_id:
|
||||
c.execute(
|
||||
'''
|
||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||
'''
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
c.close()
|
||||
|
||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -202,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return Stats(platform, [], [])
|
||||
|
||||
|
||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -216,9 +247,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return WebChatConversation(*res)
|
||||
return Conversation(*res)
|
||||
|
||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
@@ -228,7 +259,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
''', (user_id, cid, history, updated_at, created_at)
|
||||
)
|
||||
|
||||
def get_webchat_conversations(self, user_id: str) -> Tuple:
|
||||
def get_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -236,7 +267,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
''', (user_id,)
|
||||
)
|
||||
|
||||
@@ -247,24 +278,42 @@ class SQLiteDatabase(BaseDatabase):
|
||||
cid = row[0]
|
||||
created_at = row[1]
|
||||
updated_at = row[2]
|
||||
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
|
||||
title = row[3]
|
||||
persona_id = row[4]
|
||||
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
|
||||
return conversations
|
||||
|
||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新对话,并且同时更新时间'''
|
||||
updated_at = int(time.time())
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
|
||||
''', (history, user_id, cid)
|
||||
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
||||
''', (history, updated_at, user_id, cid)
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||
''', (title, user_id, cid)
|
||||
)
|
||||
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||
''', (persona_id, user_id, cid)
|
||||
)
|
||||
|
||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
)
|
||||
|
||||
|
||||
def insert_atri_vision_data(self, vision: ATRIVision):
|
||||
ts = int(time.time())
|
||||
keywords = ",".join(vision.keywords)
|
||||
|
||||
@@ -42,5 +42,7 @@ CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
cid TEXT,
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
@@ -2,6 +2,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
|
||||
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
@@ -11,7 +12,7 @@ from .respond.stage import RespondStage
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"RateLimitCheckStage", # 检查会话是否超过频率限制
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
@@ -22,6 +23,7 @@ STAGES_ORDER = [
|
||||
__all__ = [
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
'''
|
||||
import traceback
|
||||
import json
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
@@ -10,7 +11,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -24,6 +25,8 @@ class LLMRequestSubStage(Stage):
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
|
||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
@@ -46,10 +49,17 @@ class LLMRequestSubStage(Stage):
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
req.image_urls.append(image_url)
|
||||
req.session_id = event.session_id
|
||||
|
||||
# 获取对话上下文
|
||||
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
|
||||
if not conversation_id:
|
||||
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
|
||||
req.session_id = conversation_id
|
||||
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
req.contexts = session_provider_context if session_provider_context else []
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
@@ -62,9 +72,12 @@ class LLMRequestSubStage(Stage):
|
||||
await handler.handler(event, req)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
try:
|
||||
logger.debug(f"提供商请求 Payload: {req.__dict__}")
|
||||
logger.debug(f"提供商请求 Payload: {req}")
|
||||
if _nested:
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
@@ -77,6 +90,9 @@ class LLMRequestSubStage(Stage):
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, llm_response)
|
||||
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
@@ -113,8 +129,31 @@ class LLMRequestSubStage(Stage):
|
||||
req.prompt += extra_prompt
|
||||
async for _ in self.process(event, _nested=True):
|
||||
yield
|
||||
else:
|
||||
if llm_response.completion_text:
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text))
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
|
||||
return
|
||||
return
|
||||
|
||||
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
contexts = req.contexts
|
||||
new_record = {
|
||||
"role": "user",
|
||||
"content": req.prompt
|
||||
}
|
||||
contexts.append(new_record)
|
||||
contexts.append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.session_id,
|
||||
history=contexts_to_save
|
||||
)
|
||||
@@ -61,11 +61,12 @@ class RateLimitStage(Stage):
|
||||
stall_duration = (next_window_time - now).total_seconds()
|
||||
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
|
||||
await asyncio.sleep(stall_duration)
|
||||
case RateLimitStrategy.DISCARD:
|
||||
event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
logger.info(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。")
|
||||
return event.stop_event()
|
||||
|
||||
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))
|
||||
|
||||
@@ -19,6 +19,13 @@ class ResultDecorateStage:
|
||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
|
||||
self.t2i_word_threshold = ctx.astrbot_config['t2i_word_threshold']
|
||||
try:
|
||||
self.t2i_word_threshold = int(self.t2i_word_threshold)
|
||||
if self.t2i_word_threshold < 50:
|
||||
self.t2i_word_threshold = 50
|
||||
except BaseException:
|
||||
self.t2i_word_threshold = 150
|
||||
|
||||
# 分段回复
|
||||
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
|
||||
@@ -49,12 +56,13 @@ class ResultDecorateStage:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
|
||||
split_response = re.findall(self.regex, comp.text)
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
for seg in split_response:
|
||||
new_chain.append(Plain(seg))
|
||||
if seg:
|
||||
new_chain.append(Plain(seg))
|
||||
else:
|
||||
# 非 Plain 类型的消息段不分段
|
||||
new_chain.append(comp)
|
||||
@@ -91,7 +99,7 @@ class ResultDecorateStage:
|
||||
plain_str += "\n\n" + comp.text
|
||||
else:
|
||||
break
|
||||
if plain_str and len(plain_str) > 150:
|
||||
if plain_str and len(plain_str) > self.t2i_word_threshold:
|
||||
render_start = time.time()
|
||||
try:
|
||||
url = await html_renderer.render_t2i(plain_str, return_url=True)
|
||||
|
||||
@@ -2,11 +2,11 @@ from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
|
||||
@register_stage
|
||||
class WakingCheckStage(Stage):
|
||||
@@ -21,6 +21,9 @@ class WakingCheckStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"no_permission_reply", True
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -77,7 +80,9 @@ class WakingCheckStage(Stage):
|
||||
# filter 需要满足 AND 的逻辑关系
|
||||
passed = True
|
||||
child_command_handler_md = None
|
||||
|
||||
|
||||
permission_not_pass = False
|
||||
|
||||
if len(handler.event_filters) == 0:
|
||||
# 不可能有这种情况, 也不允许有这种情况
|
||||
continue
|
||||
@@ -94,6 +99,9 @@ class WakingCheckStage(Stage):
|
||||
else:
|
||||
handler = child_command_handler_md # handler 覆盖
|
||||
break
|
||||
elif isinstance(filter, PermissionTypeFilter):
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
permission_not_pass = True
|
||||
else:
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
passed = False
|
||||
@@ -111,6 +119,13 @@ class WakingCheckStage(Stage):
|
||||
break
|
||||
|
||||
if passed:
|
||||
|
||||
if permission_not_pass:
|
||||
if self.no_permission_reply:
|
||||
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"))
|
||||
event.stop_event()
|
||||
return
|
||||
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
|
||||
|
||||
@@ -31,12 +31,19 @@ class AstrMessageEvent(abc.ABC):
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,):
|
||||
self.message_str = message_str
|
||||
'''纯文本的消息'''
|
||||
self.message_obj = message_obj
|
||||
'''消息对象,AstrBotMessage。带有完整的消息结构。'''
|
||||
self.platform_meta = platform_meta
|
||||
'''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp'''
|
||||
self.session_id = session_id
|
||||
'''用户的会话 ID。可以直接使用下面的 unified_msg_origin'''
|
||||
self.role = "member"
|
||||
'''用户是否是管理员。如果是管理员,这里是 admin'''
|
||||
self.is_wake = False # 是否通过 WakingStage
|
||||
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True)
|
||||
'''是否唤醒'''
|
||||
self.is_at_or_wake_command = False
|
||||
'''是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True,但是不会让这个属性置为 True)'''
|
||||
self._extras = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.name,
|
||||
@@ -44,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
session_id=session_id
|
||||
)
|
||||
self.unified_msg_origin = str(self.session)
|
||||
|
||||
'''统一的消息来源字符串。格式为 platform_name:message_type:session_id'''
|
||||
self._result: MessageEventResult = None
|
||||
'''消息事件的结果'''
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ class AiocqhttpAdapter(Platform):
|
||||
if not ret.get('file', None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret['file']):
|
||||
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
|
||||
raise FileNotFoundError(f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot")
|
||||
|
||||
m['data'] = {
|
||||
"file": ret['file'],
|
||||
@@ -122,7 +122,10 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
if not self.host or not self.port:
|
||||
return
|
||||
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199")
|
||||
self.host = "0.0.0.0"
|
||||
self.port = 6199
|
||||
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
|
||||
@@ -194,7 +194,7 @@ class SimpleGewechatClient():
|
||||
async def start_polling(self):
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
await self.server.run_task(
|
||||
host=self.host,
|
||||
host='0.0.0.0',
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger_placeholder
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from astrbot.core.db.po import Conversation
|
||||
|
||||
|
||||
class ProviderType(enum.Enum):
|
||||
@@ -38,10 +39,15 @@ class ProviderRequest():
|
||||
'''上下文。格式与 openai 的上下文格式一致:
|
||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||
'''
|
||||
|
||||
system_prompt: str = ""
|
||||
'''系统提示词'''
|
||||
conversation: Conversation = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
|
||||
@@ -121,7 +121,8 @@ class FuncCall:
|
||||
|
||||
tools.append(func_declaration)
|
||||
|
||||
declarations["function_declarations"] = tools
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,16 @@ class ProviderManager():
|
||||
if not self.selected_default_persona and len(self.personas) > 0:
|
||||
# 默认选择第一个
|
||||
self.selected_default_persona = self.personas[0]
|
||||
|
||||
|
||||
if not self.selected_default_persona:
|
||||
self.selected_default_persona = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
_begin_dialogs_processed=[],
|
||||
_mood_imitation_dialogs_processed=""
|
||||
)
|
||||
self.personas.append(self.selected_default_persona)
|
||||
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import TypedDict
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
@@ -15,8 +17,8 @@ class Personality(TypedDict):
|
||||
mood_imitation_dialogs: List[str] = []
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: List[dict]
|
||||
_mood_imitation_dialogs_processed: str
|
||||
_begin_dialogs_processed: List[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -60,25 +62,11 @@ class Provider(AbstractProvider):
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.session_memory = defaultdict(list)
|
||||
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality: Personality = default_persona
|
||||
'''维护了当前的使用的 persona,即人格。可能为 None'''
|
||||
|
||||
self.db_helper = db_helper
|
||||
'''用于持久化的数据库操作对象。'''
|
||||
|
||||
if persistant_history:
|
||||
# 读取历史记录
|
||||
try:
|
||||
for history in db_helper.get_llm_history(provider_type=provider_config['id']):
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
raise NotImplementedError()
|
||||
@@ -96,22 +84,6 @@ class Provider(AbstractProvider):
|
||||
'''获得支持的模型列表'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
|
||||
'''获取人类可读的上下文
|
||||
|
||||
page 从 1 开始
|
||||
|
||||
Example:
|
||||
|
||||
["User: 你好", "Assistant: 你好!"]
|
||||
|
||||
Return:
|
||||
contexts: List[str]: 上下文列表
|
||||
total_pages: int: 总页数
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
@@ -125,26 +97,35 @@ class Provider(AbstractProvider):
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
session_id: 会话 ID
|
||||
session_id: 会话 ID(此属性已经被废弃)
|
||||
image_urls: 图片 URL 列表
|
||||
tools: Function-calling 工具
|
||||
contexts: 上下文
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
- 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。
|
||||
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
|
||||
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
'''重置某一个 session_id 的上下文'''
|
||||
raise NotImplementedError()
|
||||
|
||||
async def pop_record(self, context: List):
|
||||
'''
|
||||
弹出 context 第一条非系统提示词对话记录
|
||||
'''
|
||||
poped = 0
|
||||
indexs_to_pop = []
|
||||
for idx, record in enumerate(context):
|
||||
if record["role"] == "system":
|
||||
continue
|
||||
else:
|
||||
indexs_to_pop.append(idx)
|
||||
poped += 1
|
||||
if poped == 2:
|
||||
break
|
||||
|
||||
for idx in reversed(indexs_to_pop):
|
||||
context.pop(idx)
|
||||
|
||||
|
||||
class STTProvider(AbstractProvider):
|
||||
|
||||
@@ -39,7 +39,7 @@ class ProviderDify(Provider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
@@ -64,8 +64,6 @@ class ProviderDify(Provider):
|
||||
else:
|
||||
# TODO: 处理更多情况
|
||||
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
|
||||
|
||||
logger.debug(files_payload)
|
||||
|
||||
# 获得会话变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
@@ -115,7 +113,6 @@ class ProviderDify(Provider):
|
||||
result = chunk['data']['outputs'][self.workflow_output_key]
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=result)
|
||||
|
||||
async def forget(self, session_id):
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
import aiohttp
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -12,17 +10,18 @@ from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
|
||||
class SimpleGoogleGenAIClient():
|
||||
def __init__(self, api_key: str, api_base: str):
|
||||
def __init__(self, api_key: str, api_base: str, timeout: int=120) -> None:
|
||||
self.api_key = api_key
|
||||
if api_base.endswith("/"):
|
||||
self.api_base = api_base[:-1]
|
||||
else:
|
||||
self.api_base = api_base
|
||||
self.client = aiohttp.ClientSession(trust_env=True)
|
||||
self.timeout = timeout
|
||||
|
||||
async def models_list(self) -> List[str]:
|
||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
||||
async with self.client.get(request_url, timeout=10) as resp:
|
||||
async with self.client.get(request_url, timeout=self.timeout) as resp:
|
||||
response = await resp.json()
|
||||
|
||||
models = []
|
||||
@@ -48,7 +47,7 @@ class SimpleGoogleGenAIClient():
|
||||
payload["contents"] = contents
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
async with self.client.post(request_url, json=payload, timeout=10) as resp:
|
||||
async with self.client.post(request_url, json=payload, timeout=self.timeout) as resp:
|
||||
response = await resp.json()
|
||||
return response
|
||||
|
||||
@@ -67,66 +66,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
|
||||
self.timeout = provider_config.get("timeout", 180)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.client = SimpleGoogleGenAIClient(
|
||||
api_key=self.chosen_api_key,
|
||||
api_base=provider_config.get("api_base", None)
|
||||
api_base=provider_config.get("api_base", None),
|
||||
timeout=self.timeout
|
||||
)
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
'''
|
||||
弹出第一条记录
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
return record
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
tool = None
|
||||
if tools:
|
||||
@@ -181,6 +133,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
if "candidates" not in result:
|
||||
raise Exception("Gemini 返回异常结果: " + str(result))
|
||||
|
||||
candidates = result["candidates"][0]['content']['parts']
|
||||
llm_response = LLMResponse("assistant")
|
||||
for candidate in candidates:
|
||||
@@ -194,46 +149,43 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.completion_text = llm_response.completion_text.strip()
|
||||
return llm_response
|
||||
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if '_no_save' in part:
|
||||
del part['_no_save']
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config['model'] = self.get_model()
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**self.provider_config.get("model_config", {})
|
||||
**model_config
|
||||
}
|
||||
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 10
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -241,32 +193,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
if retry_cnt == 0:
|
||||
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
|
||||
elif "Function calling is not enabled" in str(e):
|
||||
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
|
||||
if 'tools' in payloads:
|
||||
del payloads['tools']
|
||||
llm_response = await self._query(payloads, None)
|
||||
else:
|
||||
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
raise e
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
return llm_response
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
@@ -287,8 +226,14 @@ class ProviderGoogleGenAI(Provider):
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
return user_content
|
||||
else:
|
||||
|
||||
@@ -57,20 +57,13 @@ class LLMTunerModelLoader(Provider):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
contexts: List = [],
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
system_prompt = ""
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
if not contexts:
|
||||
query_context = [
|
||||
*self.session_memory[session_id],
|
||||
new_record,
|
||||
]
|
||||
system_prompt = self.curr_personality["prompt"]
|
||||
else:
|
||||
query_context = [*contexts, new_record]
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
@@ -96,34 +89,8 @@ class LLMTunerModelLoader(Provider):
|
||||
responses = await self.model.achat(**conf)
|
||||
|
||||
llm_response = LLMResponse("assistant", responses[-1].response_text)
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
@@ -132,28 +99,4 @@ class LLMTunerModelLoader(Provider):
|
||||
pass
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record["role"] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record["role"] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
return [self.get_model()]
|
||||
@@ -2,9 +2,9 @@ import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI, NOT_GIVEN
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import NotFoundError
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -29,37 +29,27 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
# 适配 azure openai #332
|
||||
if "api_version" in provider_config:
|
||||
# 使用 azure api
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
api_version=provider_config.get("api_version", None),
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=self.timeout
|
||||
)
|
||||
else:
|
||||
# 使用 openai api
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
@@ -72,22 +62,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def pop_record(self, session_id: str):
|
||||
'''
|
||||
弹出最早的一个对话
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) < 2:
|
||||
return
|
||||
|
||||
try:
|
||||
self.session_memory[session_id].pop(0)
|
||||
self.session_memory[session_id].pop(0)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
tool_list = tools.get_func_desc_openai_style()
|
||||
@@ -106,12 +80,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if choice.message.content:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
|
||||
return LLMResponse("assistant", completion_text, raw_completion=completion)
|
||||
elif choice.message.tool_calls:
|
||||
llm_response.completion_text = completion_text
|
||||
|
||||
if choice.message.tool_calls:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
@@ -121,49 +97,61 @@ class ProviderOpenAIOfficial(Provider):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls, raw_completion=completion)
|
||||
else:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args = args_ls
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception("Internal Error")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
session_id: str=None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if '_no_save' in part:
|
||||
del part['_no_save']
|
||||
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config['model'] = self.get_model()
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**self.provider_config.get("model_config", {})
|
||||
**model_config
|
||||
}
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except UnprocessableEntityError as e:
|
||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads['messages'] = new_contexts
|
||||
context_query = new_contexts
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
# 重试 10 次
|
||||
retry_cnt = 10
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning("上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
|
||||
try:
|
||||
await self.pop_record(session_id)
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -175,17 +163,18 @@ class ProviderOpenAIOfficial(Provider):
|
||||
llm_response = LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
|
||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||
# 尝试删除所有 image
|
||||
print(context_query)
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
print(new_contexts)
|
||||
payloads['messages'] = new_contexts
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
|
||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||
elif 'does not support Function Calling' in str(e) \
|
||||
or 'does not support tools' in str(e) \
|
||||
or 'Function call is not supported' in str(e) \
|
||||
or 'Tool calling is not supported' in str(e): # siliconcloud
|
||||
logger.info(f"{self.get_model()} 不支持函数调用工具调用,已经自动去除")
|
||||
or 'Function calling is not enabled' in str(e) \
|
||||
or 'Tool calling is not supported' in str(e) \
|
||||
or 'No endpoints found that support tool use' in str(e): # siliconcloud
|
||||
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
|
||||
if 'tools' in payloads:
|
||||
del payloads['tools']
|
||||
llm_response = await self._query(payloads, None)
|
||||
@@ -193,7 +182,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
if 'tool' in str(e).lower() and 'support' in str(e).lower():
|
||||
logger.error(f"疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
|
||||
if 'Connection error.' in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
@@ -202,9 +191,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
raise e
|
||||
|
||||
if kwargs.get("persist", True) and llm_response:
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def _remove_image_from_context(self, contexts: List):
|
||||
@@ -232,32 +218,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
context['content'] = new_content
|
||||
new_contexts.append(context)
|
||||
return new_contexts
|
||||
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
|
||||
@@ -277,10 +238,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
if image_url.startswith("file:///"):
|
||||
image_url = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
return user_content
|
||||
else:
|
||||
|
||||
@@ -22,24 +22,21 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
context_query = [*contexts, new_record]
|
||||
|
||||
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||
model = self.get_model()
|
||||
# glm-4v-flash 只支持一张图片
|
||||
model: str = model_cfgs.get("model", "")
|
||||
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
|
||||
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
||||
logger.debug(context_query)
|
||||
@@ -62,7 +59,6 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
}
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
|
||||
@@ -16,6 +16,7 @@ from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
|
||||
class Context:
|
||||
'''
|
||||
@@ -44,6 +45,7 @@ class Context:
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
@@ -52,6 +54,7 @@ class Context:
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
'''根据插件名获取插件的 Metadata'''
|
||||
|
||||
@@ -46,7 +46,11 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
if not event.is_wake_up():
|
||||
return False
|
||||
|
||||
message_str = event.get_message_str().strip()
|
||||
if event.get_extra("parsing_command"):
|
||||
message_str = event.get_extra("parsing_command").strip()
|
||||
else:
|
||||
message_str = event.get_message_str().strip()
|
||||
|
||||
# 分割为列表(每个参数之间可能会有多个空格)
|
||||
ls = re.split(r"\s+", message_str)
|
||||
if self.command_name != ls[0]:
|
||||
|
||||
@@ -40,17 +40,24 @@ class CommandGroupFilter(HandlerFilter):
|
||||
if not event.is_wake_up():
|
||||
return False, None
|
||||
|
||||
message_str = event.get_message_str().strip()
|
||||
if event.get_extra("parsing_command"):
|
||||
message_str = event.get_extra("parsing_command").strip()
|
||||
else:
|
||||
message_str = event.get_message_str().strip()
|
||||
|
||||
ls = re.split(r"\s+", message_str)
|
||||
|
||||
if ls[0] != self.group_name:
|
||||
return False, None
|
||||
# 改写 message_str
|
||||
ls = ls[1:]
|
||||
event.message_str = " ".join(ls)
|
||||
event.message_str = event.message_str.strip()
|
||||
# event.message_str = " ".join(ls)
|
||||
# event.message_str = event.message_str.strip()
|
||||
parsing_command = " ".join(ls)
|
||||
parsing_command = parsing_command.strip()
|
||||
event.set_extra("parsing_command", parsing_command)
|
||||
|
||||
if event.message_str == "":
|
||||
if parsing_command == "":
|
||||
# 当前还是指令组
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
|
||||
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
|
||||
|
||||
@@ -19,7 +19,8 @@ class PermissionTypeFilter(HandlerFilter):
|
||||
'''
|
||||
if self.permission_type == PermissionType.ADMIN:
|
||||
if not event.is_admin():
|
||||
event.stop_event()
|
||||
raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限执行此操作。")
|
||||
# event.stop_event()
|
||||
# raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限操作管理员指令。")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.config import AstrBotConfig
|
||||
class RegexFilter(HandlerFilter):
|
||||
'''正则表达式过滤器'''
|
||||
def __init__(self, regex: str):
|
||||
self.regex_str = regex
|
||||
self.regex = re.compile(regex)
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
|
||||
@@ -17,7 +17,12 @@ def get_handler_full_name(awaitable: Awaitable) -> str:
|
||||
'''获取 Handler 的全名'''
|
||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||
|
||||
def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False) -> StarHandlerMetadata:
|
||||
def get_handler_or_create(
|
||||
handler: Awaitable,
|
||||
event_type: EventType,
|
||||
dont_add = False,
|
||||
**kwargs
|
||||
) -> StarHandlerMetadata:
|
||||
'''获取 Handler 或者创建一个新的 Handler'''
|
||||
handler_full_name = get_handler_full_name(handler)
|
||||
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
|
||||
@@ -32,12 +37,24 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
|
||||
handler=handler,
|
||||
event_filters=[]
|
||||
)
|
||||
|
||||
# 插件handler的附加额外信息
|
||||
if handler.__doc__:
|
||||
md.desc = handler.__doc__.strip()
|
||||
if 'desc' in kwargs:
|
||||
md.desc = kwargs['desc']
|
||||
del kwargs['desc']
|
||||
md.extras_configs = kwargs
|
||||
|
||||
if not dont_add:
|
||||
star_handlers_registry.append(md)
|
||||
return md
|
||||
|
||||
def register_command(command_name: str = None, *args):
|
||||
'''注册一个 Command'''
|
||||
def register_command(command_name: str = None, *args, **kwargs):
|
||||
'''注册一个 Command.
|
||||
'''
|
||||
|
||||
# print("command: ", command_name, args, kwargs)
|
||||
|
||||
new_command = None
|
||||
add_to_event_filters = False
|
||||
@@ -51,7 +68,7 @@ def register_command(command_name: str = None, *args):
|
||||
add_to_event_filters = True
|
||||
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
new_command.init_handler_md(handler_md)
|
||||
if add_to_event_filters:
|
||||
# 裸指令
|
||||
@@ -61,8 +78,11 @@ def register_command(command_name: str = None, *args):
|
||||
|
||||
return decorator
|
||||
|
||||
def register_command_group(command_group_name: str = None, *args):
|
||||
'''注册一个 CommandGroup'''
|
||||
def register_command_group(command_group_name: str = None, *args, **kwargs):
|
||||
'''注册一个 CommandGroup
|
||||
'''
|
||||
|
||||
# print("commandgroup: ", command_group_name,args, kwargs)
|
||||
|
||||
new_group = None
|
||||
add_to_event_filters = False
|
||||
@@ -78,7 +98,7 @@ def register_command_group(command_group_name: str = None, *args):
|
||||
def decorator(obj):
|
||||
if add_to_event_filters:
|
||||
# 根指令组
|
||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent)
|
||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(new_group)
|
||||
|
||||
return RegisteringCommandable(new_group)
|
||||
@@ -93,16 +113,16 @@ class RegisteringCommandable():
|
||||
def __init__(self, parent_group: CommandGroupFilter):
|
||||
self.parent_group = parent_group
|
||||
|
||||
def register_event_message_type(event_message_type: EventMessageType):
|
||||
def register_event_message_type(event_message_type: EventMessageType, **kwargs):
|
||||
'''注册一个 EventMessageType'''
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, kwargs)
|
||||
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
|
||||
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType, **kwargs):
|
||||
'''注册一个 PlatformAdapterType'''
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
@@ -111,10 +131,10 @@ def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
|
||||
|
||||
return decorator
|
||||
|
||||
def register_regex(regex: str):
|
||||
def register_regex(regex: str, **kwargs):
|
||||
'''注册一个 Regex'''
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(RegexFilter(regex))
|
||||
return awaitable
|
||||
|
||||
@@ -134,7 +154,7 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_llm_request():
|
||||
def register_on_llm_request(**kwargs):
|
||||
'''当有 LLM 请求时的事件
|
||||
|
||||
Examples:
|
||||
@@ -149,12 +169,12 @@ def register_on_llm_request():
|
||||
请务必接收两个参数:event, request
|
||||
'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent)
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_llm_response():
|
||||
def register_on_llm_response(**kwargs):
|
||||
'''当有 LLM 请求后的事件
|
||||
|
||||
Examples:
|
||||
@@ -169,7 +189,7 @@ def register_on_llm_response():
|
||||
请务必接收两个参数:event, request
|
||||
'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent)
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
@@ -215,18 +235,18 @@ def register_llm_tool(name: str = None):
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_decorating_result():
|
||||
def register_on_decorating_result(**kwargs):
|
||||
'''在发送消息前的事件'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent)
|
||||
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_after_message_sent():
|
||||
def register_after_message_sent(**kwargs):
|
||||
'''在消息发送后的事件'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent)
|
||||
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
star_registry: List[StarMetadata] = []
|
||||
@@ -39,6 +39,9 @@ class StarMetadata:
|
||||
|
||||
config: AstrBotConfig = None
|
||||
'''插件配置'''
|
||||
|
||||
star_handler_full_names: List[str] = field(default_factory=list)
|
||||
'''注册的 Handler 的全名列表'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
@@ -1,34 +1,41 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
import heapq
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
from .star import star_map
|
||||
|
||||
T = TypeVar('T', bound='StarHandlerMetadata')
|
||||
class StarHandlerRegistry(Generic[T], List[T]):
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
'''用于存储所有的 Star Handler'''
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
'''用于快速查找。key 是 handler_full_name'''
|
||||
_handlers = []
|
||||
|
||||
def append(self, handler: StarHandlerMetadata):
|
||||
'''添加一个 Handler'''
|
||||
super().append(handler)
|
||||
if 'priority' not in handler.extras_configs:
|
||||
handler.extras_configs['priority'] = 0
|
||||
|
||||
heapq.heappush(self._handlers, (-handler.extras_configs['priority'], handler))
|
||||
self.star_handlers_map[handler.handler_full_name] = handler
|
||||
|
||||
def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]:
|
||||
def _print_handlers(self):
|
||||
'''打印所有的 Handler'''
|
||||
for _, handler in self._handlers:
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(self, event_type: EventType, only_activated=True) -> List[StarHandlerMetadata]:
|
||||
'''通过事件类型获取 Handler'''
|
||||
if only_activated:
|
||||
return [
|
||||
handler
|
||||
for handler in self
|
||||
if handler.event_type == event_type and
|
||||
star_map[handler.handler_module_path] and
|
||||
star_map[handler.handler_module_path].activated
|
||||
]
|
||||
else:
|
||||
return [handler for handler in self if handler.event_type == event_type]
|
||||
handlers = [
|
||||
handler
|
||||
for _, handler in self._handlers
|
||||
if handler.event_type == event_type and
|
||||
(not only_activated or (star_map[handler.handler_module_path] and star_map[handler.handler_module_path].activated))
|
||||
]
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
'''通过 Handler 的全名获取 Handler'''
|
||||
@@ -36,7 +43,32 @@ class StarHandlerRegistry(Generic[T], List[T]):
|
||||
|
||||
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
|
||||
'''通过模块名获取 Handler'''
|
||||
return [handler for handler in self if handler.handler_module_path == module_name]
|
||||
return [handler for _, handler in self._handlers if handler.handler_module_path == module_name]
|
||||
|
||||
def clear(self):
|
||||
'''清空所有的 Handler'''
|
||||
self.star_handlers_map.clear()
|
||||
self._handlers.clear()
|
||||
|
||||
def remove(self, handler: StarHandlerMetadata):
|
||||
'''删除一个 Handler'''
|
||||
# self._handlers.remove(handler)
|
||||
for i, h in enumerate(self._handlers):
|
||||
if h[1] == handler:
|
||||
self._handlers.pop(i)
|
||||
break
|
||||
try:
|
||||
del self.star_handlers_map[handler.handler_full_name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
'''使 StarHandlerRegistry 支持迭代'''
|
||||
return (handler for _, handler in self._handlers)
|
||||
|
||||
def __len__(self):
|
||||
'''返回 Handler 的数量'''
|
||||
return len(self._handlers)
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
@@ -76,3 +108,10 @@ class StarHandlerMetadata():
|
||||
|
||||
desc: str = ""
|
||||
'''Handler 的描述信息'''
|
||||
|
||||
extras_configs: dict = field(default_factory=dict)
|
||||
'''插件注册的一些其他的信息, 如 priority 等'''
|
||||
|
||||
def __lt__(self, other: StarHandlerMetadata):
|
||||
'''定义小于运算符以支持优先队列'''
|
||||
return self.extras_configs.get('priority', 0) < other.extras_configs.get('priority', 0)
|
||||
@@ -9,7 +9,6 @@ import logging
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.default import DEFAULT_VALUE_MAP
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
@@ -19,6 +18,8 @@ from .star import star_registry, star_map
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .filter.permission import PermissionTypeFilter, PermissionType
|
||||
|
||||
class PluginManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -39,6 +40,8 @@ class PluginManager:
|
||||
'''保留插件的路径。在 packages 目录下'''
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
'''插件配置 Schema 文件名'''
|
||||
|
||||
self.failed_plugin_info = ""
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
|
||||
@@ -125,7 +128,7 @@ class PluginManager:
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata:
|
||||
raise Exception("插件元数据信息不完整。")
|
||||
raise Exception("插件元数据信息不完整。name, desc, version, author 是必须的字段。")
|
||||
metadata = StarMetadata(
|
||||
name=metadata['name'],
|
||||
author=metadata['author'],
|
||||
@@ -135,30 +138,52 @@ class PluginManager:
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
async def reload(self):
|
||||
'''扫描并加载所有的插件'''
|
||||
for smd in star_registry:
|
||||
logger.debug(f"尝试终止插件 {smd.name} ...")
|
||||
if hasattr(smd.star_cls, "__del__"):
|
||||
smd.star_cls.__del__()
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_handlers_registry.star_handlers_map.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
||||
del sys.modules[key]
|
||||
|
||||
async def reload(self, specified_plugin_name=None):
|
||||
'''扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件'''
|
||||
|
||||
specified_module_path = None
|
||||
if specified_plugin_name:
|
||||
for smd in star_registry:
|
||||
if smd.name == specified_plugin_name:
|
||||
specified_module_path = smd.module_path
|
||||
break
|
||||
|
||||
# 终止插件
|
||||
if not specified_module_path:
|
||||
for smd in star_registry:
|
||||
logger.debug(f"尝试终止插件 {smd.name} ...")
|
||||
if hasattr(smd.star_cls, "__del__"):
|
||||
smd.star_cls.__del__()
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
||||
del sys.modules[key]
|
||||
else:
|
||||
# 只重载指定插件
|
||||
smd = star_map.get(specified_module_path)
|
||||
if smd:
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
try:
|
||||
del sys.modules[specified_module_path]
|
||||
except KeyError:
|
||||
logger.warning(f"模块 {specified_module_path} 未载入")
|
||||
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
return False, "未找到任何插件模块"
|
||||
|
||||
fail_rec = ""
|
||||
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
alter_cmd = sp.get("alter_cmd", {})
|
||||
|
||||
# 导入插件模块,并尝试实例化插件类
|
||||
for plugin_module in plugin_modules:
|
||||
try:
|
||||
@@ -167,11 +192,15 @@ class PluginManager:
|
||||
root_dir_name = plugin_module['pname'] # 插件的目录名
|
||||
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
|
||||
|
||||
path = "data.plugins." if not reserved else "packages."
|
||||
path += root_dir_name + "." + module_str
|
||||
|
||||
if specified_module_path and path != specified_module_path:
|
||||
continue
|
||||
|
||||
logger.info(f"正在载入插件 {root_dir_name} ...")
|
||||
|
||||
# 尝试导入模块
|
||||
path = "data.plugins." if not reserved else "packages."
|
||||
path += root_dir_name + "." + module_str
|
||||
try:
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
@@ -200,6 +229,18 @@ class PluginManager:
|
||||
# 通过装饰器的方式注册插件
|
||||
metadata = star_map[path]
|
||||
|
||||
try:
|
||||
# yaml 文件的元数据优先
|
||||
metadata_yaml = self._load_plugin_metadata(plugin_path=plugin_dir_path)
|
||||
if metadata_yaml:
|
||||
metadata.name = metadata_yaml.name
|
||||
metadata.author = metadata_yaml.author
|
||||
metadata.desc = metadata_yaml.desc
|
||||
metadata.version = metadata_yaml.version
|
||||
metadata.repo = metadata_yaml.repo
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if plugin_config:
|
||||
metadata.config = plugin_config
|
||||
try:
|
||||
@@ -213,12 +254,11 @@ class PluginManager:
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
|
||||
# 绑定 handler
|
||||
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
|
||||
for handler in related_handlers:
|
||||
logger.debug(f"bind handler {handler.handler_name} to {metadata.name}")
|
||||
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
|
||||
handler.handler = functools.partial(handler.handler, metadata.star_cls)
|
||||
# llm_tool
|
||||
# 绑定 llm_tool handler
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler.__module__ == metadata.module_path:
|
||||
func_tool.handler_module_path = metadata.module_path
|
||||
@@ -240,8 +280,7 @@ class PluginManager:
|
||||
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
|
||||
|
||||
metadata = None
|
||||
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
|
||||
metadata = self._load_plugin_metadata(plugin_path=plugin_dir_path, plugin_obj=obj)
|
||||
metadata.star_cls = obj
|
||||
metadata.config = plugin_config
|
||||
metadata.module = module
|
||||
@@ -251,26 +290,54 @@ class PluginManager:
|
||||
metadata.module_path = path
|
||||
star_map[path] = metadata
|
||||
star_registry.append(metadata)
|
||||
logger.debug(f"插件 {root_dir_name} 载入成功。")
|
||||
|
||||
# 禁用/启用插件
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
full_names = []
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(metadata.module_path):
|
||||
full_names.append(handler.handler_full_name)
|
||||
|
||||
# 检查并且植入自定义的权限过滤器(alter_cmd)
|
||||
if metadata.name in alter_cmd and handler.handler_name in alter_cmd[metadata.name]:
|
||||
cmd_type = alter_cmd[metadata.name][handler.handler_name].get("permission", "member")
|
||||
found_permission_filter = False
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, PermissionTypeFilter):
|
||||
if cmd_type == "admin":
|
||||
filter_.permission_type = PermissionType.ADMIN
|
||||
else:
|
||||
filter_.permission_type = PermissionType.MEMBER
|
||||
found_permission_filter = True
|
||||
break
|
||||
if not found_permission_filter:
|
||||
handler.event_filters.append(PermissionTypeFilter(PermissionType.ADMIN if cmd_type == "admin" else PermissionType.MEMBER))
|
||||
|
||||
logger.debug(f"插入权限过滤器 {cmd_type} 到 {metadata.name} 的 {handler.handler_name} 方法。")
|
||||
|
||||
metadata.star_handler_full_names = full_names
|
||||
|
||||
# 执行 initialize() 方法
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n"
|
||||
logger.error(f"----- 插件 {root_dir_name} 载入失败 -----")
|
||||
errors = traceback.format_exc()
|
||||
for line in errors.split('\n'):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("----------------------------------")
|
||||
fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {str(e)}。\n"
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
|
||||
if not fail_rec:
|
||||
return True, None
|
||||
else:
|
||||
self.failed_plugin_info = fail_rec
|
||||
return False, fail_rec
|
||||
|
||||
async def install_plugin(self, repo_url: str):
|
||||
|
||||
@@ -27,22 +27,38 @@ class DifyAPIClient:
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout")
|
||||
logger.info(f"chat_messages payload: {payload}")
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
|
||||
if not data:
|
||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
if not data.strip():
|
||||
continue
|
||||
elif data.startswith(b"data:"):
|
||||
try:
|
||||
json_ = json.loads(data[5:])
|
||||
yield json_
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
buffer += chunk.decode('utf-8')
|
||||
blocks = buffer.split('\n\n')
|
||||
|
||||
# 处理完整的数据块
|
||||
for block in blocks[:-1]:
|
||||
if block.strip() and block.startswith('data:'):
|
||||
try:
|
||||
json_str = block[5:] # 移除 "data:" 前缀
|
||||
json_obj = json.loads(json_str)
|
||||
yield json_obj
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}")
|
||||
logger.error(f"原始数据块: {json_str}")
|
||||
|
||||
# 保留最后一个可能不完整的块
|
||||
buffer = blocks[-1] if blocks else ""
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
inputs: Dict,
|
||||
@@ -55,22 +71,38 @@ class DifyAPIClient:
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout")
|
||||
logger.info(f"workflow_run payload: {payload}")
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
|
||||
if not data:
|
||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
if not data.strip():
|
||||
continue
|
||||
elif data.startswith(b"data:"):
|
||||
try:
|
||||
json_ = json.loads(data[5:])
|
||||
yield json_
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
buffer += chunk.decode('utf-8')
|
||||
blocks = buffer.split('\n\n')
|
||||
|
||||
# 处理完整的数据块
|
||||
for block in blocks[:-1]:
|
||||
if block.strip() and block.startswith('data:'):
|
||||
try:
|
||||
json_str = block[5:] # 移除 "data:" 前缀
|
||||
json_obj = json.loads(json_str)
|
||||
yield json_obj
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}")
|
||||
logger.error(f"原始数据块: {json_str}")
|
||||
|
||||
# 保留最后一个可能不完整的块
|
||||
buffer = blocks[-1] if blocks else ""
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
file_path: str,
|
||||
|
||||
@@ -25,6 +25,10 @@ class ParameterValidationMixin:
|
||||
elif isinstance(param_type_or_default_val, str):
|
||||
# 如果 param_type_or_default_val 是字符串,直接赋值
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, int):
|
||||
result[param_name] = int(params[i])
|
||||
elif isinstance(param_type_or_default_val, float):
|
||||
result[param_name] = float(params[i])
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
except ValueError:
|
||||
|
||||
@@ -14,11 +14,22 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
|
||||
|
||||
if self.BASE_RENDER_URL.endswith("/"):
|
||||
self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1]
|
||||
if not self.BASE_RENDER_URL.endswith("text2img"):
|
||||
self.BASE_RENDER_URL += "/text2img"
|
||||
|
||||
def set_endpoint(self, base_url: str):
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
|
||||
if self.BASE_RENDER_URL.endswith("/"):
|
||||
self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1]
|
||||
if not self.BASE_RENDER_URL.endswith("text2img"):
|
||||
self.BASE_RENDER_URL += "/text2img"
|
||||
|
||||
|
||||
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str:
|
||||
'''使用自定义文转图模板'''
|
||||
|
||||
@@ -121,7 +121,7 @@ class ChatRoute(Route):
|
||||
}))
|
||||
|
||||
# 持久化
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
@@ -136,7 +136,7 @@ class ChatRoute(Route):
|
||||
if audio_url:
|
||||
new_his['audio_url'] = audio_url
|
||||
history.append(new_his)
|
||||
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
|
||||
self.db.update_conversation(username, conversation_id, history=json.dumps(history))
|
||||
|
||||
return Response().ok().__dict__
|
||||
|
||||
@@ -168,7 +168,7 @@ class ChatRoute(Route):
|
||||
continue
|
||||
yield result_text + '\n'
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, cid)
|
||||
conversation = self.db.get_conversation_by_user_id(username, cid)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
@@ -178,7 +178,7 @@ class ChatRoute(Route):
|
||||
'type': 'bot',
|
||||
'message': result_text
|
||||
})
|
||||
self.db.update_webchat_conversation(username, cid, history=json.dumps(history))
|
||||
self.db.update_conversation(username, cid, history=json.dumps(history))
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
except BaseException as e:
|
||||
@@ -204,20 +204,20 @@ class ChatRoute(Route):
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
self.db.delete_webchat_conversation(username, conversation_id)
|
||||
self.db.delete_conversation(username, conversation_id)
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def new_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.webchat_new_conversation(username, conversation_id)
|
||||
self.db.new_conversation(username, conversation_id)
|
||||
return Response().ok(data={
|
||||
'conversation_id': conversation_id
|
||||
}).__dict__
|
||||
|
||||
async def get_conversations(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversations = self.db.get_webchat_conversations(username)
|
||||
conversations = self.db.get_conversations(username)
|
||||
return Response().ok(data=conversations).__dict__
|
||||
|
||||
async def get_conversation(self):
|
||||
@@ -226,7 +226,7 @@ class ChatRoute(Route):
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
||||
|
||||
self.curr_user_cid[username] = conversation_id
|
||||
|
||||
|
||||
@@ -6,6 +6,12 @@ from astrbot.core import logger
|
||||
from quart import request
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.filter.regex import RegexFilter
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None:
|
||||
@@ -18,22 +24,58 @@ class PluginRoute(Route):
|
||||
'/plugin/uninstall': ('POST', self.uninstall_plugin),
|
||||
'/plugin/market_list': ('GET', self.get_online_plugins),
|
||||
'/plugin/off': ('POST', self.off_plugin),
|
||||
'/plugin/on': ('POST', self.on_plugin)
|
||||
'/plugin/on': ('POST', self.on_plugin),
|
||||
'/plugin/reload': ('POST', self.reload_plugins),
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.plugin_manager = plugin_manager
|
||||
self.register_routes()
|
||||
|
||||
self.translated_event_type = {
|
||||
EventType.AdapterMessageEvent: "平台消息下发时",
|
||||
EventType.OnLLMRequestEvent: "LLM 请求时",
|
||||
EventType.OnLLMResponseEvent: "LLM 响应后",
|
||||
EventType.OnDecoratingResultEvent: "回复消息前",
|
||||
EventType.OnCallingFuncToolEvent: "函数工具",
|
||||
EventType.OnAfterMessageSentEvent: "发送消息后"
|
||||
}
|
||||
|
||||
async def reload_plugins(self):
|
||||
data = await request.json
|
||||
plugin_name = data.get("name", None)
|
||||
try:
|
||||
success, message = await self.plugin_manager.reload(plugin_name)
|
||||
if not success:
|
||||
return Response().error(message).__dict__
|
||||
return Response().ok(None, "重载成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/reload: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_online_plugins(self):
|
||||
url = "https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json"
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as response:
|
||||
result = await response.json()
|
||||
return Response().ok(result).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件列表失败:{e}")
|
||||
return Response().error(str(e)).__dict__
|
||||
custom = request.args.get("custom_registry")
|
||||
|
||||
if custom:
|
||||
urls = [custom]
|
||||
else:
|
||||
urls = [
|
||||
"https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json",
|
||||
"https://api.soulter.top/astrbot/plugins"
|
||||
]
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return Response().ok(result).__dict__
|
||||
else:
|
||||
logger.error(f"请求 {url} 失败,状态码:{response.status}")
|
||||
except Exception as e:
|
||||
logger.error(f"请求 {url} 失败,错误:{e}")
|
||||
|
||||
return Response().error("获取插件列表失败").__dict__
|
||||
|
||||
async def get_plugins(self):
|
||||
_plugin_resp = []
|
||||
@@ -45,10 +87,58 @@ class PluginRoute(Route):
|
||||
"desc": plugin.desc,
|
||||
"version": plugin.version,
|
||||
"reserved": plugin.reserved,
|
||||
"activated": plugin.activated
|
||||
"activated": plugin.activated,
|
||||
"handlers": await self.get_plugin_handlers_info(plugin.star_handler_full_names),
|
||||
}
|
||||
_plugin_resp.append(_t)
|
||||
return Response().ok(_plugin_resp).__dict__
|
||||
return Response().ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info).__dict__
|
||||
|
||||
async def get_plugin_handlers_info(self, handler_full_names: list[str]):
|
||||
'''解析插件行为'''
|
||||
handlers = []
|
||||
|
||||
for handler_full_name in handler_full_names:
|
||||
info = {}
|
||||
handler = star_handlers_registry.star_handlers_map.get(handler_full_name, None)
|
||||
if handler is None:
|
||||
continue
|
||||
info["event_type"] = handler.event_type.name
|
||||
info["event_type_h"] = self.translated_event_type.get(handler.event_type, handler.event_type.name)
|
||||
info["handler_full_name"] = handler.handler_full_name
|
||||
info["desc"] = handler.desc
|
||||
info["handler_name"] = handler.handler_name
|
||||
|
||||
if handler.event_type == EventType.AdapterMessageEvent:
|
||||
# 处理平台适配器消息事件
|
||||
has_admin = False
|
||||
for filter in handler.event_filters: # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
|
||||
if isinstance(filter, CommandFilter):
|
||||
info["type"] = "指令"
|
||||
info["cmd"] = filter.command_name
|
||||
elif isinstance(filter, CommandGroupFilter):
|
||||
info["type"] = "指令组"
|
||||
info["cmd"] = filter.group_name
|
||||
info["sub_command"] = filter.print_cmd_tree(filter.sub_command_filters)
|
||||
elif isinstance(filter, RegexFilter):
|
||||
info["type"] = "正则匹配"
|
||||
info["cmd"] = filter.regex_str
|
||||
elif isinstance(filter, PermissionTypeFilter):
|
||||
has_admin = True
|
||||
info["has_admin"] = has_admin
|
||||
if "cmd" not in info:
|
||||
info["cmd"] = "未知"
|
||||
if "type" not in info:
|
||||
info["type"] = "事件监听器"
|
||||
else:
|
||||
info["cmd"] = "自动触发"
|
||||
info["type"] = "无"
|
||||
|
||||
if not info["desc"]:
|
||||
info["desc"] = "无描述"
|
||||
|
||||
handlers.append(info)
|
||||
|
||||
return handlers
|
||||
|
||||
async def install_plugin(self):
|
||||
post_data = await request.json
|
||||
|
||||
13
changelogs/v3.4.19.md
Normal file
13
changelogs/v3.4.19.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# What's Changed
|
||||
|
||||
1. 支持接入企业微信(测试)
|
||||
2. 修复速率限制不可用的问题
|
||||
3. gewechat 回调接口默认暴露在所有 IP
|
||||
4. 适配 Azure OpenAI
|
||||
5. 修复请求 gemini 出现 KeyError 'candidates' 的错误
|
||||
6. 将 /reset /persona 挪入管理员指令 #308
|
||||
7. 支持通过 /alter_cmd 设置所有指令是否只能管理员操作
|
||||
8. /plugin 指令支持查看插件注册的指令和指令组
|
||||
9. 插件注册指令支持传入指令的描述以方便 /plugin 查看。需要写在函数的第一行的 docstring 中。
|
||||
10. 修复 schema 中 object hint 不显示 #290
|
||||
11. feat: 优化插件市场的访问速度
|
||||
15
changelogs/v3.4.20.md
Normal file
15
changelogs/v3.4.20.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# What's Changed
|
||||
|
||||
> 由于重写了会话记录部分,更新此版本后,将会造成之前的对话记录清空(但没有被删除)。
|
||||
> 关于更好的对话管理,如果有任何报错或者优化建议,请直接提交 issue~
|
||||
|
||||
1. 更好的对话管理,支持 /ls, /del, /new, /switch, /rename 指令来操作对话。
|
||||
2. 人格情境跟随对话。每个对话支持独立设置人格情境,只需要 /persona 指令切换即可。
|
||||
3. 支持使用 LLM 辅助分段回复 #338
|
||||
4. 优化 aiocqhttp 适配器对用户非法输入的处理
|
||||
5. 优化插件页面
|
||||
6. 修复权限过滤算子导致的问题 #350
|
||||
7. 修复级联指令组时出现载入错误的问题 #366
|
||||
8. 修复代码执行器的一个typo by @eltociear
|
||||
9. 修复指令组情况下可能造成多指令出触发的问题
|
||||
10. 添加屏蔽无权限指令回复的功能 #361
|
||||
19
changelogs/v3.4.21.md
Normal file
19
changelogs/v3.4.21.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# What's Changed
|
||||
|
||||
> 由于重写了会话记录部分,更新此版本后,将会造成之前的对话记录清空(但没有被删除)。
|
||||
> 关于更好的对话管理,如果有任何报错或者优化建议,请直接提交 issue~
|
||||
|
||||
1. 修复 reminder 时区问题
|
||||
2. 面板支持重载单个插件 #297
|
||||
3. 面板支持列表展示插件市场
|
||||
4. 文字转图片支持自定义字数阈值(配置->其他配置)
|
||||
5. 面板更好的列表可视化 #274
|
||||
6. 面板支持查看插件行为
|
||||
7. 支持设置 timeout 超时时间参数,防止思考模型太长达到超时时间。(需要重新配置服务提供商或者在服务提供商 config 中配置 timeout 参数) #378
|
||||
8. openrouter 报错 no endpoints found that support tool use #371
|
||||
9. 修复插件 metadata 不生效的问题
|
||||
10. 修复不支持图片的模型请求异常
|
||||
11. 修复 reminder 无法删除的问题
|
||||
12. 修复 /model 切换不了模型的问题
|
||||
13. 插件支持设置优先级
|
||||
14. 聊天增强图像转述支持自定义 provider id。#274
|
||||
12
changelogs/v3.4.22.md
Normal file
12
changelogs/v3.4.22.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# What's Changed
|
||||
|
||||
1. fix: 400 Bad Request: The browser (or proxy) sent a request that this server could not understand. #396
|
||||
2. remove: 移除了 put_history_to_prompt。当主动回复时,将群聊记录将自动放入prompt,当未主动回复但是开启群聊增强时,群聊记录将放入system prompt
|
||||
3. fix: 插件错误信息点击关闭没反应 #394
|
||||
4. fix: 自部署文转图不生效 #352
|
||||
5. fix: Google Search 报 429 错误时,放宽 Exception 至其他搜索引擎 #405
|
||||
6. fix: 使用 Google Gemini (OpenAI 兼容)的部分情况下联网搜索等函数调用工具没被调用 #342
|
||||
7. fix: 修复尝试弹出最早的记录失效的问题
|
||||
8. fix: 移除了分段回复llm提示词辅助
|
||||
9. perf: 当图片数据为空时不加入上下文 #379
|
||||
10. 修复 dify 返回的结果带有多行数据时的 json 解析异常导致返回值为空的问题 #298 by @zhaolj
|
||||
@@ -32,7 +32,7 @@
|
||||
"vue-router": "4.2.4",
|
||||
"vue3-apexcharts": "1.4.4",
|
||||
"vue3-print-nb": "0.1.4",
|
||||
"vuetify": "3.3.14",
|
||||
"vuetify": "3.7.11",
|
||||
"yup": "1.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
<template>
|
||||
<h3 style="margin-bottom: 8px;" v-if="iterable && metadata[metadataKey]?.type === 'object'">
|
||||
{{metadata[metadataKey]?.description }}
|
||||
{{ metadata[metadataKey]?.description }}
|
||||
</h3>
|
||||
<v-card-text>
|
||||
<div v-for="(index, key) in iterable" :key="key" style="margin-bottom: 0.5px;" v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template">
|
||||
<v-alert v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object'" style="margin-bottom: 16px"
|
||||
:text="metadata[metadataKey].items[key]?.hint" :title="'💡 关于' + metadata[metadataKey].items[key]?.description"
|
||||
type="info" variant="tonal">
|
||||
<div v-for="(index, key) in iterable" :key="key" style="margin-bottom: 0.5px;"
|
||||
v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template">
|
||||
<v-alert v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint"
|
||||
style="margin-bottom: 16px" :text="metadata[metadataKey].items[key]?.hint"
|
||||
:title="'💡 关于' + metadata[metadataKey].items[key]?.description" type="info" variant="tonal">
|
||||
</v-alert>
|
||||
|
||||
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
|
||||
<div style="width: 100%;" v-if="metadata[metadataKey].items[key]">
|
||||
<v-select v-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible" v-model="iterable[key]"
|
||||
variant="outlined" :items="metadata[metadataKey].items[key]?.options"
|
||||
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" dense :disabled="metadata[metadataKey].items[key]?.readonly"></v-select>
|
||||
<v-text-field v-else-if="metadata[metadataKey].items[key]?.type === 'string' && !metadata[metadataKey].items[key]?.invisible"
|
||||
<v-select
|
||||
v-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]" variant="outlined" :items="metadata[metadataKey].items[key]?.options"
|
||||
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" dense
|
||||
:disabled="metadata[metadataKey].items[key]?.readonly"></v-select>
|
||||
<v-text-field
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'string' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]" :label="metadata[metadataKey].items[key]?.description + '(' + key + ')'"
|
||||
variant="outlined" dense ></v-text-field>
|
||||
variant="outlined" dense></v-text-field>
|
||||
<v-text-field
|
||||
v-else-if="(metadata[metadataKey].items[key]?.type === 'int' || metadata[metadataKey].items[key]?.type === 'float') && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]" :label="metadata[metadataKey].items[key]?.description + '(' + key + ')'"
|
||||
@@ -27,17 +31,11 @@
|
||||
<v-switch v-else-if="metadata[metadataKey].items[key]?.type === 'bool' && !metadata[metadataKey].items[key]?.invisible" v-model="iterable[key]"
|
||||
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" color="primary"
|
||||
inset></v-switch>
|
||||
<v-combobox variant="outlined" v-else-if="metadata[metadataKey].items[key]?.type === 'list' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]" chips clearable
|
||||
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" multiple
|
||||
prepend-icon="mdi-tag-multiple-outline">
|
||||
<template v-slot:selection="{ attrs, item, select, selected }">
|
||||
<v-chip v-bind="attrs" :model-value="selected" closable @click="select"
|
||||
@click:close="remove(item)">
|
||||
<strong>{{ item }}</strong>
|
||||
</v-chip>
|
||||
</template>
|
||||
</v-combobox>
|
||||
<ListConfigItem
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'list' && !metadata[metadataKey].items[key]?.invisible"
|
||||
:value="iterable[key]"
|
||||
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'"/>
|
||||
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.type === 'object' && !metadata[metadataKey].items[key]?.invisible"
|
||||
style="border: 1px solid #e0e0e0; padding: 8px; margin-bottom: 16px; border-radius: 10px;">
|
||||
<AstrBotConfig :metadata="metadata[metadataKey].items" :iterable="iterable[key]"
|
||||
@@ -52,51 +50,54 @@
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object' && !metadata[metadataKey].items[key]?.invisible">
|
||||
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && !metadata[metadataKey].items[key]?.invisible">
|
||||
<v-btn icon size="x-small" style="margin-bottom: 22px;">
|
||||
<v-icon size="x-small">mdi-help</v-icon>
|
||||
<v-tooltip activator="parent" location="start">{{ metadata[metadataKey].items[key]?.hint
|
||||
}}</v-tooltip>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<v-chip v-if="!metadata[metadataKey].items[key]?.invisible" color="primary">{{ metadata[metadataKey].items[key]?.type }}</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
<div v-else>
|
||||
<v-alert v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && metadata[metadataKey]?.type !== 'object'" style="margin-bottom: 16px"
|
||||
:text="metadata[metadataKey]?.hint" :title="'💡 关于' + metadata[metadataKey]?.description"
|
||||
type="info" variant="tonal">
|
||||
<v-alert v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint"
|
||||
style="margin-bottom: 16px" :text="metadata[metadataKey]?.hint"
|
||||
:title="'💡 关于' + metadata[metadataKey]?.description" type="info" variant="tonal">
|
||||
</v-alert>
|
||||
|
||||
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
|
||||
<div style="width: 100%;">
|
||||
<v-select v-if="metadata[metadataKey]?.options && !metadata[metadataKey]?.invisible" v-model="iterable[metadataKey]"
|
||||
variant="outlined" :items="metadata[metadataKey]?.options"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'" dense :disabled="metadata[metadataKey]?.readonly"></v-select>
|
||||
<v-text-field v-else-if="metadata[metadataKey]?.type === 'string' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]" :label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'"
|
||||
variant="outlined" dense ></v-text-field>
|
||||
<v-select v-if="metadata[metadataKey]?.options && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]" variant="outlined" :items="metadata[metadataKey]?.options"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" dense
|
||||
:disabled="metadata[metadataKey]?.readonly"></v-select>
|
||||
<v-text-field
|
||||
v-else-if="metadata[metadataKey]?.type === 'string' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" variant="outlined"
|
||||
dense></v-text-field>
|
||||
<v-text-field
|
||||
v-else-if="(metadata[metadataKey]?.type === 'int' || metadata[metadataKey]?.type === 'float') && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]" :label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'"
|
||||
variant="outlined" dense></v-text-field>
|
||||
<v-textarea v-else-if="metadata[metadataKey]?.type === 'text' && !metadata[metadataKey]?.invisible" v-model="iterable[metadataKey]"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'" variant="outlined"
|
||||
v-model="iterable[metadataKey]"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" variant="outlined"
|
||||
dense></v-text-field>
|
||||
<v-textarea v-else-if="metadata[metadataKey]?.type === 'text' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" variant="outlined"
|
||||
dense></v-textarea>
|
||||
<v-switch v-else-if="metadata[metadataKey]?.type === 'bool' && !metadata[metadataKey]?.invisible" v-model="iterable[metadataKey]"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'" color="primary"
|
||||
<v-switch v-else-if="metadata[metadataKey]?.type === 'bool' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" color="primary"
|
||||
inset></v-switch>
|
||||
<v-combobox variant="outlined" v-else-if="metadata[metadataKey]?.type === 'list' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]" chips clearable
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'" multiple
|
||||
prepend-icon="mdi-tag-multiple-outline">
|
||||
<template v-slot:selection="{ attrs, item, select, selected }">
|
||||
<v-chip v-bind="attrs" :model-value="selected" closable @click="select"
|
||||
@click:close="remove(item)">
|
||||
<strong>{{ item }}</strong>
|
||||
</v-chip>
|
||||
</template>
|
||||
</v-combobox>
|
||||
<ListConfigItem
|
||||
v-else-if="metadata[metadataKey]?.type === 'list' && !metadata[metadataKey]?.invisible"
|
||||
:value="iterable[metadataKey]"
|
||||
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'"/>
|
||||
<div v-else-if="metadata[metadataKey]?.type === 'object' && !metadata[metadataKey]?.invisible"
|
||||
style="border: 1px solid #e0e0e0; padding: 8px; margin-bottom: 16px; border-radius: 10px;">
|
||||
<AstrBotConfig :metadata="metadata[metadataKey].items" :iterable="iterable[metadataKey]"
|
||||
@@ -106,13 +107,17 @@
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="!metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && metadata[metadataKey]?.type !== 'object' && !metadata[metadataKey]?.invisible">
|
||||
v-if="!metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && !metadata[metadataKey]?.invisible">
|
||||
<v-btn icon size="x-small" style="margin-bottom: 22px;">
|
||||
<v-icon size="x-small">mdi-help</v-icon>
|
||||
<v-tooltip activator="parent" location="start">{{ metadata[metadataKey]?.hint
|
||||
}}</v-tooltip>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<v-chip v-if="!metadata[metadataKey]?.invisible" color="primary">{{ metadata[metadataKey]?.type }}</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
@@ -120,8 +125,12 @@
|
||||
|
||||
<script>
|
||||
import { readonly } from 'vue';
|
||||
import ListConfigItem from './ListConfigItem.vue';
|
||||
|
||||
export default {
|
||||
components: {
|
||||
ListConfigItem
|
||||
},
|
||||
props: {
|
||||
metadata: Object,
|
||||
iterable: Object,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
<script setup lang="ts">
|
||||
const props = defineProps({
|
||||
title: String,
|
||||
link: String
|
||||
link: String,
|
||||
logo: String
|
||||
});
|
||||
|
||||
const open = (link: string | undefined) => {
|
||||
@@ -11,15 +12,16 @@ const open = (link: string | undefined) => {
|
||||
|
||||
<template>
|
||||
<v-card variant="outlined" elevation="0" class="withbg">
|
||||
<v-card-item style="padding: 10px 14px">
|
||||
<v-card-item style="padding: 10px 12px">
|
||||
<div class="d-sm-flex align-center justify-space-between">
|
||||
<v-card-title style="font-size: 17px;">{{ props.title }}</v-card-title>
|
||||
<img v-if="logo" :src="logo" alt="logo" style="width: 40px; height: 40px; margin-right: 8px;">
|
||||
<v-card-title style="font-size: 16px;">{{ props.title }}</v-card-title>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="plain" @click="open(props.link)">仓库</v-btn>
|
||||
<v-btn size="small" text="Read" variant="flat" border @click="open(props.link)">帮助</v-btn>
|
||||
</div>
|
||||
</v-card-item>
|
||||
<v-divider></v-divider>
|
||||
<v-card-text>
|
||||
<v-card-text style="padding: 16px;">
|
||||
<slot />
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
93
dashboard/src/components/shared/ListConfigItem.vue
Normal file
93
dashboard/src/components/shared/ListConfigItem.vue
Normal file
@@ -0,0 +1,93 @@
|
||||
<template>
|
||||
<div class="list-config-item">
|
||||
<h3>{{ label }}</h3>
|
||||
<v-list dense style="background-color: transparent;max-height: 300px; overflow-y: scroll;" >
|
||||
<v-list-item v-for="(item, index) in items" :key="index">
|
||||
<v-list-item-content style="display: flex; justify-content: space-between;">
|
||||
<v-list-item-title>
|
||||
<v-chip>{{ item }}</v-chip>
|
||||
</v-list-item-title>
|
||||
<v-btn @click="removeItem(index)" variant="plain">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</v-list-item-content>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
<v-text-field
|
||||
v-model="newItem"
|
||||
label="添加新项,按回车确认添加"
|
||||
@keyup.enter="addItem"
|
||||
clearable
|
||||
dense
|
||||
hide-details
|
||||
variant="outlined"
|
||||
></v-text-field>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
export default {
|
||||
name: 'ListConfigItem',
|
||||
props: {
|
||||
value: {
|
||||
type: Array,
|
||||
default: () => [],
|
||||
},
|
||||
label: {
|
||||
type: String,
|
||||
default: '',
|
||||
},
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
newItem: '',
|
||||
items: this.value,
|
||||
};
|
||||
},
|
||||
watch: {
|
||||
items(newVal) {
|
||||
this.$emit('input', newVal);
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
addItem() {
|
||||
if (this.newItem.trim() !== '') {
|
||||
this.items.push(this.newItem.trim());
|
||||
this.newItem = '';
|
||||
}
|
||||
},
|
||||
removeItem(index) {
|
||||
this.items.splice(index, 1);
|
||||
},
|
||||
},
|
||||
};
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.list-config-item {
|
||||
border: 1px solid #e0e0e0;
|
||||
padding: 16px;
|
||||
margin-bottom: 16px;
|
||||
border-radius: 10px;
|
||||
background-color: #ffffff;
|
||||
}
|
||||
|
||||
.list-config-item h3 {
|
||||
margin-top: 0;
|
||||
margin-bottom: 16px;
|
||||
font-size: 18px;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.v-list-item {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.v-list-item-title {
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.v-btn {
|
||||
margin-left: 8px;
|
||||
}
|
||||
</style>
|
||||
@@ -9,7 +9,7 @@ const sidebarMenu = shallowRef(sidebarItems);
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<v-navigation-drawer left v-model="customizer.Sidebar_drawer" elevation="0" rail-width="80" mobile-breakpoint="960"
|
||||
<v-navigation-drawer left v-model="customizer.Sidebar_drawer" elevation="0" rail-width="80"
|
||||
app class="leftSidebar" :rail="customizer.mini_sidebar">
|
||||
<v-list class="pa-4 listitem" style="height: auto">
|
||||
<template v-for="(item, i) in sidebarMenu" :key="i">
|
||||
|
||||
@@ -68,7 +68,7 @@ import config from '@/config';
|
||||
<v-tabs-window-item v-for="(config_item, index) in config_data[key2]" v-show="config_template_tab === index"
|
||||
:key="index" :value="index">
|
||||
<v-container>
|
||||
<v-btn variant="tonal" rounded="xl" color="error" @click="config_data[key2].splice(index, 1)">
|
||||
<v-btn variant="tonal" rounded="xl" color="error" @click="deleteItem(key2, index)">
|
||||
删除这项
|
||||
</v-btn>
|
||||
|
||||
@@ -215,6 +215,20 @@ export default {
|
||||
// new_tmpl_cfg.id = "new_" + val + "_" + this.config_data[config_item_name].length;
|
||||
this.config_data[config_item_name].push(new_tmpl_cfg);
|
||||
this.config_template_tab = this.config_data[config_item_name].length - 1;
|
||||
},
|
||||
deleteItem(config_item_name, index) {
|
||||
console.log(config_item_name, index);
|
||||
let new_list = [];
|
||||
for (let i = 0; i < this.config_data[config_item_name].length; i++) {
|
||||
if (i !== index) {
|
||||
new_list.push(this.config_data[config_item_name][i]);
|
||||
}
|
||||
}
|
||||
this.config_data[config_item_name] = new_list;
|
||||
|
||||
if (this.config_template_tab > 0) {
|
||||
this.config_template_tab -= 1;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -4,59 +4,123 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import axios from 'axios';
|
||||
import { max } from 'date-fns';
|
||||
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<v-row>
|
||||
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
|
||||
title="💡提示" type="info" variant="tonal">
|
||||
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README" title="💡提示"
|
||||
type="info" variant="tonal">
|
||||
</v-alert>
|
||||
<v-col cols="12" md="12">
|
||||
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
|
||||
<h3>🧩 已安装的插件</h3>
|
||||
<div style="display: flex; align-items: center;">
|
||||
<h3>🧩 已安装的插件</h3>
|
||||
|
||||
<v-dialog max-width="500px">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" v-if="extension_data.message" icon size="small" color="error"
|
||||
style="margin-left: auto;" variant="plain">
|
||||
<v-icon>mdi-alert-circle</v-icon>
|
||||
</v-btn>
|
||||
</template>
|
||||
|
||||
<template v-slot:default="{ isActive }">
|
||||
<v-card>
|
||||
<v-card-title class="headline">错误信息</v-card-title>
|
||||
<v-card-text>{{ extension_data.message }}
|
||||
<br>
|
||||
<small>详情请检查控制台</small>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="primary" text @click="isActive.value = false">关闭</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</template>
|
||||
|
||||
</v-dialog>
|
||||
</div>
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="12" md="6" lg="4" v-for="extension in extension_data.data">
|
||||
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" style="margin-bottom: 4px;">
|
||||
<p style="min-height: 130px; max-height: 130px; overflow: none;">{{ extension.desc }}</p>
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-icon>mdi-account</v-icon>
|
||||
<span>{{ extension.author }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<div v-if="!extension.reserved">
|
||||
<v-btn variant="plain" @click="openExtensionConfig(extension.name)">配置</v-btn>
|
||||
<v-btn variant="plain" @click="updateExtension(extension.name)">更新</v-btn>
|
||||
<v-btn variant="plain" @click="uninstallExtension(extension.name)">卸载</v-btn>
|
||||
</div>
|
||||
<v-col cols="12" md="6" lg="3" v-for="extension in extension_data.data">
|
||||
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" :logo="extension?.logo"
|
||||
style="margin-bottom: 4px;">
|
||||
<div style="min-height: 135px; max-height: 135px; overflow: none;">
|
||||
<span style="font-weight: bold;">By @{{ extension.author }}</span>
|
||||
<span> | 插件有 {{ extension.handlers.length }} 个行为</span>
|
||||
<p style="margin-top: 8px;">{{ extension.desc }}</p>
|
||||
<a style="font-size: 12px; cursor: pointer; text-decoration: underline; color: #555;"
|
||||
@click="reloadPlugin(extension.name)">重载插件</a>
|
||||
</div>
|
||||
<div class="d-flex align-center gap-2 " style="overflow-x: auto;">
|
||||
<v-btn v-if="!extension.reserved" class="text-none mr-2" size="small" text="Read" variant="flat" border
|
||||
@click="openExtensionConfig(extension.name)">配置</v-btn>
|
||||
<v-btn v-if="!extension.reserved" class="text-none mr-2" size="small" text="Read" variant="flat" border
|
||||
@click="updateExtension(extension.name)">更新</v-btn>
|
||||
<v-btn v-if="!extension.reserved" class="text-none mr-2" size="small" text="Read" variant="flat" border
|
||||
@click="uninstallExtension(extension.name)">卸载</v-btn>
|
||||
<!-- <span v-else>保留插件</span> -->
|
||||
<v-btn variant="plain" v-if="extension.activated" @click="pluginOff(extension)">禁用</v-btn>
|
||||
<v-btn variant="plain" v-else @click="pluginOn(extension)">启用</v-btn>
|
||||
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border v-if="extension.activated"
|
||||
@click="pluginOff(extension)">禁用</v-btn>
|
||||
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border v-else
|
||||
@click="pluginOn(extension)">启用</v-btn>
|
||||
|
||||
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border
|
||||
@click="showPluginInfo(extension)">行为</v-btn>
|
||||
</div>
|
||||
</ExtensionCard>
|
||||
</v-col>
|
||||
<v-col cols="12" md="12">
|
||||
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
|
||||
<h3>🧩 插件市场</h3>
|
||||
<div style="display: flex; align-items: center;">
|
||||
<h3>🧩 插件市场</h3>
|
||||
<small style="margin-left: 16px;">如无法显示,请打开 <a
|
||||
href="https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json">链接</a> 复制想安装插件对应的 `repo`
|
||||
链接然后点击右下角 + 号安装,或打开链接下载压缩包安装。</small>
|
||||
<v-btn icon @click="isListView = !isListView" size="small" style="margin-left: auto;" variant="plain">
|
||||
<v-icon>{{ isListView ? 'mdi-view-grid' : 'mdi-view-list' }}</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="12" md="6" lg="4" v-for="plugin in pluginMarketData">
|
||||
<ExtensionCard :key="plugin.name" :title="plugin.name" :link="plugin.repo" style="margin-bottom: 4px;">
|
||||
<p style="min-height: 130px; max-height: 130px; overflow: hidden;">{{ plugin.desc }}</p>
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-icon>mdi-account</v-icon>
|
||||
<span>{{ plugin.author }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn v-if="!plugin.installed" variant="plain"
|
||||
@click="extension_url = plugin.repo; newExtension()">安装</v-btn>
|
||||
<v-btn v-else variant="plain" disabled>已安装</v-btn>
|
||||
</div>
|
||||
</ExtensionCard>
|
||||
|
||||
|
||||
<v-col cols="12" md="12" v-if="announcement">
|
||||
<v-banner color="success" lines="one" :text="announcement" :stacked="false">
|
||||
</v-banner>
|
||||
</v-col>
|
||||
|
||||
<template v-if="isListView">
|
||||
<v-col cols="12" md="12">
|
||||
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name">
|
||||
<template v-slot:item.actions="{ item }">
|
||||
<v-btn v-if="!item.installed" class="text-none mr-2" size="small" text="Read" variant="flat" border
|
||||
@click="extension_url = item.repo; newExtension()">安装</v-btn>
|
||||
<v-btn v-else class="text-none mr-2" size="small" text="Read" variant="flat" border disabled>已安装</v-btn>
|
||||
</template>
|
||||
</v-data-table>
|
||||
</v-col>
|
||||
</template>
|
||||
<template v-else>
|
||||
<v-col cols="12" md="6" lg="3" v-for="plugin in pluginMarketData">
|
||||
<ExtensionCard :key="plugin.name" :title="plugin.name" :link="plugin.repo" style="margin-bottom: 4px;">
|
||||
<div style="min-height: 130px; max-height: 130px; overflow: hidden;">
|
||||
<p style="font-weight: bold;">By @{{ plugin.author }}</p>
|
||||
{{ plugin.desc }}
|
||||
</div>
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-btn v-if="!plugin.installed" class="text-none mr-2" size="small" text="Read" variant="flat" border
|
||||
@click="extension_url = plugin.repo; newExtension()">安装</v-btn>
|
||||
<v-btn v-else class="text-none mr-2" size="small" text="Read" variant="flat" border disabled>已安装</v-btn>
|
||||
</div>
|
||||
</ExtensionCard>
|
||||
</v-col>
|
||||
</template>
|
||||
|
||||
<v-col style="margin-bottom: 16px;" cols="12" md="12">
|
||||
<small ><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
|
||||
<small><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
|
||||
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
|
||||
</v-col>
|
||||
|
||||
@@ -71,7 +135,8 @@ import axios from 'axios';
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-container>
|
||||
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata" :iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
|
||||
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata"
|
||||
:iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
|
||||
<p v-else>这个插件没有配置</p>
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
@@ -166,6 +231,44 @@ import axios from 'axios';
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<v-dialog v-model="showPluginInfoDialog" width="1200">
|
||||
<template v-slot:activator="{ props }">
|
||||
</template>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
<span class="text-h5">{{ selectedPlugin.name }} 插件行为</span>
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-data-table style="font-size: 17px;" :headers="plugin_handler_info_headers" :items="selectedPlugin.handlers"
|
||||
item-key="name">
|
||||
<template v-slot:header.id="{ column }">
|
||||
<p style="font-weight: bold;">{{ column.title }}</p>
|
||||
</template>
|
||||
<template v-slot:item.event_type="{ item }">
|
||||
{{ item.event_type }}
|
||||
</template>
|
||||
<template v-slot:item.desc="{ item }">
|
||||
{{ item.desc }}
|
||||
</template>
|
||||
<template v-slot:item.type="{ item }">
|
||||
<v-chip color="success">
|
||||
{{ item.type }}
|
||||
</v-chip>
|
||||
</template>
|
||||
<template v-slot:item.cmd="{ item }">
|
||||
<span style="font-weight: bold;">{{ item.cmd }}</span>
|
||||
</template>
|
||||
</v-data-table>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="blue-darken-1" variant="text" @click="showPluginInfoDialog = false">
|
||||
关闭
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<v-snackbar :timeout="2000" elevation="24" :color="snack_success" v-model="snack_show">
|
||||
{{ snack_message }}
|
||||
</v-snackbar>
|
||||
@@ -186,7 +289,8 @@ export default {
|
||||
data() {
|
||||
return {
|
||||
extension_data: {
|
||||
"data": []
|
||||
"data": [],
|
||||
"message": ""
|
||||
},
|
||||
extension_url: "",
|
||||
status: "",
|
||||
@@ -207,12 +311,34 @@ export default {
|
||||
title: "加载中...",
|
||||
statusCode: 0, // 0: loading, 1: success, 2: error,
|
||||
result: ""
|
||||
}
|
||||
},
|
||||
|
||||
announcement: "",
|
||||
showPluginInfoDialog: false,
|
||||
selectedPlugin: {},
|
||||
plugin_handler_info_headers: [
|
||||
{ title: '行为类型', key: 'event_type_h' },
|
||||
{ title: '描述', key: 'desc', maxWidth: '250px' },
|
||||
{ title: '具体类型', key: 'type' },
|
||||
{ title: '触发方式', key: 'cmd' },
|
||||
],
|
||||
isListView: false,
|
||||
pluginMarketHeaders: [
|
||||
{ title: '名称', value: 'name' },
|
||||
{ title: '描述', value: 'desc' },
|
||||
{ title: '作者', value: 'author' },
|
||||
{ title: '操作', value: 'actions', sortable: false }
|
||||
],
|
||||
}
|
||||
},
|
||||
mounted() {
|
||||
this.getExtensions();
|
||||
this.fetchPluginCollection();
|
||||
|
||||
axios.get('https://api.soulter.top/astrbot-announcement-plugin-market').then((res) => {
|
||||
let data = res.data.data;
|
||||
this.announcement = data.text;
|
||||
});
|
||||
},
|
||||
methods: {
|
||||
toast(message, success) {
|
||||
@@ -240,7 +366,8 @@ export default {
|
||||
},
|
||||
getExtensions() {
|
||||
axios.get('/api/plugin/get').then((res) => {
|
||||
this.extension_data.data = res.data.data;
|
||||
this.extension_data = res.data;
|
||||
|
||||
this.checkAlreadyInstalled();
|
||||
});
|
||||
},
|
||||
@@ -259,7 +386,7 @@ export default {
|
||||
if (this.upload_file !== null) {
|
||||
this.toast("正在从文件安装插件", "primary");
|
||||
const formData = new FormData();
|
||||
formData.append('file', this.upload_file[0]);
|
||||
formData.append('file', this.upload_file);
|
||||
axios.post('/api/plugin/install-upload', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data'
|
||||
@@ -270,7 +397,7 @@ export default {
|
||||
this.onLoadingDialogResult(2, res.data.message, -1);
|
||||
return;
|
||||
}
|
||||
this.extension_data.data = res.data.data;
|
||||
this.extension_data = res.data;
|
||||
this.upload_file = "";
|
||||
this.onLoadingDialogResult(1, res.data.message);
|
||||
this.dialog = false;
|
||||
@@ -291,8 +418,7 @@ export default {
|
||||
this.onLoadingDialogResult(2, res.data.message, -1);
|
||||
return;
|
||||
}
|
||||
this.extension_data.data = res.data.data;
|
||||
console.log(this.extension_data);
|
||||
this.extension_data = res.data;
|
||||
this.extension_url = "";
|
||||
this.onLoadingDialogResult(1, res.data.message);
|
||||
this.dialog = false;
|
||||
@@ -314,7 +440,7 @@ export default {
|
||||
this.toast(res.data.message, "error");
|
||||
return;
|
||||
}
|
||||
this.extension_data.data = res.data.data;
|
||||
this.extension_data = res.data;
|
||||
this.toast(res.data.message, "success");
|
||||
this.dialog = false;
|
||||
this.getExtensions();
|
||||
@@ -332,7 +458,7 @@ export default {
|
||||
this.onLoadingDialogResult(2, res.data.message, -1);
|
||||
return;
|
||||
}
|
||||
this.extension_data.data = res.data.data;
|
||||
this.extension_data = res.data;
|
||||
console.log(this.extension_data);
|
||||
this.onLoadingDialogResult(1, res.data.message);
|
||||
this.dialog = false;
|
||||
@@ -382,7 +508,7 @@ export default {
|
||||
});
|
||||
},
|
||||
updateConfig() {
|
||||
axios.post('/api/config/plugin/update?plugin_name='+this.curr_namespace, this.extension_config.config).then((res) => {
|
||||
axios.post('/api/config/plugin/update?plugin_name=' + this.curr_namespace, this.extension_config.config).then((res) => {
|
||||
if (res.data.status === "ok") {
|
||||
this.toast(res.data.message, "success");
|
||||
this.$refs.wfr.check();
|
||||
@@ -418,11 +544,42 @@ export default {
|
||||
}
|
||||
for (let i = 0; i < this.pluginMarketData.length; i++) {
|
||||
for (let j = 0; j < this.extension_data.data.length; j++) {
|
||||
if (this.pluginMarketData[i].repo === this.extension_data.data[j].repo) {
|
||||
if (this.pluginMarketData[i].repo === this.extension_data.data[j].repo || this.pluginMarketData[i].name === this.extension_data.data[j].name) {
|
||||
this.pluginMarketData[i].installed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 将已安装的插件移动到最后面
|
||||
let installed = [];
|
||||
let notInstalled = [];
|
||||
for (let i = 0; i < this.pluginMarketData.length; i++) {
|
||||
if (this.pluginMarketData[i].installed) {
|
||||
installed.push(this.pluginMarketData[i]);
|
||||
} else {
|
||||
notInstalled.push(this.pluginMarketData[i]);
|
||||
}
|
||||
}
|
||||
this.pluginMarketData = notInstalled.concat(installed);
|
||||
},
|
||||
showPluginInfo(plugin) {
|
||||
this.selectedPlugin = plugin;
|
||||
this.showPluginInfoDialog = true;
|
||||
},
|
||||
reloadPlugin(plugin_name) {
|
||||
axios.post('/api/plugin/reload',
|
||||
{
|
||||
name: plugin_name
|
||||
}).then((res) => {
|
||||
if (res.data.status === "error") {
|
||||
this.onLoadingDialogResult(2, res.data.message, -1);
|
||||
return;
|
||||
}
|
||||
this.toast("重载成功", "success");
|
||||
this.getExtensions();
|
||||
}).catch((err) => {
|
||||
this.toast(err, "error");
|
||||
});
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ class LongTermMemory:
|
||||
self.max_cnt = 300
|
||||
self.image_caption = self.config["image_caption"]
|
||||
self.image_caption_prompt = self.config["image_caption_prompt"]
|
||||
self.image_caption_provider_id = self.config["image_caption_provider_id"]
|
||||
|
||||
self.active_reply = self.config["active_reply"]
|
||||
self.enable_active_reply = self.active_reply.get("enable", False)
|
||||
@@ -32,7 +33,7 @@ class LongTermMemory:
|
||||
self.ar_possibility = self.active_reply["possibility_reply"]
|
||||
self.ar_prompt = self.active_reply.get("prompt", "")
|
||||
|
||||
self.put_history_to_prompt = self.config["put_history_to_prompt"]
|
||||
# self.put_history_to_prompt = self.config["put_history_to_prompt"]
|
||||
|
||||
async def remove_session(self, event: AstrMessageEvent) -> int:
|
||||
cnt = 0
|
||||
@@ -42,7 +43,13 @@ class LongTermMemory:
|
||||
return cnt
|
||||
|
||||
async def get_image_caption(self, image_url: str) -> str:
|
||||
provider = self.context.get_using_provider()
|
||||
|
||||
if not self.image_caption_provider_id:
|
||||
provider = self.context.get_using_provider()
|
||||
else:
|
||||
provider = self.context.get_provider_by_id(self.image_caption_provider_id)
|
||||
if not provider:
|
||||
raise Exception(f"没有找到 ID 为 {self.image_caption_provider_id} 的提供商")
|
||||
response = await provider.text_chat(
|
||||
prompt=self.image_caption_prompt,
|
||||
session_id=uuid.uuid4().hex,
|
||||
@@ -103,11 +110,11 @@ class LongTermMemory:
|
||||
|
||||
chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin])
|
||||
|
||||
if self.put_history_to_prompt:
|
||||
if self.enable_active_reply:
|
||||
prompt = req.prompt
|
||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
||||
req.contexts = [] # 清空上下文,当使用了群聊增强,所有聊天记录都在一个prompt中。
|
||||
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
||||
else:
|
||||
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
|
||||
req.system_prompt += chats_str
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
import aiohttp
|
||||
import datetime
|
||||
import builtins
|
||||
import json
|
||||
import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.config.default import VERSION
|
||||
from collections import defaultdict
|
||||
from .long_term_memory import LongTermMemory
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -41,6 +45,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command("help")
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
'''查看帮助'''
|
||||
notice = ""
|
||||
try:
|
||||
notice = await self._query_astrbot_notice()
|
||||
@@ -50,31 +55,38 @@ class Main(star.Star):
|
||||
dashboard_version = await get_dashboard_version()
|
||||
|
||||
msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version})
|
||||
已注册的 AstrBot 内置指令:
|
||||
AstrBot 指令:
|
||||
[System]
|
||||
/plugin: 查看注册的插件、插件帮助
|
||||
/t2i: 开启/关闭文本转图片模式
|
||||
/sid: 获取当前会话的 ID
|
||||
/op <admin_id>: 授权管理员
|
||||
/deop <admin_id>: 取消管理员
|
||||
/wl <sid>: 添加会话白名单
|
||||
/dwl <sid>: 删除会话白名单
|
||||
/dashboard_update: 更新管理面板
|
||||
/plugin: 查看插件、插件帮助
|
||||
/t2i: 开关文本转图片
|
||||
/sid: 获取会话 ID
|
||||
/op <admin_id>: 授权管理员(op)
|
||||
/deop <admin_id>: 取消管理员(op)
|
||||
/wl <sid>: 添加白名单(op)
|
||||
/dwl <sid>: 删除白名单(op)
|
||||
/dashboard_update: 更新管理面板(op)
|
||||
/alter_cmd: 设置指令权限(op)
|
||||
|
||||
[大模型]
|
||||
/provider: 查看、切换大模型提供商
|
||||
/model: 查看、切换提供商模型列表
|
||||
/key: 查看、切换 API Key
|
||||
/reset: 重置 LLM 会话
|
||||
/history: 获取会话历史记录
|
||||
/persona: 情境人格设置
|
||||
/tool ls: 查看、激活、停用当前注册的函数工具
|
||||
/provider: 大模型提供商
|
||||
/model: 模型列表
|
||||
/ls: 对话列表
|
||||
/new: 创建新对话
|
||||
/switch: 切换对话
|
||||
/rename: 重命名对话
|
||||
/del: 删除当前会话对话(op)
|
||||
/reset: 重置 LLM 会话(op)
|
||||
/history: 当前对话的对话记录
|
||||
/persona: 人格情景(op)
|
||||
/tool ls: 函数工具
|
||||
/key: API Key(op)
|
||||
/websearch: 网页搜索
|
||||
|
||||
[其他]
|
||||
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
|
||||
/unset <变量名>: 删除当前会话的变量。
|
||||
/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入。
|
||||
/unset <变量名>: 删除会话的变量。
|
||||
|
||||
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
|
||||
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
|
||||
{notice}"""
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
@@ -124,7 +136,7 @@ class Main(star.Star):
|
||||
if plugin_list_info.strip() == "":
|
||||
plugin_list_info = "没有加载任何插件。"
|
||||
|
||||
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
|
||||
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
|
||||
event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False))
|
||||
else:
|
||||
if oper1 == "off":
|
||||
@@ -147,10 +159,34 @@ class Main(star.Star):
|
||||
plugin = self.context.get_registered_star(oper1)
|
||||
if plugin is None:
|
||||
event.set_result(MessageEventResult().message("未找到此插件。"))
|
||||
else:
|
||||
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
|
||||
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
return
|
||||
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "帮助信息: 未提供"
|
||||
help_msg += f"\n\n作者: {plugin.author}\n版本: {plugin.version}"
|
||||
command_handlers = []
|
||||
command_names = []
|
||||
for handler in star_handlers_registry:
|
||||
assert isinstance(handler, StarHandlerMetadata)
|
||||
if handler.handler_module_path != plugin.module_path:
|
||||
continue
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, CommandFilter):
|
||||
command_handlers.append(handler)
|
||||
command_names.append(filter_.command_name)
|
||||
break
|
||||
elif isinstance(filter_, CommandGroupFilter):
|
||||
command_handlers.append(handler)
|
||||
command_names.append(filter_.group_name)
|
||||
|
||||
if len(command_handlers) > 0:
|
||||
help_msg += "\n\n指令列表:\n"
|
||||
for i in range(len(command_handlers)):
|
||||
help_msg += f"{command_names[i]}: {command_handlers[i].desc}\n"
|
||||
|
||||
help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。"
|
||||
|
||||
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
|
||||
ret += "更多帮助信息请查看插件仓库 README。"
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@filter.command("t2i")
|
||||
async def t2i(self, event: AstrMessageEvent):
|
||||
@@ -192,6 +228,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("wl")
|
||||
async def wl(self, event: AstrMessageEvent, sid: str):
|
||||
'''添加白名单。wl <sid>'''
|
||||
self.context.get_config()['platform_settings']['id_whitelist'].append(sid)
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("添加白名单成功。"))
|
||||
@@ -199,6 +236,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dwl")
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str):
|
||||
'''删除白名单。dwl <sid>'''
|
||||
try:
|
||||
self.context.get_config()['platform_settings']['id_whitelist'].remove(sid)
|
||||
self.context.get_config().save_config()
|
||||
@@ -236,14 +274,24 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
|
||||
'''重置 LLM 会话'''
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
await self.context.get_using_provider().forget(message.session_id)
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
|
||||
if not cid:
|
||||
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.update_conversation(
|
||||
message.unified_msg_origin, cid, []
|
||||
)
|
||||
|
||||
ret = "清除会话 LLM 聊天历史成功。"
|
||||
if self.ltm:
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
@@ -253,7 +301,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("model")
|
||||
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
|
||||
|
||||
'''查看或者切换模型'''
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
@@ -294,25 +342,34 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
else:
|
||||
self.context.get_using_provider().set_model(idx_or_name)
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换模型成功。 \n模型信息: {idx_or_name}"))
|
||||
MessageEventResult().message(f"切换模型到 {self.context.get_using_provider().get_model()}。"))
|
||||
|
||||
|
||||
@filter.command("history")
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
|
||||
|
||||
'''查看对话记录'''
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
size_per_page = 3
|
||||
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
|
||||
size_per_page = 6
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
|
||||
if not session_curr_cid:
|
||||
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
|
||||
return
|
||||
|
||||
contexts, total_pages = await self.context.conversation_manager.get_human_readable_context(
|
||||
message.unified_msg_origin, session_curr_cid, page, size_per_page
|
||||
)
|
||||
|
||||
history = ""
|
||||
for context in contexts:
|
||||
if len(context) > 150:
|
||||
context = context[:150] + "..."
|
||||
history += f"{context}\n"
|
||||
|
||||
ret = f"""历史记录:
|
||||
ret = f"""当前对话历史记录:
|
||||
{history}
|
||||
第 {page} 页 | 共 {total_pages} 页
|
||||
|
||||
@@ -321,6 +378,88 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@filter.command("ls")
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1):
|
||||
'''查看对话列表'''
|
||||
size_per_page = 6
|
||||
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
|
||||
total_pages = len(conversations) // size_per_page
|
||||
if len(conversations) % size_per_page != 0:
|
||||
total_pages += 1
|
||||
conversations = conversations[(page-1)*size_per_page:page*size_per_page]
|
||||
|
||||
ret = "对话列表:\n---\n"
|
||||
global_index = (page - 1) * size_per_page + 1
|
||||
|
||||
_titles = {}
|
||||
for conv in conversations:
|
||||
|
||||
persona_id = conv.persona_id
|
||||
if not persona_id and not persona_id == "[%None]":
|
||||
persona_id = self.context.provider_manager.selected_default_persona['name']
|
||||
|
||||
title = conv.title if conv.title else "新对话"
|
||||
_titles[conv.cid] = title
|
||||
|
||||
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
|
||||
global_index += 1
|
||||
|
||||
ret += "---\n"
|
||||
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
if curr_cid:
|
||||
ret += f"\n当前对话: {_titles[curr_cid]}({curr_cid[:4]})"
|
||||
else:
|
||||
ret += "\n当前对话: 无"
|
||||
|
||||
unique_session = self.context.get_config()['platform_settings']['unique_session']
|
||||
if unique_session:
|
||||
ret += "\n会话隔离粒度: 个人"
|
||||
else:
|
||||
ret += "\n会话隔离粒度: 群聊"
|
||||
|
||||
ret += f"\n第 {page} 页 | 共 {total_pages} 页"
|
||||
ret += "\n*输入 /ls 2 跳转到第 2 页"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@filter.command("new")
|
||||
async def new_conv(self, message: AstrMessageEvent):
|
||||
'''创建新对话'''
|
||||
cid = await self.context.conversation_manager.new_conversation(message.unified_msg_origin)
|
||||
message.set_result(MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"))
|
||||
|
||||
@filter.command("switch")
|
||||
async def switch_conv(self, message: AstrMessageEvent, index: int):
|
||||
'''通过 /ls 前面的序号切换对话'''
|
||||
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
|
||||
if index > len(conversations) or index < 1:
|
||||
message.set_result(MessageEventResult().message("对话序号错误,请使用 /ls 查看"))
|
||||
else:
|
||||
conversation = conversations[index-1]
|
||||
title = conversation.title if conversation.title else "新对话"
|
||||
await self.context.conversation_manager.switch_conversation(message.unified_msg_origin, conversation.cid)
|
||||
message.set_result(MessageEventResult().message(f"切换到对话: {title}({conversation.cid[:4]})。"))
|
||||
|
||||
@filter.command("rename")
|
||||
async def rename_conv(self, message: AstrMessageEvent, new_name: str):
|
||||
'''重命名对话'''
|
||||
await self.context.conversation_manager.update_conversation_title(message.unified_msg_origin, new_name)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("del")
|
||||
async def del_conv(self, message: AstrMessageEvent):
|
||||
'''删除当前对话'''
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
|
||||
if not session_curr_cid:
|
||||
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid)
|
||||
message.set_result(MessageEventResult().message("删除当前对话成功。"))
|
||||
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("key")
|
||||
async def key(self, message: AstrMessageEvent, index: int=None):
|
||||
@@ -354,30 +493,35 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
MessageEventResult().message("切换 Key 未知错误: "+str(e)))
|
||||
message.set_result(MessageEventResult().message("切换 Key 成功。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("persona")
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
|
||||
l = message.message_str.split(" ")
|
||||
|
||||
curr_persona_name = "无"
|
||||
if self.context.get_using_provider().curr_personality:
|
||||
curr_persona_name = self.context.get_using_provider().curr_personality['name']
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
curr_cid_title = "无"
|
||||
if cid:
|
||||
conversation = await self.context.conversation_manager.get_conversation(message.unified_msg_origin, cid)
|
||||
if not conversation.persona_id and not conversation.persona_id == "[%None]":
|
||||
curr_persona_name = self.context.provider_manager.selected_default_persona['name']
|
||||
else:
|
||||
curr_persona_name = conversation.persona_id
|
||||
|
||||
curr_cid_title = conversation.title if conversation.title else "新对话"
|
||||
curr_cid_title += f"({cid[:4]})"
|
||||
|
||||
if len(l) == 1:
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"""[Persona]
|
||||
|
||||
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
|
||||
- 人格情景列表: `/persona list`
|
||||
- 人格情景详细信息: `/persona view 人格名`
|
||||
- 设置人格情景: `/persona 人格`
|
||||
- 人格情景详细信息: `/persona view 人格`
|
||||
- 取消人格: `/persona unset`
|
||||
|
||||
当前人格情景: {curr_persona_name}
|
||||
默认人格情景: {self.context.provider_manager.selected_default_persona['name']}
|
||||
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
|
||||
|
||||
配置人格情景请前往管理面板-配置页
|
||||
""").use_t2i(False))
|
||||
@@ -402,7 +546,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
msg = f"人格{ps}不存在"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "unset":
|
||||
self.context.get_using_provider().curr_personality = None
|
||||
if not cid:
|
||||
message.set_result(MessageEventResult().message("当前没有对话,无法取消人格。"))
|
||||
return
|
||||
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, "[%None]")
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
@@ -410,7 +557,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
lambda persona: persona['name'] == ps,
|
||||
self.context.provider_manager.personas
|
||||
), None):
|
||||
self.context.get_using_provider().curr_personality = persona
|
||||
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, ps)
|
||||
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
|
||||
else:
|
||||
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
|
||||
@@ -482,7 +629,17 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(event.unified_msg_origin)
|
||||
|
||||
if not session_curr_cid:
|
||||
logger.error("当前未处于对话状态,无法主动回复,请使用 /switch 切换或者 /new 创建。")
|
||||
return
|
||||
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
session_curr_cid
|
||||
)
|
||||
history = json.loads(conv.history)
|
||||
|
||||
prompt = self.ltm.ar_prompt
|
||||
if not prompt:
|
||||
@@ -492,7 +649,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
contexts=session_provider_context if session_provider_context else []
|
||||
contexts=history if history else []
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
@@ -501,7 +658,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
'''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt'''
|
||||
provider = self.context.get_using_provider()
|
||||
if self.prompt_prefix:
|
||||
req.prompt = self.prompt_prefix + req.prompt
|
||||
|
||||
@@ -512,16 +668,27 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
req.prompt = user_info + req.prompt
|
||||
|
||||
if self.enable_datetime:
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
|
||||
|
||||
if persona := provider.curr_personality:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
tz_offset = datetime.timedelta(hours=8)
|
||||
tz = datetime.timezone(tz_offset)
|
||||
current_time = datetime.datetime.now(tz).strftime('%Y-%m-%d %H:%M')
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
|
||||
if req.conversation:
|
||||
persona_id = req.conversation.persona_id
|
||||
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
|
||||
persona_id = self.context.provider_manager.selected_default_persona['name']
|
||||
persona = next(builtins.filter(
|
||||
lambda persona: persona['name'] == persona_id,
|
||||
self.context.provider_manager.personas
|
||||
), None)
|
||||
if persona:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
if self.ltm:
|
||||
try:
|
||||
@@ -538,6 +705,66 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
await self.ltm.after_req_llm(event)
|
||||
except BaseException as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("alter_cmd")
|
||||
async def alter_cmd(self, event: AstrMessageEvent):
|
||||
# token = event.message_str.split(" ")
|
||||
token = self.parse_commands(event.message_str)
|
||||
if token.len < 2:
|
||||
yield event.plain_result("可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令")
|
||||
return
|
||||
|
||||
cmd_name = token.get(1)
|
||||
cmd_type = token.get(2)
|
||||
|
||||
if cmd_type not in ["admin", "member"]:
|
||||
yield event.plain_result("指令类型错误,可选类型有 admin, member")
|
||||
return
|
||||
|
||||
# 查找指令
|
||||
found_command = None
|
||||
for handler in star_handlers_registry:
|
||||
assert isinstance(handler, StarHandlerMetadata)
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, CommandFilter):
|
||||
if filter_.command_name == cmd_name:
|
||||
found_command = handler
|
||||
break
|
||||
elif isinstance(filter_, CommandGroupFilter):
|
||||
if cmd_name == filter_.group_name:
|
||||
found_command = handler
|
||||
break
|
||||
|
||||
if not found_command:
|
||||
yield event.plain_result("未找到该指令")
|
||||
return
|
||||
|
||||
found_plugin = star_map[found_command.handler_module_path]
|
||||
|
||||
alter_cmd_cfg = sp.get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
|
||||
cfg = plugin_.get(found_command.handler_name, {})
|
||||
cfg["permission"] = cmd_type
|
||||
plugin_[found_command.handler_name] = cfg
|
||||
alter_cmd_cfg[found_plugin.name] = plugin_
|
||||
|
||||
sp.put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
# 注入权限过滤器
|
||||
found_permission_filter = False
|
||||
for filter_ in found_command.event_filters:
|
||||
if isinstance(filter_, PermissionTypeFilter):
|
||||
if cmd_type == "admin":
|
||||
filter_.permission_type = filter.PermissionType.ADMIN
|
||||
else:
|
||||
filter_.permission_type = filter.PermissionType.MEMBER
|
||||
found_permission_filter = True
|
||||
break
|
||||
if not found_permission_filter:
|
||||
found_command.event_filters.insert(0, PermissionTypeFilter(filter.PermissionType.ADMIN if cmd_type == "admin" else filter.PermissionType.MEMBER))
|
||||
|
||||
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
|
||||
|
||||
# @filter.command_group("kdb")
|
||||
# def kdb(self):
|
||||
|
||||
@@ -358,7 +358,7 @@ class Main(star.Star):
|
||||
|
||||
if not ok:
|
||||
if traceback:
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code."
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
|
||||
else:
|
||||
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
|
||||
break
|
||||
@@ -393,4 +393,4 @@ class Main(star.Star):
|
||||
await container.kill()
|
||||
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
|
||||
finally:
|
||||
await container.delete()
|
||||
await container.delete()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import datetime
|
||||
import uuid
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import filter
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
@@ -28,11 +29,18 @@ class Main(star.Star):
|
||||
'''Initialize the scheduler.'''
|
||||
for group in self.reminder_data:
|
||||
for reminder in self.reminder_data[group]:
|
||||
if 'id' not in reminder:
|
||||
id_ = str(uuid.uuid4())
|
||||
reminder['id'] = id_
|
||||
else:
|
||||
id_ = reminder['id']
|
||||
|
||||
if "datetime" in reminder:
|
||||
if self.check_is_outdated(reminder):
|
||||
continue
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
id=id_,
|
||||
trigger='date',
|
||||
args=[group, reminder],
|
||||
run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"),
|
||||
@@ -42,6 +50,7 @@ class Main(star.Star):
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
trigger='cron',
|
||||
id=id_,
|
||||
args=[group, reminder],
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(reminder["cron"])
|
||||
@@ -69,11 +78,11 @@ class Main(star.Star):
|
||||
}
|
||||
|
||||
@llm_tool("reminder")
|
||||
async def reminder_tool(self, event: AstrMessageEvent, text: str, datetime_str: str = None, cron_expression: str = None, human_readable_cron: str = None):
|
||||
async def reminder_tool(self, event: AstrMessageEvent, text: str=None, datetime_str: str = None, cron_expression: str = None, human_readable_cron: str = None):
|
||||
'''Call this function when user ask for setting a reminder.
|
||||
|
||||
Args:
|
||||
text(string): The content of the reminder.
|
||||
text(string): Must Required. The content of the reminder.
|
||||
datetime_str(string): Required when user's reminder is a single reminder. The datetime string of the reminder, Must format with %Y-%m-%d %H:%M
|
||||
cron_expression(string): Required when user's reminder is a repeated reminder. The cron expression of the reminder.
|
||||
human_readable_cron(string): Optional. The human readable cron expression of the reminder.
|
||||
@@ -88,65 +97,88 @@ class Main(star.Star):
|
||||
if not cron_expression and not datetime_str:
|
||||
raise ValueError("The cron_expression and datetime_str cannot be both None.")
|
||||
reminder_time = ""
|
||||
|
||||
if not text:
|
||||
text = "未命名待办事项"
|
||||
|
||||
if cron_expression:
|
||||
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
|
||||
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron, "id": str(uuid.uuid4()) }
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
'cron',
|
||||
id=d["id"],
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d]
|
||||
)
|
||||
if human_readable_cron:
|
||||
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
|
||||
else:
|
||||
d = { "text": text, "datetime": datetime_str }
|
||||
d = { "text": text, "datetime": datetime_str, "id": str(uuid.uuid4()) }
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
'date',
|
||||
id=d["id"],
|
||||
args=[event.unified_msg_origin, d],
|
||||
run_date=datetime_scheduled,
|
||||
misfire_grace_time=60
|
||||
)
|
||||
reminder_time = datetime_str
|
||||
await self._save_data()
|
||||
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
|
||||
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。\n使用 /tool off reminder 关闭此功能。")
|
||||
|
||||
@filter.command_group("reminder")
|
||||
def reminder(self):
|
||||
'''The command group of the reminder.'''
|
||||
pass
|
||||
|
||||
async def get_upcoming_reminders(self, unified_msg_origin: str):
|
||||
'''Get upcoming reminders.'''
|
||||
reminders = self.reminder_data.get(unified_msg_origin, [])
|
||||
if not reminders:
|
||||
return []
|
||||
now = datetime.datetime.now()
|
||||
upcoming_reminders = [
|
||||
reminder for reminder in reminders
|
||||
if "datetime" not in reminder or datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") >= now
|
||||
]
|
||||
return upcoming_reminders
|
||||
|
||||
@reminder.command("ls")
|
||||
async def reminder_ls(self, event: AstrMessageEvent):
|
||||
'''List all reminders.'''
|
||||
reminders = self.reminder_data.get(event.unified_msg_origin, [])
|
||||
'''List upcoming reminders.'''
|
||||
reminders = await self.get_upcoming_reminders(event.unified_msg_origin)
|
||||
if not reminders:
|
||||
yield event.plain_result("没有待办事项。")
|
||||
yield event.plain_result("没有正在进行的待办事项。")
|
||||
else:
|
||||
reminder_str = "待办事项:\n"
|
||||
reminder_str = "正在进行的待办事项:\n"
|
||||
for i, reminder in enumerate(reminders):
|
||||
time_ = reminder.get("datetime", "")
|
||||
if not time_:
|
||||
cron_expr = reminder.get("cron", "")
|
||||
time_ = reminder.get("cron_h", "") + f"(Cron: {cron_expr})"
|
||||
reminder_str += f"{i + 1}. {reminder['text']} - {time_}\n"
|
||||
reminder_str += "\n使用 /reminder rm <index> 删除待办事项。"
|
||||
reminder_str += "\n使用 /reminder rm <id> 删除待办事项。\n"
|
||||
yield event.plain_result(reminder_str)
|
||||
|
||||
|
||||
@reminder.command("rm")
|
||||
async def reminder_rm(self, event: AstrMessageEvent, index: int):
|
||||
'''Remove a reminder by index.'''
|
||||
reminders = self.reminder_data.get(event.unified_msg_origin, [])
|
||||
reminders = await self.get_upcoming_reminders(event.unified_msg_origin)
|
||||
|
||||
if not reminders:
|
||||
yield event.plain_result("没有待办事项。")
|
||||
elif index < 1 or index > len(reminders):
|
||||
yield event.plain_result("索引越界。")
|
||||
else:
|
||||
reminder = reminders.pop(index - 1)
|
||||
self.scheduler.remove_job(event.unified_msg_origin)
|
||||
job_id = reminder.get("id")
|
||||
try:
|
||||
self.scheduler.remove_job(job_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Remove job error: {e}")
|
||||
await self._save_data()
|
||||
yield event.plain_result("成功删除待办事项:\n" + reminder["text"])
|
||||
|
||||
|
||||
@@ -1,9 +1,30 @@
|
||||
import random
|
||||
from .config import HEADERS, USER_AGENTS
|
||||
from bs4 import BeautifulSoup
|
||||
from aiohttp import ClientSession
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import urllib.parse
|
||||
|
||||
HEADERS = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0',
|
||||
'Accept': '*/*',
|
||||
'Connection': 'keep-alive',
|
||||
'Accept-Language': 'en-GB,en;q=0.5'
|
||||
}
|
||||
|
||||
USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0'
|
||||
USER_AGENTS = [
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36',
|
||||
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0',
|
||||
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0'
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,11 +59,13 @@ class SearchEngine():
|
||||
if data:
|
||||
async with ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=data, timeout=self.TIMEOUT) as resp:
|
||||
return await resp.text(encoding="utf-8")
|
||||
ret = await resp.text(encoding="utf-8")
|
||||
return ret
|
||||
else:
|
||||
async with ClientSession() as session:
|
||||
async with session.get(url, headers=headers, timeout=self.TIMEOUT) as resp:
|
||||
return await resp.text(encoding="utf-8")
|
||||
ret = await resp.text(encoding="utf-8")
|
||||
return ret
|
||||
|
||||
|
||||
def tidy_text(self, text: str) -> str:
|
||||
@@ -53,6 +76,8 @@ class SearchEngine():
|
||||
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
query = urllib.parse.quote(query)
|
||||
|
||||
try:
|
||||
resp = await self._get_next_page(query)
|
||||
soup = BeautifulSoup(resp, 'html.parser')
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import List
|
||||
from .engine import SearchEngine, SearchResult
|
||||
from .config import USER_AGENT_BING
|
||||
from . import SearchEngine, SearchResult
|
||||
from . import USER_AGENT_BING
|
||||
|
||||
class Bing(SearchEngine):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.base_url = "https://www.bing.com"
|
||||
self.base_urls = ["https://cn.bing.com", "https://www.bing.com"]
|
||||
self.headers.update({'User-Agent': USER_AGENT_BING})
|
||||
|
||||
def _set_selector(self, selector: str):
|
||||
@@ -19,11 +19,17 @@ class Bing(SearchEngine):
|
||||
return selectors[selector]
|
||||
|
||||
async def _get_next_page(self, query) -> str:
|
||||
if self.page == 1:
|
||||
await self._get_html(self.base_url)
|
||||
url = f'{self.base_url}/search?q={query}&form=QBLH&sp=-1&lq=0&pq=hi&sc=10-2&qs=n&sk=&cvid=DE75965E2D6346D681288933984DE48F&ghsh=0&ghacc=0&ghpl='
|
||||
return await self._get_html(url, None)
|
||||
|
||||
# if self.page == 1:
|
||||
# await self._get_html(self.base_url)
|
||||
for base_url in self.base_urls:
|
||||
try:
|
||||
url = f'{base_url}/search?q={query}'
|
||||
return await self._get_html(url, None)
|
||||
except Exception as _:
|
||||
self.base_url = base_url
|
||||
continue
|
||||
raise Exception("Bing search failed")
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
results = await super().search(query, num_results)
|
||||
for result in results:
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
HEADERS = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0',
|
||||
'Accept': '*/*',
|
||||
'Connection': 'keep-alive',
|
||||
'Accept-Language': 'en-GB,en;q=0.5'
|
||||
}
|
||||
|
||||
USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0'
|
||||
USER_AGENTS = [
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36',
|
||||
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0',
|
||||
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0'
|
||||
]
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from googlesearch import search
|
||||
|
||||
from .engine import SearchEngine, SearchResult
|
||||
from . import SearchEngine, SearchResult
|
||||
|
||||
from typing import List
|
||||
|
||||
@@ -9,7 +9,6 @@ class Google(SearchEngine):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proxy = os.environ.get("https_proxy")
|
||||
print(f"Google Search using proxy: {self.proxy}")
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
results = []
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import random
|
||||
import re
|
||||
from bs4 import BeautifulSoup
|
||||
from .engine import SearchEngine, SearchResult
|
||||
from .config import USER_AGENTS
|
||||
from . import SearchEngine, SearchResult
|
||||
from . import USER_AGENTS
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from .engines.sogo import Sogo
|
||||
from .engines.google import Google
|
||||
from readability import Document
|
||||
from bs4 import BeautifulSoup
|
||||
from .engines.config import HEADERS, USER_AGENTS
|
||||
from .engines import HEADERS, USER_AGENTS
|
||||
|
||||
|
||||
@star.register(name="astrbot-web-searcher", desc="让 LLM 具有网页检索能力", author="Soulter", version="1.14.514")
|
||||
@@ -85,19 +85,19 @@ class Main(star.Star):
|
||||
RESULT_NUM = 5
|
||||
try:
|
||||
results = await self.google.search(query, RESULT_NUM)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"google search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search google failed")
|
||||
try:
|
||||
results = await self.bing_search.search(query, RESULT_NUM)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"bing search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search bing failed")
|
||||
try:
|
||||
results = await self.sogo_search.search(query, RESULT_NUM)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"sogo search error: {e}")
|
||||
if len(results) == 0:
|
||||
logger.debug("search sogo failed")
|
||||
|
||||
@@ -30,7 +30,6 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
success, err_message = await plugin_manager_pm.reload()
|
||||
assert success is True
|
||||
assert err_message is None
|
||||
assert len(star_handlers_registry) > 0 # package
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_crud(plugin_manager_pm: PluginManager):
|
||||
|
||||
Reference in New Issue
Block a user