Compare commits
53 Commits
v4.1.6
...
refactor/l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54340cca18 | ||
|
|
7191d28ada | ||
|
|
e6b5e3d282 | ||
|
|
1413d6b5fe | ||
|
|
dcd8a1094c | ||
|
|
e64b31b9ba | ||
|
|
080f347511 | ||
|
|
eaaff4298d | ||
|
|
dd5a02e8ef | ||
|
|
3211ec57ee | ||
|
|
6796afdaee | ||
|
|
cc6fe57773 | ||
|
|
1dfc831938 | ||
|
|
cafeda4abf | ||
|
|
d951b99718 | ||
|
|
0ad87209e5 | ||
|
|
1b50c5404d | ||
|
|
3007f67cab | ||
|
|
ee08659f01 | ||
|
|
baf5ad0fab | ||
|
|
8bdd748aec | ||
|
|
cef0c22f52 | ||
|
|
13d3fc5cfe | ||
|
|
b91141e2be | ||
|
|
f8a4b54165 | ||
|
|
afe007ca0b | ||
|
|
8a9a044f95 | ||
|
|
5eaf03e227 | ||
|
|
a8437d9331 | ||
|
|
e0392fa98b | ||
|
|
68ff8951de | ||
|
|
9c6b31e71c | ||
|
|
50f74f5ba2 | ||
|
|
b9de2aef60 | ||
|
|
7a47598538 | ||
|
|
3c8c28ebd5 | ||
|
|
524285f767 | ||
|
|
c2a34475f1 | ||
|
|
a69195a02b | ||
|
|
19d7438499 | ||
|
|
ccb380ce06 | ||
|
|
a35c439bbd | ||
|
|
09d1f96603 | ||
|
|
26aa18d980 | ||
|
|
d10b542797 | ||
|
|
ce4e4fb8dd | ||
|
|
8f4a31cf8c | ||
|
|
23549f13d6 | ||
|
|
869d11f9a6 | ||
|
|
02e73b82ee | ||
|
|
f85f87f545 | ||
|
|
1fff5713f3 | ||
|
|
8453ec36f0 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -30,4 +30,6 @@ packages/python_interpreter/workplace
|
||||
.conda/
|
||||
.idea
|
||||
pytest.ini
|
||||
.astrbot
|
||||
.astrbot
|
||||
|
||||
uv.lock
|
||||
21
Dockerfile
21
Dockerfile
@@ -4,8 +4,6 @@ WORKDIR /AstrBot
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
nodejs \
|
||||
npm \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
@@ -13,23 +11,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libssl-dev \
|
||||
ca-certificates \
|
||||
bash \
|
||||
ffmpeg \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN apt-get update && apt-get install -y curl gnupg && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m pip install uv
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
||||
|
||||
# 释出 ffmpeg
|
||||
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
|
||||
|
||||
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
|
||||
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">查看文档</a> |
|
||||
<a href="https://astrbot.app/">文档</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
@@ -191,7 +192,7 @@ pre-commit install
|
||||
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
|
||||
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
|
||||
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
|
||||
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
|
||||
- [KroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
|
||||
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
@@ -198,6 +198,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
|
||||
if not func_tool:
|
||||
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: 未找到工具 {func_tool_name}",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_start(
|
||||
self.run_context, func_tool, func_tool_args
|
||||
@@ -210,9 +221,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context=self.run_context,
|
||||
**func_tool_args,
|
||||
)
|
||||
async for resp in executor:
|
||||
|
||||
_final_resp: CallToolResult | None = None
|
||||
async for resp in executor: # type: ignore
|
||||
if isinstance(resp, CallToolResult):
|
||||
res = resp
|
||||
_final_resp = resp
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -279,13 +293,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool, func_tool_args, None
|
||||
self.run_context, func_tool, func_tool_args, _final_resp
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.1.6"
|
||||
VERSION = "4.3.2"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -64,7 +64,7 @@ DEFAULT_CONFIG = {
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "",
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
@@ -116,6 +116,15 @@ DEFAULT_CONFIG = {
|
||||
"port": 6185,
|
||||
},
|
||||
"platform": [],
|
||||
"platform_specific": {
|
||||
# 平台特异配置:按平台分类,平台下按功能分组
|
||||
"lark": {
|
||||
"pre_ack_emoji": {"enable": False, "emojis": ["Typing"]},
|
||||
},
|
||||
"telegram": {
|
||||
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
||||
},
|
||||
},
|
||||
"wake_prefix": ["/"],
|
||||
"log_level": "INFO",
|
||||
"pip_install_arg": "",
|
||||
@@ -766,7 +775,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
@@ -812,6 +821,21 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
"id": "tokenpony",
|
||||
"provider": "tokenpony",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
"id": "compshare",
|
||||
"provider": "compshare",
|
||||
@@ -869,6 +893,18 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 60,
|
||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
"provider": "coze",
|
||||
"provider_type": "chat_completion",
|
||||
"type": "coze",
|
||||
"enable": True,
|
||||
"coze_api_key": "",
|
||||
"bot_id": "",
|
||||
"coze_api_base": "https://api.coze.cn",
|
||||
"timeout": 60,
|
||||
"auto_save_history": True,
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"provider": "dashscope",
|
||||
@@ -1735,6 +1771,26 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
||||
"obvious": True,
|
||||
},
|
||||
"coze_api_key": {
|
||||
"description": "Coze API Key",
|
||||
"type": "string",
|
||||
"hint": "Coze API 密钥,用于访问 Coze 服务。",
|
||||
},
|
||||
"bot_id": {
|
||||
"description": "Bot ID",
|
||||
"type": "string",
|
||||
"hint": "Coze 机器人的 ID,在 Coze 平台上创建机器人后获得。",
|
||||
},
|
||||
"coze_api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
|
||||
},
|
||||
"auto_save_history": {
|
||||
"description": "由 Coze 管理对话记录",
|
||||
"type": "bool",
|
||||
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_settings": {
|
||||
@@ -1944,26 +2000,28 @@ CONFIG_METADATA_3 = {
|
||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||
},
|
||||
"provider_stt_settings.enable": {
|
||||
"description": "默认启用语音转文本",
|
||||
"description": "启用语音转文本",
|
||||
"type": "bool",
|
||||
"hint": "STT 总开关。",
|
||||
},
|
||||
"provider_stt_settings.provider_id": {
|
||||
"description": "语音转文本模型",
|
||||
"description": "默认语音转文本模型",
|
||||
"type": "string",
|
||||
"hint": "留空代表不使用。",
|
||||
"hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型。",
|
||||
"_special": "select_provider_stt",
|
||||
"condition": {
|
||||
"provider_stt_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_tts_settings.enable": {
|
||||
"description": "默认启用文本转语音",
|
||||
"description": "启用文本转语音",
|
||||
"type": "bool",
|
||||
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
|
||||
},
|
||||
"provider_tts_settings.provider_id": {
|
||||
"description": "文本转语音模型",
|
||||
"description": "默认文本转语音模型",
|
||||
"type": "string",
|
||||
"hint": "留空代表不使用。",
|
||||
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
|
||||
"_special": "select_provider_tts",
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
@@ -2075,12 +2133,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
"hint": "例子: 如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
||||
},
|
||||
"provider_settings.prompt_prefix": {
|
||||
"description": "额外前缀提示词",
|
||||
"description": "用户提示词",
|
||||
"type": "string",
|
||||
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
|
||||
},
|
||||
"provider_settings.dual_output": {
|
||||
"provider_tts_settings.dual_output": {
|
||||
"description": "开启 TTS 时同时输出语音和文字内容",
|
||||
"type": "bool",
|
||||
},
|
||||
@@ -2261,6 +2321,32 @@ CONFIG_METADATA_3 = {
|
||||
"description": "用户权限不足时是否回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.lark.pre_ack_emoji.enable": {
|
||||
"description": "[飞书] 启用预回应表情",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.lark.pre_ack_emoji.emojis": {
|
||||
"description": "表情列表(飞书表情枚举名)",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "表情枚举名参考:https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce",
|
||||
"condition": {
|
||||
"platform_specific.lark.pre_ack_emoji.enable": True,
|
||||
},
|
||||
},
|
||||
"platform_specific.telegram.pre_ack_emoji.enable": {
|
||||
"description": "[Telegram] 启用预回应表情",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.telegram.pre_ack_emoji.emojis": {
|
||||
"description": "表情列表(Unicode)",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "Telegram 仅支持固定反应集合,参考:https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9",
|
||||
"condition": {
|
||||
"platform_specific.telegram.pre_ack_emoji.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -87,17 +87,25 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
f = False
|
||||
if not conversation_id:
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
f = True
|
||||
if conversation_id:
|
||||
await self.db.delete_conversation(cid=conversation_id)
|
||||
if f:
|
||||
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
|
||||
if curr_cid == conversation_id:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
|
||||
"""删除会话的所有对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
"""
|
||||
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
|
||||
@@ -154,12 +154,17 @@ class BaseDatabase(abc.ABC):
|
||||
"""Delete a conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
"""Delete all conversations for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict],
|
||||
content: dict,
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
@@ -282,3 +287,14 @@ class BaseDatabase(abc.ABC):
|
||||
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
||||
# """Get all LLM messages for a specific conversation."""
|
||||
# ...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_session_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
search_query: str | None = None,
|
||||
platform: str | None = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
|
||||
...
|
||||
|
||||
@@ -75,7 +75,9 @@ class Persona(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "personas"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
persona_id: str = Field(max_length=255, nullable=False)
|
||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
@@ -135,7 +137,9 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "platform_message_history"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||
@@ -158,8 +162,8 @@ class Attachment(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "attachments"
|
||||
|
||||
inner_attachment_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
attachment_id: str = Field(
|
||||
max_length=36,
|
||||
|
||||
@@ -15,10 +15,8 @@ from astrbot.core.db.po import (
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
from sqlalchemy import select, update, delete, text
|
||||
from sqlmodel import select, update, delete, text, func, or_, desc, col
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import or_
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
@@ -34,6 +32,12 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""Initialize the database by creating tables if they do not exist."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
# ====
|
||||
@@ -42,10 +46,10 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
async def insert_platform_stats(
|
||||
self,
|
||||
platform_id: str,
|
||||
platform_type: str,
|
||||
count: int = 1,
|
||||
timestamp: datetime = None,
|
||||
platform_id,
|
||||
platform_type,
|
||||
count=1,
|
||||
timestamp=None,
|
||||
) -> None:
|
||||
"""Insert a new platform statistic record."""
|
||||
async with self.get_db() as session:
|
||||
@@ -76,7 +80,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
|
||||
select(func.count(col(PlatformStat.platform_id))).select_from(
|
||||
PlatformStat
|
||||
)
|
||||
)
|
||||
count = result.scalar_one_or_none()
|
||||
return count if count is not None else 0
|
||||
@@ -96,7 +102,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
return result.scalars().all()
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ====
|
||||
# Conversation Management
|
||||
@@ -112,7 +118,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
if platform_id:
|
||||
query = query.where(ConversationV2.platform_id == platform_id)
|
||||
# order by
|
||||
query = query.order_by(ConversationV2.created_at.desc())
|
||||
query = query.order_by(desc(ConversationV2.created_at))
|
||||
result = await session.execute(query)
|
||||
|
||||
return result.scalars().all()
|
||||
@@ -130,7 +136,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
offset = (page - 1) * page_size
|
||||
result = await session.execute(
|
||||
select(ConversationV2)
|
||||
.order_by(ConversationV2.created_at.desc())
|
||||
.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
@@ -151,25 +157,26 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
if platform_ids:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.platform_id.in_(platform_ids)
|
||||
col(ConversationV2.platform_id).in_(platform_ids)
|
||||
)
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
ConversationV2.title.ilike(f"%{search_query}%"),
|
||||
ConversationV2.content.ilike(f"%{search_query}%"),
|
||||
ConversationV2.user_id.ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.title).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
||||
)
|
||||
)
|
||||
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||
for msg_type in kwargs["message_types"]:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.user_id.ilike(f"%:{msg_type}:%")
|
||||
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
|
||||
)
|
||||
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.platform_id.in_(kwargs["platforms"])
|
||||
col(ConversationV2.platform_id).in_(kwargs["platforms"])
|
||||
)
|
||||
|
||||
# Get total count matching the filters
|
||||
@@ -180,7 +187,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
result_query = (
|
||||
base_query.order_by(ConversationV2.created_at.desc())
|
||||
base_query.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
@@ -226,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(ConversationV2).where(
|
||||
ConversationV2.conversation_id == cid
|
||||
col(ConversationV2.conversation_id) == cid
|
||||
)
|
||||
values = {}
|
||||
if title is not None:
|
||||
@@ -246,9 +253,126 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.conversation_id) == cid
|
||||
)
|
||||
)
|
||||
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
|
||||
)
|
||||
|
||||
async def get_session_conversations(
|
||||
self,
|
||||
page=1,
|
||||
page_size=20,
|
||||
search_query=None,
|
||||
platform=None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
base_query = (
|
||||
select(
|
||||
col(Preference.scope_id).label("session_id"),
|
||||
func.json_extract(Preference.value, "$.val").label(
|
||||
"conversation_id"
|
||||
), # type: ignore
|
||||
col(ConversationV2.persona_id).label("persona_id"),
|
||||
col(ConversationV2.title).label("title"),
|
||||
col(Persona.persona_id).label("persona_name"),
|
||||
)
|
||||
.select_from(Preference)
|
||||
.outerjoin(
|
||||
ConversationV2,
|
||||
func.json_extract(Preference.value, "$.val")
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
|
||||
# 搜索筛选
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
# 平台筛选
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
base_query = base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern)
|
||||
)
|
||||
|
||||
# 排序
|
||||
base_query = base_query.order_by(Preference.scope_id)
|
||||
|
||||
# 分页结果
|
||||
result_query = base_query.offset(offset).limit(page_size)
|
||||
result = await session.execute(result_query)
|
||||
rows = result.fetchall()
|
||||
|
||||
# 查询总数(应用相同的筛选条件)
|
||||
count_base_query = (
|
||||
select(func.count(col(Preference.scope_id)))
|
||||
.select_from(Preference)
|
||||
.outerjoin(
|
||||
ConversationV2,
|
||||
func.json_extract(Preference.value, "$.val")
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
|
||||
# 应用相同的搜索和平台筛选条件到计数查询
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
count_base_query = count_base_query.where(
|
||||
or_(
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
count_base_query = count_base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern)
|
||||
)
|
||||
|
||||
total_result = await session.execute(count_base_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
sessions_data = [
|
||||
{
|
||||
"session_id": row.session_id,
|
||||
"conversation_id": row.conversation_id,
|
||||
"persona_id": row.persona_id,
|
||||
"title": row.title,
|
||||
"persona_name": row.persona_name,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
return sessions_data, total
|
||||
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id,
|
||||
@@ -282,9 +406,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
cutoff_time = now - timedelta(seconds=offset_sec)
|
||||
await session.execute(
|
||||
delete(PlatformMessageHistory).where(
|
||||
PlatformMessageHistory.platform_id == platform_id,
|
||||
PlatformMessageHistory.user_id == user_id,
|
||||
PlatformMessageHistory.created_at < cutoff_time,
|
||||
col(PlatformMessageHistory.platform_id) == platform_id,
|
||||
col(PlatformMessageHistory.user_id) == user_id,
|
||||
col(PlatformMessageHistory.created_at) < cutoff_time,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -301,7 +425,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
PlatformMessageHistory.platform_id == platform_id,
|
||||
PlatformMessageHistory.user_id == user_id,
|
||||
)
|
||||
.order_by(PlatformMessageHistory.created_at.desc())
|
||||
.order_by(desc(PlatformMessageHistory.created_at))
|
||||
)
|
||||
result = await session.execute(query.offset(offset).limit(page_size))
|
||||
return result.scalars().all()
|
||||
@@ -323,7 +447,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""Get an attachment by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(Attachment.id == attachment_id)
|
||||
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -366,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(Persona).where(Persona.persona_id == persona_id)
|
||||
query = update(Persona).where(col(Persona.persona_id) == persona_id)
|
||||
values = {}
|
||||
if system_prompt is not None:
|
||||
values["system_prompt"] = system_prompt
|
||||
@@ -386,7 +510,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Persona).where(Persona.persona_id == persona_id)
|
||||
delete(Persona).where(col(Persona.persona_id) == persona_id)
|
||||
)
|
||||
|
||||
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||
@@ -441,9 +565,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
Preference.scope == scope,
|
||||
Preference.scope_id == scope_id,
|
||||
Preference.key == key,
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
col(Preference.key) == key,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
@@ -455,7 +579,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
Preference.scope == scope, Preference.scope_id == scope_id
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
@@ -482,7 +607,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
DeprecatedPlatformStat(
|
||||
name=data.platform_id,
|
||||
count=data.count,
|
||||
timestamp=data.timestamp.timestamp(),
|
||||
timestamp=int(data.timestamp.timestamp()),
|
||||
)
|
||||
)
|
||||
return deprecated_stats
|
||||
@@ -540,7 +665,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
DeprecatedPlatformStat(
|
||||
name=platform_id,
|
||||
count=count,
|
||||
timestamp=start_time.timestamp(),
|
||||
timestamp=int(start_time.timestamp()),
|
||||
)
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
@@ -97,5 +97,6 @@ async def call_event_hook(
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return True
|
||||
|
||||
return event.is_stopped()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
@@ -22,6 +23,26 @@ class PreProcessStage(Stage):
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""在处理事件之前的预处理"""
|
||||
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
|
||||
supported = {"telegram", "lark"}
|
||||
platform = event.get_platform_name()
|
||||
cfg = (
|
||||
self.config.get("platform_specific", {})
|
||||
.get(platform, {})
|
||||
.get("pre_ack_emoji", {})
|
||||
) or {}
|
||||
emojis = cfg.get("emojis") or []
|
||||
if (
|
||||
cfg.get("enable", False)
|
||||
and platform in supported
|
||||
and emojis
|
||||
and event.is_at_or_wake_command
|
||||
):
|
||||
try:
|
||||
await event.react(random.choice(emojis))
|
||||
except Exception as e:
|
||||
logger.warning(f"{platform} 预回应表情发送失败: {e}")
|
||||
|
||||
# 路径映射
|
||||
if mappings := self.platform_settings.get("path_mapping", []):
|
||||
# 支持 Record,Image 消息段的路径映射。
|
||||
@@ -46,6 +67,9 @@ class PreProcessStage(Stage):
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||
if not stt_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
|
||||
)
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
|
||||
@@ -291,13 +291,6 @@ async def run_agent(
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -524,6 +517,14 @@ class LLMRequestSubStage(Stage):
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_webchat(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
):
|
||||
@@ -536,7 +537,23 @@ class LLMRequestSubStage(Stage):
|
||||
latest_pair = messages[-2:]
|
||||
if not latest_pair:
|
||||
return
|
||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||
content = latest_pair[0].get("content", "")
|
||||
if isinstance(content, list):
|
||||
# 多模态
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "image":
|
||||
text_parts.append("[图片]")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
cleaned_text = "User: " + " ".join(text_parts).strip()
|
||||
elif isinstance(content, str):
|
||||
cleaned_text = "User: " + content.strip()
|
||||
else:
|
||||
return
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await prov.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
|
||||
@@ -190,6 +190,16 @@ class RespondStage(Stage):
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
# 将 Plain 为空的消息段移除
|
||||
result.chain = [
|
||||
comp
|
||||
for comp in result.chain
|
||||
if not (
|
||||
isinstance(comp, Comp.Plain)
|
||||
and (not comp.text or not comp.text.strip())
|
||||
)
|
||||
]
|
||||
|
||||
# 发送消息链
|
||||
# Record 需要强制单独发送
|
||||
need_separately = {ComponentType.Record}
|
||||
|
||||
@@ -183,56 +183,60 @@ class ResultDecorateStage(Stage):
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
if not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。"
|
||||
)
|
||||
else:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
)
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
elif (
|
||||
@@ -275,7 +279,6 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = [Image.fromFileSystem(url)]
|
||||
|
||||
# 触发转发消息
|
||||
has_forwarded = False
|
||||
if event.get_platform_name() == "aiocqhttp":
|
||||
word_cnt = 0
|
||||
for comp in result.chain:
|
||||
@@ -286,9 +289,9 @@ class ResultDecorateStage(Stage):
|
||||
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
|
||||
)
|
||||
result.chain = [node]
|
||||
has_forwarded = True
|
||||
|
||||
if not has_forwarded:
|
||||
has_plain = any(isinstance(item, Plain) for item in result.chain)
|
||||
if has_plain:
|
||||
# at 回复
|
||||
if (
|
||||
self.reply_with_mention
|
||||
|
||||
@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
|
||||
"""检查会话是否整体启用"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
self.ctx = ctx
|
||||
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
|
||||
# workaround for #2309
|
||||
conv_id = await self.conv_mgr.get_curr_conversation_id(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if not conv_id:
|
||||
await self.conv_mgr.new_conversation(
|
||||
event.unified_msg_origin, platform_id=event.get_platform_id()
|
||||
)
|
||||
|
||||
event.stop_event()
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
|
||||
activated_handlers.append(handler)
|
||||
if "parsed_params" in event.get_extra():
|
||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||
"parsed_params"
|
||||
)
|
||||
is_group_cmd_handler = any(
|
||||
isinstance(f, CommandGroupFilter) for f in handler.event_filters
|
||||
)
|
||||
if not is_group_cmd_handler:
|
||||
activated_handlers.append(handler)
|
||||
if "parsed_params" in event.get_extra(default={}):
|
||||
handlers_parsed_params[handler.handler_full_name] = (
|
||||
event.get_extra("parsed_params")
|
||||
)
|
||||
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
@@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
def __init__(
|
||||
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""是否唤醒(是否通过 WakingStage)"""
|
||||
self.is_at_or_wake_command = False
|
||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||
self._extras = {}
|
||||
self._extras: dict[str, Any] = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.id,
|
||||
message_type=message_obj.type,
|
||||
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self.unified_msg_origin = str(self.session)
|
||||
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
||||
self._result: MessageEventResult = None
|
||||
self._result: MessageEventResult | None = None
|
||||
"""消息事件的结果"""
|
||||
|
||||
self._has_send_oper = False
|
||||
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self._extras[key] = value
|
||||
|
||||
def get_extra(self, key=None):
|
||||
def get_extra(
|
||||
self, key: str | None = None, default: _VT = None
|
||||
) -> dict[str, Any] | _VT:
|
||||
"""
|
||||
获取额外的信息。
|
||||
"""
|
||||
if key is None:
|
||||
return self._extras
|
||||
return self._extras.get(key, None)
|
||||
return self._extras.get(key, default)
|
||||
|
||||
def clear_extra(self):
|
||||
"""
|
||||
@@ -412,6 +416,16 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def react(self, emoji: str):
|
||||
"""
|
||||
对消息添加表情回应。
|
||||
|
||||
默认实现为发送一条包含该表情的消息。
|
||||
注意:此实现并不一定符合所有平台的原生“表情回应”行为。
|
||||
如需支持平台原生的消息反应功能,请在对应平台的子类中重写本方法。
|
||||
"""
|
||||
await self.send(MessageChain([Plain(emoji)]))
|
||||
|
||||
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
|
||||
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
|
||||
|
||||
|
||||
@@ -14,3 +14,5 @@ class PlatformMetadata:
|
||||
"""平台的默认配置模板"""
|
||||
adapter_display_name: str = None
|
||||
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
||||
logo_path: str = None
|
||||
"""平台适配器的 logo 文件路径(相对于插件目录)"""
|
||||
|
||||
@@ -13,10 +13,12 @@ def register_platform_adapter(
|
||||
desc: str,
|
||||
default_config_tmpl: dict = None,
|
||||
adapter_display_name: str = None,
|
||||
logo_path: str = None,
|
||||
):
|
||||
"""用于注册平台适配器的带参装饰器。
|
||||
|
||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
@@ -39,6 +41,7 @@ def register_platform_adapter(
|
||||
description=desc,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -107,6 +107,22 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
request = (
|
||||
CreateMessageReactionRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
.request_body(
|
||||
CreateMessageReactionRequestBody.builder()
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji).build())
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await self.bot.im.v1.message_reaction.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
||||
return None
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
|
||||
@@ -95,9 +95,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||
)
|
||||
id_ = self.config.get("id") or "telegram"
|
||||
return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
@@ -117,6 +116,10 @@ class TelegramPlatformAdapter(Platform):
|
||||
)
|
||||
self.scheduler.start()
|
||||
|
||||
if not self.application.updater:
|
||||
logger.error("Telegram Updater is not initialized. Cannot start polling.")
|
||||
return
|
||||
|
||||
queue = self.application.updater.start_polling()
|
||||
logger.info("Telegram Platform Adapter is running.")
|
||||
await queue
|
||||
@@ -194,6 +197,11 @@ class TelegramPlatformAdapter(Platform):
|
||||
return cmd_name, description
|
||||
|
||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if not update.effective_chat:
|
||||
logger.warning(
|
||||
"Received a start command without an effective chat, skipping /start reply."
|
||||
)
|
||||
return
|
||||
await context.bot.send_message(
|
||||
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||
)
|
||||
@@ -206,15 +214,20 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
async def convert_message(
|
||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
|
||||
) -> AstrBotMessage:
|
||||
) -> AstrBotMessage | None:
|
||||
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
|
||||
|
||||
@param update: Telegram 的 Update 对象。
|
||||
@param context: Telegram 的 Context 对象。
|
||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||
"""
|
||||
if not update.message:
|
||||
logger.warning("Received an update without a message.")
|
||||
return None
|
||||
|
||||
message = AstrBotMessage()
|
||||
message.session_id = str(update.message.chat.id)
|
||||
|
||||
# 获得是群聊还是私聊
|
||||
if update.message.chat.type == ChatType.PRIVATE:
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -225,10 +238,13 @@ class TelegramPlatformAdapter(Platform):
|
||||
# Topic Group
|
||||
message.group_id += "#" + str(update.message.message_thread_id)
|
||||
message.session_id = message.group_id
|
||||
|
||||
message.message_id = str(update.message.message_id)
|
||||
_from_user = update.message.from_user
|
||||
if not _from_user:
|
||||
logger.warning("[Telegram] Received a message without a from_user.")
|
||||
return None
|
||||
message.sender = MessageMember(
|
||||
str(update.message.from_user.id), update.message.from_user.username
|
||||
str(_from_user.id), _from_user.username or "Unknown"
|
||||
)
|
||||
message.self_id = str(context.bot.username)
|
||||
message.raw_message = update
|
||||
@@ -247,22 +263,32 @@ class TelegramPlatformAdapter(Platform):
|
||||
)
|
||||
reply_abm = await self.convert_message(reply_update, context, False)
|
||||
|
||||
message.message.append(
|
||||
Comp.Reply(
|
||||
id=reply_abm.message_id,
|
||||
chain=reply_abm.message,
|
||||
sender_id=reply_abm.sender.user_id,
|
||||
sender_nickname=reply_abm.sender.nickname,
|
||||
time=reply_abm.timestamp,
|
||||
message_str=reply_abm.message_str,
|
||||
text=reply_abm.message_str,
|
||||
qq=reply_abm.sender.user_id,
|
||||
if reply_abm:
|
||||
message.message.append(
|
||||
Comp.Reply(
|
||||
id=reply_abm.message_id,
|
||||
chain=reply_abm.message,
|
||||
sender_id=reply_abm.sender.user_id,
|
||||
sender_nickname=reply_abm.sender.nickname,
|
||||
time=reply_abm.timestamp,
|
||||
message_str=reply_abm.message_str,
|
||||
text=reply_abm.message_str,
|
||||
qq=reply_abm.sender.user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if update.message.text:
|
||||
# 处理文本消息
|
||||
plain_text = update.message.text
|
||||
if (
|
||||
message.type == MessageType.GROUP_MESSAGE
|
||||
and update.message
|
||||
and update.message.reply_to_message
|
||||
and update.message.reply_to_message.from_user
|
||||
and update.message.reply_to_message.from_user.id == context.bot.id
|
||||
):
|
||||
plain_text2 = f"/@{context.bot.username} " + plain_text
|
||||
plain_text = plain_text2
|
||||
|
||||
# 群聊场景命令特殊处理
|
||||
if plain_text.startswith("/"):
|
||||
@@ -328,15 +354,25 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
elif update.message.document:
|
||||
file = await update.message.document.get_file()
|
||||
message.message = [
|
||||
Comp.File(file=file.file_path, name=update.message.document.file_name),
|
||||
]
|
||||
file_name = update.message.document.file_name or uuid.uuid4().hex
|
||||
file_path = file.file_path
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
f"Telegram document file_path is None, cannot save the file {file_name}."
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
message.message = [
|
||||
Comp.Video(file=file.file_path, path=file.file_path),
|
||||
]
|
||||
file_name = update.message.video.file_name or uuid.uuid4().hex
|
||||
file_path = file.file_path
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
f"Telegram video file_path is None, cannot save the file {file_name}."
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.Video(file=file_path, path=file.file_path))
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from telegram import ReactionTypeEmoji, ReactionTypeCustomEmoji
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -135,6 +136,39 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str | None, big: bool = False):
|
||||
"""
|
||||
给原消息添加 Telegram 反应:
|
||||
- 普通 emoji:传入 '👍'、'😂' 等
|
||||
- 自定义表情:传入其 custom_emoji_id(纯数字字符串)
|
||||
- 取消本机器人的反应:传入 None 或空字符串
|
||||
"""
|
||||
try:
|
||||
# 解析 chat_id(去掉超级群的 "#<thread_id>" 片段)
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
chat_id = (self.message_obj.group_id or "").split("#")[0]
|
||||
else:
|
||||
chat_id = self.get_sender_id()
|
||||
|
||||
message_id = int(self.message_obj.message_id)
|
||||
|
||||
# 组装 reaction 参数(必须是 ReactionType 的列表)
|
||||
if not emoji: # 清空本 bot 的反应
|
||||
reaction_param = [] # 空列表表示移除本 bot 的反应
|
||||
elif emoji.isdigit(): # 自定义表情:传 custom_emoji_id
|
||||
reaction_param = [ReactionTypeCustomEmoji(emoji)]
|
||||
else: # 普通 emoji
|
||||
reaction_param = [ReactionTypeEmoji(emoji)]
|
||||
|
||||
await self.client.set_message_reaction(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
reaction=reaction_param, # 注意是列表
|
||||
is_big=big, # 可选:大动画
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] 添加反应失败: {e}")
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
message_thread_id = None
|
||||
|
||||
|
||||
@@ -185,6 +185,7 @@ class WecomPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"wecom",
|
||||
"wecom 适配器",
|
||||
id=self.config.get("id", "wecom"),
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -184,6 +184,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"weixin_official_account",
|
||||
"微信公众平台 适配器",
|
||||
id=self.config.get("id", "weixin_official_account"),
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -234,6 +234,8 @@ class ProviderManager:
|
||||
)
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "coze":
|
||||
from .sources.coze_source import ProviderCoze as ProviderCoze
|
||||
case "dashscope":
|
||||
from .sources.dashscope_source import (
|
||||
ProviderDashscope as ProviderDashscope,
|
||||
|
||||
@@ -75,7 +75,7 @@ class Provider(AbstractProvider):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_models(self) -> List[str]:
|
||||
async def get_models(self) -> List[str]:
|
||||
"""获得支持的模型列表"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
314
astrbot/core/provider/sources/coze_api_client.py
Normal file
314
astrbot/core/provider/sources/coze_api_client.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import io
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class CozeAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = None
|
||||
|
||||
async def _ensure_session(self):
|
||||
"""确保HTTP session存在"""
|
||||
if self.session is None:
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=False if self.api_base.startswith("http://") else True,
|
||||
limit=100,
|
||||
limit_per_host=30,
|
||||
keepalive_timeout=30,
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=120, # 默认超时时间
|
||||
connect=30,
|
||||
sock_read=120,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "text/event-stream",
|
||||
}
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=headers, timeout=timeout, connector=connector
|
||||
)
|
||||
return self.session
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file_data: bytes,
|
||||
) -> str:
|
||||
"""上传文件到 Coze 并返回 file_id
|
||||
|
||||
Args:
|
||||
file_data (bytes): 文件的二进制数据
|
||||
Returns:
|
||||
str: 上传成功后返回的 file_id
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v1/files/upload"
|
||||
|
||||
try:
|
||||
file_io = io.BytesIO(file_data)
|
||||
async with session.post(
|
||||
url,
|
||||
data={
|
||||
"file": file_io,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(
|
||||
f"文件上传响应状态: {response.status}, 内容: {response_text}"
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await response.json()
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"文件上传响应解析失败: {response_text}")
|
||||
|
||||
if result.get("code") != 0:
|
||||
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
|
||||
|
||||
file_id = result["data"]["id"]
|
||||
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
||||
return file_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("文件上传超时")
|
||||
raise Exception("文件上传超时")
|
||||
except Exception as e:
|
||||
logger.error(f"文件上传失败: {str(e)}")
|
||||
raise Exception(f"文件上传失败: {str(e)}")
|
||||
|
||||
async def download_image(self, image_url: str) -> bytes:
|
||||
"""下载图片并返回字节数据
|
||||
|
||||
Args:
|
||||
image_url (str): 图片的URL
|
||||
Returns:
|
||||
bytes: 图片的二进制数据
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
|
||||
try:
|
||||
async with session.get(image_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"下载图片失败,状态码: {response.status}")
|
||||
|
||||
image_data = await response.read()
|
||||
return image_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片失败 {image_url}: {str(e)}")
|
||||
raise Exception(f"下载图片失败: {str(e)}")
|
||||
|
||||
async def chat_messages(
|
||||
self,
|
||||
bot_id: str,
|
||||
user_id: str,
|
||||
additional_messages: List[Dict] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
auto_save_history: bool = True,
|
||||
stream: bool = True,
|
||||
timeout: float = 120,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""发送聊天消息并返回流式响应
|
||||
|
||||
Args:
|
||||
bot_id: Bot ID
|
||||
user_id: 用户ID
|
||||
additional_messages: 额外消息列表
|
||||
conversation_id: 会话ID
|
||||
auto_save_history: 是否自动保存历史
|
||||
stream: 是否流式响应
|
||||
timeout: 超时时间
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/chat"
|
||||
|
||||
payload = {
|
||||
"bot_id": bot_id,
|
||||
"user_id": user_id,
|
||||
"stream": stream,
|
||||
"auto_save_history": auto_save_history,
|
||||
}
|
||||
|
||||
if additional_messages:
|
||||
payload["additional_messages"] = additional_messages
|
||||
|
||||
params = {}
|
||||
if conversation_id:
|
||||
params["conversation_id"] = conversation_id
|
||||
|
||||
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
||||
|
||||
# SSE
|
||||
buffer = ""
|
||||
event_type = None
|
||||
event_data = None
|
||||
|
||||
async for chunk in response.content:
|
||||
if chunk:
|
||||
buffer += chunk.decode("utf-8", errors="ignore")
|
||||
lines = buffer.split("\n")
|
||||
buffer = lines[-1]
|
||||
|
||||
for line in lines[:-1]:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
if event_type and event_data:
|
||||
yield {"event": event_type, "data": event_data}
|
||||
event_type = None
|
||||
event_data = None
|
||||
elif line.startswith("event:"):
|
||||
event_type = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
if data_str and data_str != "[DONE]":
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
event_data = {"content": data_str}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
|
||||
except Exception as e:
|
||||
raise Exception(f"Coze API 流式请求失败: {str(e)}")
|
||||
|
||||
async def clear_context(self, conversation_id: str):
|
||||
"""清空会话上下文
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
Returns:
|
||||
dict: API响应结果
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/conversation/message/clear_context"
|
||||
payload = {"conversation_id": conversation_id}
|
||||
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
|
||||
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
raise Exception("Coze API 返回非JSON格式")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception("Coze API 请求超时")
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"Coze API 请求失败: {str(e)}")
|
||||
|
||||
async def get_message_list(
|
||||
self,
|
||||
conversation_id: str,
|
||||
order: str = "desc",
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""获取消息列表
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
order: 排序方式 (asc/desc)
|
||||
limit: 限制数量
|
||||
offset: 偏移量
|
||||
Returns:
|
||||
dict: API响应结果
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/conversation/message/list"
|
||||
params = {
|
||||
"conversation_id": conversation_id,
|
||||
"order": order,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Coze消息列表失败: {str(e)}")
|
||||
raise Exception(f"获取Coze消息列表失败: {str(e)}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭会话"""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
async def test_coze_api_client():
|
||||
api_key = os.getenv("COZE_API_KEY", "")
|
||||
bot_id = os.getenv("COZE_BOT_ID", "")
|
||||
client = CozeAPIClient(api_key=api_key)
|
||||
|
||||
try:
|
||||
with open("README.md", "rb") as f:
|
||||
file_data = f.read()
|
||||
file_id = await client.upload_file(file_data)
|
||||
print(f"Uploaded file_id: {file_id}")
|
||||
async for event in client.chat_messages(
|
||||
bot_id=bot_id,
|
||||
user_id="test_user",
|
||||
additional_messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
[
|
||||
{"type": "text", "text": "这是什么"},
|
||||
{"type": "file", "file_id": file_id},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
"content_type": "object_string",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
):
|
||||
print(f"Event: {event}")
|
||||
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(test_coze_api_client())
|
||||
635
astrbot/core/provider/sources/coze_source.py
Normal file
635
astrbot/core/provider/sources/coze_source.py
Normal file
@@ -0,0 +1,635 @@
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import AsyncGenerator, Dict
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
|
||||
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
|
||||
class ProviderCoze(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://")
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||
self.conversation_ids: Dict[str, str] = {}
|
||||
self.file_id_cache: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
# 创建 API 客户端
|
||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||
|
||||
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
|
||||
"""生成统一的缓存键
|
||||
|
||||
Args:
|
||||
data: 图片数据或路径
|
||||
is_base64: 是否是 base64 数据
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
"""
|
||||
|
||||
try:
|
||||
if is_base64 and data.startswith("data:image/"):
|
||||
try:
|
||||
header, encoded = data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
cache_key = hashlib.md5(image_bytes).hexdigest()
|
||||
return cache_key
|
||||
except Exception:
|
||||
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
if data.startswith(("http://", "https://")):
|
||||
# URL图片,使用URL作为缓存键
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
clean_path = (
|
||||
data.split("_")[0]
|
||||
if "_" in data and len(data.split("_")) >= 3
|
||||
else data
|
||||
)
|
||||
|
||||
if os.path.exists(clean_path):
|
||||
with open(clean_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
cache_key = hashlib.md5(file_content).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
|
||||
return cache_key
|
||||
|
||||
async def _upload_file(
|
||||
self,
|
||||
file_data: bytes,
|
||||
session_id: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> str:
|
||||
"""上传文件到 Coze 并返回 file_id"""
|
||||
# 使用 API 客户端上传文件
|
||||
file_id = await self.api_client.upload_file(file_data)
|
||||
|
||||
# 缓存 file_id
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
async def _download_and_upload_image(
|
||||
self, image_url: str, session_id: str | None = None
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
# 计算哈希实现缓存
|
||||
cache_key = self._generate_cache_key(image_url) if session_id else None
|
||||
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
return file_id
|
||||
|
||||
try:
|
||||
image_data = await self.api_client.download_image(image_url)
|
||||
|
||||
file_id = await self._upload_file(image_data, session_id, cache_key)
|
||||
|
||||
if session_id and cache_key:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {image_url}: {str(e)}")
|
||||
raise Exception(f"处理图片失败: {str(e)}")
|
||||
|
||||
async def _process_context_images(
|
||||
self, content: str | list, session_id: str
|
||||
) -> str:
|
||||
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
|
||||
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
processed_content = []
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
processed_content.append(item)
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
processed_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# 处理图片逻辑
|
||||
if "file_id" in item:
|
||||
# 已经有 file_id
|
||||
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
|
||||
processed_content.append(item)
|
||||
else:
|
||||
# 获取图片数据
|
||||
image_data = ""
|
||||
if "image_url" in item and isinstance(item["image_url"], dict):
|
||||
image_data = item["image_url"].get("url", "")
|
||||
elif "data" in item:
|
||||
image_data = item.get("data", "")
|
||||
elif "url" in item:
|
||||
image_data = item.get("url", "")
|
||||
|
||||
if not image_data:
|
||||
continue
|
||||
# 计算哈希用于缓存
|
||||
cache_key = self._generate_cache_key(
|
||||
image_data, is_base64=image_data.startswith("data:image/")
|
||||
)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id}
|
||||
)
|
||||
else:
|
||||
# 上传图片并缓存
|
||||
if image_data.startswith("data:image/"):
|
||||
# base64 处理
|
||||
_, encoded = image_data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
elif image_data.startswith(("http://", "https://")):
|
||||
# URL 图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
image_data, session_id
|
||||
)
|
||||
# 为URL图片也添加缓存
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
elif os.path.exists(image_data):
|
||||
# 本地文件
|
||||
with open(image_data, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"无法处理的图片格式: {image_data[:50]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id}
|
||||
)
|
||||
|
||||
result = json.dumps(processed_content, ensure_ascii=False)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"处理上下文图片失败: {str(e)}")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
else:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""文本对话, 内部使用流式接口实现非流式
|
||||
|
||||
Args:
|
||||
prompt (str): 用户提示词
|
||||
session_id (str): 会话ID
|
||||
image_urls (List[str]): 图片URL列表
|
||||
func_tool (FuncCall): 函数调用工具(不支持)
|
||||
contexts (List): 上下文列表
|
||||
system_prompt (str): 系统提示语
|
||||
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
|
||||
model (str): 模型名称(不支持)
|
||||
Returns:
|
||||
LLMResponse: LLM响应对象
|
||||
"""
|
||||
accumulated_content = ""
|
||||
final_response = None
|
||||
|
||||
async for llm_response in self.text_chat_stream(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
model=model,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.completion_text:
|
||||
accumulated_content += llm_response.completion_text
|
||||
else:
|
||||
final_response = llm_response
|
||||
|
||||
if final_response:
|
||||
return final_response
|
||||
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
else:
|
||||
return LLMResponse(role="assistant", completion_text="")
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话接口"""
|
||||
# 用户ID参数(参考文档, 可以自定义)
|
||||
user_id = session_id or kwargs.get("user", "default_user")
|
||||
|
||||
# 获取或创建会话ID
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
# 构建消息
|
||||
additional_messages = []
|
||||
|
||||
if system_prompt:
|
||||
if not self.auto_save_history or not conversation_id:
|
||||
additional_messages.append(
|
||||
{"role": "system", "content": system_prompt, "content_type": "text"}
|
||||
)
|
||||
|
||||
if not self.auto_save_history and contexts:
|
||||
# 如果关闭了自动保存历史,传入上下文
|
||||
for ctx in contexts:
|
||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||
content = ctx["content"]
|
||||
content_type = ctx.get("content_type", "text")
|
||||
|
||||
# 处理可能包含图片的上下文
|
||||
if (
|
||||
content_type == "object_string"
|
||||
or (isinstance(content, str) and content.startswith("["))
|
||||
or (
|
||||
isinstance(content, list)
|
||||
and any(
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "image_url"
|
||||
for item in content
|
||||
)
|
||||
)
|
||||
):
|
||||
processed_content = await self._process_context_images(
|
||||
content, user_id
|
||||
)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": processed_content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": (
|
||||
content
|
||||
if isinstance(content, str)
|
||||
else json.dumps(content, ensure_ascii=False)
|
||||
),
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
|
||||
|
||||
if prompt or image_urls:
|
||||
if image_urls:
|
||||
# 多模态
|
||||
object_string_content = []
|
||||
if prompt:
|
||||
object_string_content.append({"type": "text", "text": prompt})
|
||||
|
||||
for url in image_urls:
|
||||
try:
|
||||
if url.startswith(("http://", "https://")):
|
||||
# 网络图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
url, user_id
|
||||
)
|
||||
else:
|
||||
# 本地文件或 base64
|
||||
if url.startswith("data:image/"):
|
||||
# base64
|
||||
_, encoded = url.split(",", 1)
|
||||
image_data = base64.b64decode(encoded)
|
||||
cache_key = self._generate_cache_key(
|
||||
url, is_base64=True
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data, user_id, cache_key
|
||||
)
|
||||
else:
|
||||
# 本地文件
|
||||
if os.path.exists(url):
|
||||
with open(url, "rb") as f:
|
||||
image_data = f.read()
|
||||
# 用文件路径和修改时间来缓存
|
||||
file_stat = os.stat(url)
|
||||
cache_key = self._generate_cache_key(
|
||||
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
|
||||
is_base64=False,
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data, user_id, cache_key
|
||||
)
|
||||
else:
|
||||
logger.warning(f"图片文件不存在: {url}")
|
||||
continue
|
||||
|
||||
object_string_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {url}: {str(e)}")
|
||||
continue
|
||||
|
||||
if object_string_content:
|
||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本
|
||||
if prompt:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
accumulated_content = ""
|
||||
message_started = False
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
if event_type == "conversation.chat.created":
|
||||
if isinstance(data, dict) and "conversation_id" in data:
|
||||
self.conversation_ids[user_id] = data["conversation_id"]
|
||||
|
||||
elif event_type == "conversation.message.delta":
|
||||
if isinstance(data, dict):
|
||||
content = data.get("content", "")
|
||||
if not content and "delta" in data:
|
||||
content = data["delta"].get("content", "")
|
||||
if not content and "text" in data:
|
||||
content = data.get("text", "")
|
||||
|
||||
if content:
|
||||
message_started = True
|
||||
accumulated_content += content
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=content,
|
||||
is_chunk=True,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.message.completed":
|
||||
if isinstance(data, dict):
|
||||
msg_type = data.get("type")
|
||||
if msg_type == "answer" and data.get("role") == "assistant":
|
||||
final_content = data.get("content", "")
|
||||
if not accumulated_content and final_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(final_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.chat.completed":
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
elif event_type == "done":
|
||||
break
|
||||
|
||||
elif event_type == "error":
|
||||
error_msg = (
|
||||
data.get("message", "未知错误")
|
||||
if isinstance(data, dict)
|
||||
else str(data)
|
||||
)
|
||||
logger.error(f"Coze 流式响应错误: {error_msg}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 错误: {error_msg}",
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
if not message_started and not accumulated_content:
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="LLM 未响应任何内容。",
|
||||
is_chunk=False,
|
||||
)
|
||||
elif message_started and accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 流式请求失败: {str(e)}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 流式请求失败: {str(e)}",
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
async def forget(self, session_id: str):
|
||||
"""清空指定会话的上下文"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if user_id in self.file_id_cache:
|
||||
self.file_id_cache.pop(user_id, None)
|
||||
|
||||
if not conversation_id:
|
||||
return True
|
||||
|
||||
try:
|
||||
response = await self.api_client.clear_context(conversation_id)
|
||||
|
||||
if "code" in response and response["code"] == 0:
|
||||
self.conversation_ids.pop(user_id, None)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"清空 Coze 会话上下文失败: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空 Coze 会话失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_current_key(self):
|
||||
"""获取当前API Key"""
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key: str):
|
||||
"""设置新的API Key"""
|
||||
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
"""获取可用模型列表"""
|
||||
return [f"bot_{self.bot_id}"]
|
||||
|
||||
def get_model(self):
|
||||
"""获取当前模型"""
|
||||
return f"bot_{self.bot_id}"
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型(在Coze中是Bot ID)"""
|
||||
if model.startswith("bot_"):
|
||||
self.bot_id = model[4:]
|
||||
else:
|
||||
self.bot_id = model
|
||||
|
||||
async def get_human_readable_context(
|
||||
self, session_id: str, page: int = 1, page_size: int = 10
|
||||
):
|
||||
"""获取人类可读的上下文历史"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if not conversation_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await self.api_client.get_message_list(
|
||||
conversation_id=conversation_id,
|
||||
order="desc",
|
||||
limit=page_size,
|
||||
offset=(page - 1) * page_size,
|
||||
)
|
||||
|
||||
if data.get("code") != 0:
|
||||
logger.warning(f"获取 Coze 消息历史失败: {data}")
|
||||
return []
|
||||
|
||||
messages = data.get("data", {}).get("messages", [])
|
||||
|
||||
readable_history = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
msg_type = msg.get("type", "")
|
||||
|
||||
if role == "user":
|
||||
readable_history.append(f"用户: {content}")
|
||||
elif role == "assistant" and msg_type == "answer":
|
||||
readable_history.append(f"助手: {content}")
|
||||
|
||||
return readable_history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def terminate(self):
|
||||
"""清理资源"""
|
||||
await self.api_client.close()
|
||||
@@ -1,12 +1,12 @@
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
# This file was originally created to adapt to glm-4v-flash, which only supports one image in the context.
|
||||
# It is no longer specifically adapted to Zhipu's models. To ensure compatibility, this
|
||||
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
|
||||
@register_provider_adapter("zhipu_chat_completion", "智谱 Chat Completion 提供商适配器")
|
||||
class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -19,63 +19,3 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
context_query = [*contexts, new_record]
|
||||
|
||||
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||
model = model or self.get_model()
|
||||
# glm-4v-flash 只支持一张图片
|
||||
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
|
||||
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
||||
logger.debug(context_query)
|
||||
new_context_query_ = []
|
||||
for i in range(0, len(context_query) - 1, 2):
|
||||
if isinstance(context_query[i].get("content", ""), list):
|
||||
continue
|
||||
new_context_query_.append(context_query[i])
|
||||
new_context_query_.append(context_query[i + 1])
|
||||
new_context_query_.append(context_query[-1]) # 保留最后一条记录
|
||||
context_query = new_context_query_
|
||||
logger.debug(context_query)
|
||||
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
payloads = {"messages": context_query, **model_cfgs}
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(
|
||||
f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。"
|
||||
)
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import re
|
||||
import inspect
|
||||
import types
|
||||
import typing
|
||||
from typing import List, Any, Type, Dict
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -14,6 +16,18 @@ class GreedyStr(str):
|
||||
pass
|
||||
|
||||
|
||||
def unwrap_optional(annotation) -> tuple:
|
||||
"""去掉 Optional[T] / Union[T, None] / T|None,返回 T"""
|
||||
args = typing.get_args(annotation)
|
||||
non_none_args = [a for a in args if a is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
return (non_none_args[0],)
|
||||
elif len(non_none_args) > 1:
|
||||
return tuple(non_none_args)
|
||||
else:
|
||||
return ()
|
||||
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter):
|
||||
"""标准指令过滤器"""
|
||||
@@ -32,11 +46,16 @@ class CommandFilter(HandlerFilter):
|
||||
self.init_handler_md(handler_md)
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
|
||||
# Cache for complete command names list
|
||||
self._cmpl_cmd_names: list | None = None
|
||||
|
||||
def print_types(self):
|
||||
result = ""
|
||||
for k, v in self.handler_params.items():
|
||||
if isinstance(v, type):
|
||||
result += f"{k}({v.__name__}),"
|
||||
elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union:
|
||||
result += f"{k}({v}),"
|
||||
else:
|
||||
result += f"{k}({type(v).__name__})={v},"
|
||||
result = result.rstrip(",")
|
||||
@@ -92,7 +111,8 @@ class CommandFilter(HandlerFilter):
|
||||
# 没有 GreedyStr 的情况
|
||||
if i >= len(params):
|
||||
if (
|
||||
isinstance(param_type_or_default_val, Type)
|
||||
isinstance(param_type_or_default_val, (Type, types.UnionType))
|
||||
or typing.get_origin(param_type_or_default_val) is typing.Union
|
||||
or param_type_or_default_val is inspect.Parameter.empty
|
||||
):
|
||||
# 是类型
|
||||
@@ -129,13 +149,42 @@ class CommandFilter(HandlerFilter):
|
||||
elif isinstance(param_type_or_default_val, float):
|
||||
result[param_name] = float(params[i])
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
origin = typing.get_origin(param_type_or_default_val)
|
||||
if origin in (typing.Union, types.UnionType):
|
||||
# 注解是联合类型
|
||||
# NOTE: 目前没有处理联合类型嵌套相关的注解写法
|
||||
nn_types = unwrap_optional(param_type_or_default_val)
|
||||
if len(nn_types) == 1:
|
||||
# 只有一个非 NoneType 类型
|
||||
result[param_name] = nn_types[0](params[i])
|
||||
else:
|
||||
# 没有或者有多个非 NoneType 类型,这里我们暂时直接赋值为原始值。
|
||||
# NOTE: 目前还没有做类型校验
|
||||
result[param_name] = params[i]
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"
|
||||
)
|
||||
return result
|
||||
|
||||
def get_complete_command_names(self):
|
||||
if self._cmpl_cmd_names is not None:
|
||||
return self._cmpl_cmd_names
|
||||
self._cmpl_cmd_names = [
|
||||
f"{parent} {cmd}" if parent else cmd
|
||||
for cmd in [self.command_name] + list(self.alias)
|
||||
for parent in self.parent_command_names or [""]
|
||||
]
|
||||
return self._cmpl_cmd_names
|
||||
|
||||
def equals(self, message_str: str) -> bool:
|
||||
for full_cmd in self.get_complete_command_names():
|
||||
if message_str == full_cmd:
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
@@ -145,18 +194,11 @@ class CommandFilter(HandlerFilter):
|
||||
|
||||
# 检查是否以指令开头
|
||||
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
||||
candidates = [self.command_name] + list(self.alias)
|
||||
ok = False
|
||||
for candidate in candidates:
|
||||
for parent_command_name in self.parent_command_names:
|
||||
if parent_command_name:
|
||||
_full = f"{parent_command_name} {candidate}"
|
||||
else:
|
||||
_full = candidate
|
||||
if message_str.startswith(f"{_full} ") or message_str == _full:
|
||||
message_str = message_str[len(_full) :].strip()
|
||||
ok = True
|
||||
break
|
||||
for full_cmd in self.get_complete_command_names():
|
||||
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
|
||||
ok = True
|
||||
message_str = message_str[len(full_cmd) :].strip()
|
||||
if not ok:
|
||||
return False
|
||||
|
||||
|
||||
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
self.parent_group = parent_group
|
||||
|
||||
# Cache for complete command names list
|
||||
self._cmpl_cmd_names: list | None = None
|
||||
|
||||
def add_sub_command_filter(
|
||||
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
|
||||
):
|
||||
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
|
||||
"""遍历父节点获取完整的指令名。
|
||||
|
||||
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
|
||||
if self._cmpl_cmd_names is not None:
|
||||
return self._cmpl_cmd_names
|
||||
|
||||
parent_cmd_names = (
|
||||
self.parent_group.get_complete_command_names() if self.parent_group else []
|
||||
)
|
||||
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
for parent_cmd_name in parent_cmd_names:
|
||||
for candidate in candidates:
|
||||
result.append(parent_cmd_name + " " + candidate)
|
||||
self._cmpl_cmd_names = result
|
||||
return result
|
||||
|
||||
# 以树的形式打印出来
|
||||
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
|
||||
return False
|
||||
return True
|
||||
|
||||
def startswith(self, message_str: str) -> bool:
|
||||
return message_str.startswith(tuple(self.get_complete_command_names()))
|
||||
|
||||
def equals(self, message_str: str) -> bool:
|
||||
return message_str in self.get_complete_command_names()
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
if not self.custom_filter_ok(event, cfg):
|
||||
return False
|
||||
|
||||
complete_command_names = self.get_complete_command_names()
|
||||
if event.message_str.strip() in complete_command_names:
|
||||
if self.equals(event.message_str.strip()):
|
||||
tree = (
|
||||
self.group_name
|
||||
+ "\n"
|
||||
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||
)
|
||||
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
# return event.message_str.startswith(tuple(complete_command_names))
|
||||
return False
|
||||
return self.startswith(event.message_str)
|
||||
|
||||
@@ -205,7 +205,6 @@ def register_command_group(
|
||||
new_group = CommandGroupFilter(command_group_name, alias)
|
||||
|
||||
def decorator(obj):
|
||||
# 根指令组
|
||||
if new_group:
|
||||
handler_md = get_handler_or_create(
|
||||
obj, EventType.AdapterMessageEvent, **kwargs
|
||||
@@ -213,6 +212,7 @@ def register_command_group(
|
||||
handler_md.event_filters.append(new_group)
|
||||
|
||||
return RegisteringCommandable(new_group)
|
||||
raise ValueError("注册指令组失败。")
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -220,9 +220,11 @@ def register_command_group(
|
||||
class RegisteringCommandable:
|
||||
"""用于指令组级联注册"""
|
||||
|
||||
group: CommandGroupFilter = register_command_group
|
||||
command: CommandFilter = register_command
|
||||
custom_filter = register_custom_filter
|
||||
group: Callable[..., Callable[..., "RegisteringCommandable"]] = (
|
||||
register_command_group
|
||||
)
|
||||
command: Callable[..., Callable[..., None]] = register_command
|
||||
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
|
||||
|
||||
def __init__(self, parent_group: CommandGroupFilter):
|
||||
self.parent_group = parent_group
|
||||
|
||||
@@ -52,10 +52,6 @@ class SessionServiceManager:
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
@@ -6,7 +6,7 @@ class CommandTokens:
|
||||
self.tokens = []
|
||||
self.len = 0
|
||||
|
||||
def get(self, idx: int):
|
||||
def get(self, idx: int) -> str | None:
|
||||
if idx >= self.len:
|
||||
return None
|
||||
return self.tokens[idx].strip()
|
||||
|
||||
@@ -1,9 +1,33 @@
|
||||
import codecs
|
||||
import json
|
||||
from astrbot.core import logger
|
||||
from aiohttp import ClientSession
|
||||
from aiohttp import ClientSession, ClientResponse
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
|
||||
async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
|
||||
decoder = codecs.getincrementaldecoder("utf-8")()
|
||||
buffer = ""
|
||||
async for chunk in resp.content.iter_chunked(8192):
|
||||
buffer += decoder.decode(chunk)
|
||||
while "\n\n" in buffer:
|
||||
block, buffer = buffer.split("\n\n", 1)
|
||||
if block.strip().startswith("data:"):
|
||||
try:
|
||||
yield json.loads(block[5:])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Drop invalid dify json data: {block[5:]}")
|
||||
continue
|
||||
# flush any remaining text
|
||||
buffer += decoder.decode(b"", final=True)
|
||||
if buffer.strip().startswith("data:"):
|
||||
try:
|
||||
yield json.loads(buffer[5:])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Drop invalid dify json data: {buffer[5:]}")
|
||||
pass
|
||||
|
||||
|
||||
class DifyAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
||||
self.api_key = api_key
|
||||
@@ -33,31 +57,11 @@ class DifyAPIClient:
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk.decode("utf-8")
|
||||
blocks = buffer.split("\n\n")
|
||||
|
||||
# 处理完整的数据块
|
||||
for block in blocks[:-1]:
|
||||
if block.strip() and block.startswith("data:"):
|
||||
try:
|
||||
json_str = block[5:] # 移除 "data:" 前缀
|
||||
json_obj = json.loads(json_str)
|
||||
yield json_obj
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}")
|
||||
logger.error(f"原始数据块: {json_str}")
|
||||
|
||||
# 保留最后一个可能不完整的块
|
||||
buffer = blocks[-1] if blocks else ""
|
||||
raise Exception(
|
||||
f"Dify /chat-messages 接口请求失败:{resp.status}. {text}"
|
||||
)
|
||||
async for event in _stream_sse(resp):
|
||||
yield event
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
@@ -77,31 +81,11 @@ class DifyAPIClient:
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk.decode("utf-8")
|
||||
blocks = buffer.split("\n\n")
|
||||
|
||||
# 处理完整的数据块
|
||||
for block in blocks[:-1]:
|
||||
if block.strip() and block.startswith("data:"):
|
||||
try:
|
||||
json_str = block[5:] # 移除 "data:" 前缀
|
||||
json_obj = json.loads(json_str)
|
||||
yield json_obj
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}")
|
||||
logger.error(f"原始数据块: {json_str}")
|
||||
|
||||
# 保留最后一个可能不完整的块
|
||||
buffer = blocks[-1] if blocks else ""
|
||||
raise Exception(
|
||||
f"Dify /workflows/run 接口请求失败:{resp.status}. {text}"
|
||||
)
|
||||
async for event in _stream_sse(resp):
|
||||
yield event
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
@@ -109,12 +93,15 @@ class DifyAPIClient:
|
||||
user: str,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.api_base}/files/upload"
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": open(file_path, "rb"),
|
||||
}
|
||||
async with self.session.post(url, data=payload, headers=self.headers) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
with open(file_path, "rb") as f:
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": f,
|
||||
}
|
||||
async with self.session.post(
|
||||
url, data=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
|
||||
@@ -1,17 +1,27 @@
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from quart import request, Response as QuartResponse, g, make_response
|
||||
from astrbot.core.db import BaseDatabase
|
||||
import asyncio
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def track_conversation(convs: dict, conv_id: str):
|
||||
convs[conv_id] = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
convs.pop(conv_id, None)
|
||||
|
||||
|
||||
class ChatRoute(Route):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -40,6 +50,8 @@ class ChatRoute(Route):
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
|
||||
self.running_convs: dict[str, bool] = {}
|
||||
|
||||
async def get_file(self):
|
||||
filename = request.args.get("filename")
|
||||
if not filename:
|
||||
@@ -139,42 +151,63 @@ class ChatRoute(Route):
|
||||
)
|
||||
|
||||
async def stream():
|
||||
client_disconnected = False
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
async with track_conversation(self.running_convs, webchat_conv_id):
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||
client_disconnected = True
|
||||
except Exception as e:
|
||||
logger.error(f"WebChat stream error: {e}")
|
||||
|
||||
if not result:
|
||||
continue
|
||||
if not result:
|
||||
continue
|
||||
|
||||
result_text = result["data"]
|
||||
type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.05)
|
||||
result_text = result["data"]
|
||||
type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
|
||||
if type == "end":
|
||||
break
|
||||
elif (
|
||||
(streaming and type == "complete")
|
||||
or not streaming
|
||||
or type == "break"
|
||||
):
|
||||
# append bot message
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
)
|
||||
try:
|
||||
if not client_disconnected:
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
if not client_disconnected:
|
||||
logger.debug(
|
||||
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
|
||||
)
|
||||
client_disconnected = True
|
||||
|
||||
except BaseException as _:
|
||||
logger.debug(f"用户 {username} 断开聊天长连接。")
|
||||
return
|
||||
try:
|
||||
if not client_disconnected:
|
||||
await asyncio.sleep(0.05)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||
client_disconnected = True
|
||||
|
||||
if type == "end":
|
||||
break
|
||||
elif (
|
||||
(streaming and type == "complete")
|
||||
or not streaming
|
||||
or type == "break"
|
||||
):
|
||||
# append bot message
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||
|
||||
# Put message to conversation-specific queue
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
||||
@@ -291,6 +324,7 @@ class ChatRoute(Route):
|
||||
.ok(
|
||||
data={
|
||||
"history": history_res,
|
||||
"is_running": self.running_convs.get(webchat_conv_id, False),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import typing
|
||||
import traceback
|
||||
import os
|
||||
import inspect
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
@@ -13,10 +14,10 @@ from astrbot.core.config.default import (
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
from astrbot.core.platform.register import platform_registry, platform_cls_map
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, file_token_service
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
import asyncio
|
||||
@@ -51,24 +52,6 @@ def validate_config(
|
||||
def validate(data: dict, metadata: dict = schema, path=""):
|
||||
for key, value in data.items():
|
||||
if key not in metadata:
|
||||
# 无 schema 的配置项,执行类型猜测
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
data[key] = int(value)
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
data[key] = float(value)
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if value.lower() == "true":
|
||||
data[key] = True
|
||||
elif value.lower() == "false":
|
||||
data[key] = False
|
||||
continue
|
||||
meta = metadata[key]
|
||||
if "type" not in meta:
|
||||
@@ -127,12 +110,12 @@ def validate_config(
|
||||
)
|
||||
|
||||
if is_core:
|
||||
for key, group in schema.items():
|
||||
group_meta = group.get("metadata")
|
||||
if not group_meta:
|
||||
continue
|
||||
# logger.info(f"验证配置: 组 {key} ...")
|
||||
validate(data, group_meta, path=f"{key}.")
|
||||
meta_all = {
|
||||
**schema["platform_group"]["metadata"],
|
||||
**schema["provider_group"]["metadata"],
|
||||
**schema["misc_config_group"]["metadata"],
|
||||
}
|
||||
validate(data, meta_all)
|
||||
else:
|
||||
validate(data, schema)
|
||||
|
||||
@@ -142,6 +125,7 @@ def validate_config(
|
||||
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
|
||||
"""验证并保存配置"""
|
||||
errors = None
|
||||
logger.info(f"Saving config, is_core={is_core}")
|
||||
try:
|
||||
if is_core:
|
||||
errors, post_config = validate_config(
|
||||
@@ -166,6 +150,7 @@ class ConfigRoute(Route):
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config: AstrBotConfig = core_lifecycle.astrbot_config
|
||||
self._logo_token_cache = {} # 缓存logo token,避免重复注册
|
||||
self.acm = core_lifecycle.astrbot_config_mgr
|
||||
self.routes = {
|
||||
"/config/abconf/new": ("POST", self.create_abconf),
|
||||
@@ -672,6 +657,78 @@ class ConfigRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "删除成功,已经实时生效~").__dict__
|
||||
|
||||
async def get_llm_tools(self):
|
||||
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
|
||||
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
tools = tool_mgr.get_func_desc_openai_style()
|
||||
return Response().ok(tools).__dict__
|
||||
|
||||
async def _register_platform_logo(self, platform, platform_default_tmpl):
|
||||
"""注册平台logo文件并生成访问令牌"""
|
||||
if not platform.logo_path:
|
||||
return
|
||||
|
||||
try:
|
||||
# 检查缓存
|
||||
cache_key = f"{platform.name}:{platform.logo_path}"
|
||||
if cache_key in self._logo_token_cache:
|
||||
cached_token = self._logo_token_cache[cache_key]
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl:
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
elif not isinstance(platform_default_tmpl[platform.name], dict):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
platform_default_tmpl[platform.name]["logo_token"] = cached_token
|
||||
logger.debug(f"Using cached logo token for platform {platform.name}")
|
||||
return
|
||||
|
||||
# 获取平台适配器类
|
||||
platform_cls = platform_cls_map.get(platform.name)
|
||||
if not platform_cls:
|
||||
logger.warning(f"Platform class not found for {platform.name}")
|
||||
return
|
||||
|
||||
# 获取插件目录路径
|
||||
module_file = inspect.getfile(platform_cls)
|
||||
plugin_dir = os.path.dirname(module_file)
|
||||
|
||||
# 解析logo文件路径
|
||||
logo_file_path = os.path.join(plugin_dir, platform.logo_path)
|
||||
|
||||
# 检查文件是否存在并注册令牌
|
||||
if os.path.exists(logo_file_path):
|
||||
logo_token = await file_token_service.register_file(
|
||||
logo_file_path, timeout=3600
|
||||
)
|
||||
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl:
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
elif not isinstance(platform_default_tmpl[platform.name], dict):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
|
||||
platform_default_tmpl[platform.name]["logo_token"] = logo_token
|
||||
|
||||
# 缓存token
|
||||
self._logo_token_cache[cache_key] = logo_token
|
||||
|
||||
logger.debug(f"Logo token registered for platform {platform.name}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Platform {platform.name} logo file not found: {logo_file_path}"
|
||||
)
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to import required modules for platform {platform.name}: {e}"
|
||||
)
|
||||
except (OSError, IOError) as e:
|
||||
logger.warning(f"File system error for platform {platform.name} logo: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unexpected error registering logo for platform {platform.name}: {e}"
|
||||
)
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config
|
||||
|
||||
@@ -679,9 +736,21 @@ class ConfigRoute(Route):
|
||||
platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][
|
||||
"platform"
|
||||
]["config_template"]
|
||||
|
||||
# 收集需要注册logo的平台
|
||||
logo_registration_tasks = []
|
||||
for platform in platform_registry:
|
||||
if platform.default_config_tmpl:
|
||||
platform_default_tmpl[platform.name] = platform.default_config_tmpl
|
||||
# 收集logo注册任务
|
||||
if platform.logo_path:
|
||||
logo_registration_tasks.append(
|
||||
self._register_platform_logo(platform, platform_default_tmpl)
|
||||
)
|
||||
|
||||
# 并行执行logo注册
|
||||
if logo_registration_tasks:
|
||||
await asyncio.gather(*logo_registration_tasks, return_exceptions=True)
|
||||
|
||||
# 服务提供商的默认配置模板注入
|
||||
provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][
|
||||
|
||||
@@ -169,15 +169,65 @@ class ConversationRoute(Route):
|
||||
"""删除对话"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
user_id = data.get("user_id")
|
||||
cid = data.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||
unified_msg_origin=user_id, conversation_id=cid
|
||||
)
|
||||
return Response().ok({"message": "对话删除成功"}).__dict__
|
||||
# 检查是否是批量删除
|
||||
if "conversations" in data:
|
||||
# 批量删除
|
||||
conversations = data.get("conversations", [])
|
||||
if not conversations:
|
||||
return (
|
||||
Response().error("批量删除时conversations参数不能为空").__dict__
|
||||
)
|
||||
|
||||
deleted_count = 0
|
||||
failed_items = []
|
||||
|
||||
for conv in conversations:
|
||||
user_id = conv.get("user_id")
|
||||
cid = conv.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
failed_items.append(
|
||||
f"user_id:{user_id}, cid:{cid} - 缺少必要参数"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||
unified_msg_origin=user_id, conversation_id=cid
|
||||
)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - {str(e)}")
|
||||
|
||||
message = f"成功删除 {deleted_count} 个对话"
|
||||
if failed_items:
|
||||
message += f",失败 {len(failed_items)} 个"
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": message,
|
||||
"deleted_count": deleted_count,
|
||||
"failed_count": len(failed_items),
|
||||
"failed_items": failed_items,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
# 单个删除
|
||||
user_id = data.get("user_id")
|
||||
cid = data.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
|
||||
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||
unified_msg_origin=user_id, conversation_id=cid
|
||||
)
|
||||
return Response().ok({"message": "对话删除成功"}).__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")
|
||||
|
||||
@@ -20,6 +20,7 @@ class SessionManagementRoute(Route):
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.db_helper = db_helper
|
||||
self.routes = {
|
||||
"/session/list": ("GET", self.list_sessions),
|
||||
"/session/update_persona": ("POST", self.update_session_persona),
|
||||
@@ -30,6 +31,7 @@ class SessionManagementRoute(Route):
|
||||
"/session/update_tts": ("POST", self.update_session_tts),
|
||||
"/session/update_name": ("POST", self.update_session_name),
|
||||
"/session/update_status": ("POST", self.update_session_status),
|
||||
"/session/delete": ("POST", self.delete_session),
|
||||
}
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.core_lifecycle = core_lifecycle
|
||||
@@ -38,22 +40,42 @@ class SessionManagementRoute(Route):
|
||||
async def list_sessions(self):
|
||||
"""获取所有会话的列表,包括 persona 和 provider 信息"""
|
||||
try:
|
||||
preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[])
|
||||
session_conversations = {}
|
||||
for pref in preferences:
|
||||
session_conversations[pref.scope_id] = pref.value["val"]
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 20))
|
||||
search_query = request.args.get("search", "")
|
||||
platform = request.args.get("platform", "")
|
||||
|
||||
# 获取活跃的会话数据(处于对话内的会话)
|
||||
sessions_data, total = await self.db_helper.get_session_conversations(
|
||||
page, page_size, search_query, platform
|
||||
)
|
||||
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
persona_mgr = self.core_lifecycle.persona_mgr
|
||||
personas = persona_mgr.personas_v3
|
||||
|
||||
sessions = []
|
||||
|
||||
# 构建会话信息
|
||||
for session_id, conversation_id in session_conversations.items():
|
||||
# 循环补充非数据库信息,如 provider 和 session 状态
|
||||
for data in sessions_data:
|
||||
session_id = data["session_id"]
|
||||
conversation_id = data["conversation_id"]
|
||||
conv_persona_id = data["persona_id"]
|
||||
title = data["title"]
|
||||
persona_name = data["persona_name"]
|
||||
|
||||
# 处理 persona 显示
|
||||
if conv_persona_id == "[%None]":
|
||||
persona_name = "无人格"
|
||||
else:
|
||||
default_persona = persona_mgr.selected_default_persona_v3
|
||||
if default_persona:
|
||||
persona_name = default_persona["name"]
|
||||
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_id": persona_name,
|
||||
"chat_provider_id": None,
|
||||
"stt_provider_id": None,
|
||||
"tts_provider_id": None,
|
||||
@@ -78,31 +100,10 @@ class SessionManagementRoute(Route):
|
||||
"session_raw_name": session_id.split(":")[2]
|
||||
if session_id.count(":") >= 2
|
||||
else session_id,
|
||||
"title": title,
|
||||
}
|
||||
|
||||
# 获取对话信息
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=session_id, conversation_id=conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
|
||||
# 查找 persona 名称
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_id"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_id"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = persona_mgr.selected_default_persona_v3
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
chat_provider = provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.CHAT_COMPLETION, umo=session_id
|
||||
)
|
||||
@@ -171,6 +172,14 @@ class SessionManagementRoute(Route):
|
||||
"available_chat_providers": available_chat_providers,
|
||||
"available_stt_providers": available_stt_providers,
|
||||
"available_tts_providers": available_tts_providers,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"total_pages": (total + page_size - 1) // page_size
|
||||
if page_size > 0
|
||||
else 0,
|
||||
},
|
||||
}
|
||||
|
||||
return Response().ok(result).__dict__
|
||||
@@ -180,60 +189,132 @@ class SessionManagementRoute(Route):
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话列表失败: {str(e)}").__dict__
|
||||
|
||||
async def _update_single_session_persona(self, session_id: str, persona_name: str):
|
||||
"""更新单个会话的 persona 的内部方法"""
|
||||
conversation_manager = self.core_lifecycle.star_context.conversation_manager
|
||||
conversation_id = await conversation_manager.get_curr_conversation_id(
|
||||
session_id
|
||||
)
|
||||
|
||||
conv = None
|
||||
if conversation_id:
|
||||
conv = await conversation_manager.get_conversation(
|
||||
unified_msg_origin=session_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if not conv or not conversation_id:
|
||||
conversation_id = await conversation_manager.new_conversation(session_id)
|
||||
|
||||
# 更新 persona
|
||||
await conversation_manager.update_conversation_persona_id(
|
||||
session_id, persona_name
|
||||
)
|
||||
|
||||
async def _handle_batch_operation(
|
||||
self, session_ids: list, operation_func, operation_name: str, **kwargs
|
||||
):
|
||||
"""通用的批量操作处理方法"""
|
||||
success_count = 0
|
||||
error_sessions = []
|
||||
|
||||
for session_id in session_ids:
|
||||
try:
|
||||
await operation_func(session_id, **kwargs)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"批量{operation_name} 会话 {session_id} 失败: {str(e)}")
|
||||
error_sessions.append(session_id)
|
||||
|
||||
if error_sessions:
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"批量更新完成,成功: {success_count},失败: {len(error_sessions)}",
|
||||
"success_count": success_count,
|
||||
"error_count": len(error_sessions),
|
||||
"error_sessions": error_sessions,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功批量{operation_name} {success_count} 个会话",
|
||||
"success_count": success_count,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def update_session_persona(self):
|
||||
"""更新指定会话的 persona"""
|
||||
"""更新指定会话的 persona,支持批量操作"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
is_batch = data.get("is_batch", False)
|
||||
persona_name = data.get("persona_name")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if persona_name is None:
|
||||
return Response().error("缺少必要参数: persona_name").__dict__
|
||||
|
||||
# 获取会话当前的对话 ID
|
||||
conversation_manager = self.core_lifecycle.star_context.conversation_manager
|
||||
conversation_id = await conversation_manager.get_curr_conversation_id(
|
||||
session_id
|
||||
)
|
||||
if is_batch:
|
||||
session_ids = data.get("session_ids", [])
|
||||
if not session_ids:
|
||||
return Response().error("缺少必要参数: session_ids").__dict__
|
||||
|
||||
if not conversation_id:
|
||||
# 如果没有对话,创建一个新的对话
|
||||
conversation_id = await conversation_manager.new_conversation(
|
||||
session_id
|
||||
return await self._handle_batch_operation(
|
||||
session_ids,
|
||||
self._update_single_session_persona,
|
||||
"更新人格",
|
||||
persona_name=persona_name,
|
||||
)
|
||||
else:
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 更新 persona
|
||||
await conversation_manager.update_conversation_persona_id(
|
||||
session_id, persona_name
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"})
|
||||
.__dict__
|
||||
)
|
||||
await self._update_single_session_persona(session_id, persona_name)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话人格失败: {str(e)}").__dict__
|
||||
|
||||
async def _update_single_session_provider(
|
||||
self, session_id: str, provider_id: str, provider_type_enum
|
||||
):
|
||||
"""更新单个会话的 provider 的内部方法"""
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
await provider_manager.set_provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type_enum,
|
||||
umo=session_id,
|
||||
)
|
||||
|
||||
async def update_session_provider(self):
|
||||
"""更新指定会话的 provider"""
|
||||
"""更新指定会话的 provider,支持批量操作"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
is_batch = data.get("is_batch", False)
|
||||
provider_id = data.get("provider_id")
|
||||
# "chat_completion", "speech_to_text", "text_to_speech"
|
||||
provider_type = data.get("provider_type")
|
||||
|
||||
if not session_id or not provider_id or not provider_type:
|
||||
if not provider_id or not provider_type:
|
||||
return (
|
||||
Response()
|
||||
.error("缺少必要参数: session_id, provider_id, provider_type")
|
||||
.error("缺少必要参数: provider_id, provider_type")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
@@ -251,23 +332,35 @@ class SessionManagementRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 设置 provider
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
await provider_manager.set_provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type_enum,
|
||||
umo=session_id,
|
||||
)
|
||||
if is_batch:
|
||||
session_ids = data.get("session_ids", [])
|
||||
if not session_ids:
|
||||
return Response().error("缺少必要参数: session_ids").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}"
|
||||
}
|
||||
return await self._handle_batch_operation(
|
||||
session_ids,
|
||||
self._update_single_session_provider,
|
||||
f"更新 {provider_type} 提供商",
|
||||
provider_id=provider_id,
|
||||
provider_type_enum=provider_type_enum,
|
||||
)
|
||||
else:
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
await self._update_single_session_provider(
|
||||
session_id, provider_id, provider_type_enum
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}"
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}"
|
||||
@@ -376,66 +469,98 @@ class SessionManagementRoute(Route):
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
|
||||
|
||||
async def _update_single_session_llm(self, session_id: str, enabled: bool):
|
||||
"""更新单个会话的LLM状态的内部方法"""
|
||||
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
|
||||
|
||||
async def update_session_llm(self):
|
||||
"""更新指定会话的LLM启停状态"""
|
||||
"""更新指定会话的LLM启停状态,支持批量操作"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
is_batch = data.get("is_batch", False)
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新LLM状态
|
||||
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
|
||||
if is_batch:
|
||||
session_ids = data.get("session_ids", [])
|
||||
if not session_ids:
|
||||
return Response().error("缺少必要参数: session_ids").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"LLM已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"llm_enabled": enabled,
|
||||
}
|
||||
result = await self._handle_batch_operation(
|
||||
session_ids,
|
||||
self._update_single_session_llm,
|
||||
f"{'启用' if enabled else '禁用'}LLM",
|
||||
enabled=enabled,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
await self._update_single_session_llm(session_id, enabled)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"LLM已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"llm_enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
|
||||
|
||||
async def _update_single_session_tts(self, session_id: str, enabled: bool):
|
||||
"""更新单个会话的TTS状态的内部方法"""
|
||||
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
|
||||
|
||||
async def update_session_tts(self):
|
||||
"""更新指定会话的TTS启停状态"""
|
||||
"""更新指定会话的TTS启停状态,支持批量操作"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
is_batch = data.get("is_batch", False)
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新TTS状态
|
||||
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
|
||||
if is_batch:
|
||||
session_ids = data.get("session_ids", [])
|
||||
if not session_ids:
|
||||
return Response().error("缺少必要参数: session_ids").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"TTS已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"tts_enabled": enabled,
|
||||
}
|
||||
result = await self._handle_batch_operation(
|
||||
session_ids,
|
||||
self._update_single_session_tts,
|
||||
f"{'启用' if enabled else '禁用'}TTS",
|
||||
enabled=enabled,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
await self._update_single_session_tts(session_id, enabled)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"TTS已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"tts_enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话TTS状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
@@ -507,3 +632,43 @@ class SessionManagementRoute(Route):
|
||||
error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__
|
||||
|
||||
async def delete_session(self):
|
||||
"""删除指定会话及其所有相关数据"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 删除会话的所有相关数据
|
||||
conversation_manager = self.core_lifecycle.conversation_manager
|
||||
|
||||
# 1. 删除会话的所有对话
|
||||
try:
|
||||
await conversation_manager.delete_conversations_by_user_id(session_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"删除会话 {session_id} 的对话失败: {str(e)}")
|
||||
|
||||
# 2. 清除会话的偏好设置数据(清空该会话的所有配置)
|
||||
try:
|
||||
await sp.clear_async("umo", session_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"清除会话 {session_id} 的偏好设置失败: {str(e)}")
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"会话 {session_id} 及其相关所有对话数据已成功删除",
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"删除会话失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"删除会话失败: {str(e)}").__dict__
|
||||
|
||||
@@ -9,6 +9,8 @@ from astrbot.core.config.default import VERSION
|
||||
from astrbot.core import DEMO_MODE
|
||||
from astrbot.core.db.migration.helper import do_migration_v4, check_migration_needed_v4
|
||||
|
||||
CLEAR_SITE_DATA_HEADERS = {"Clear-Site-Data": '"cache"'}
|
||||
|
||||
|
||||
class UpdateRoute(Route):
|
||||
def __init__(
|
||||
@@ -113,17 +115,19 @@ class UpdateRoute(Route):
|
||||
|
||||
if reboot:
|
||||
await self.core_lifecycle.restart()
|
||||
return (
|
||||
ret = (
|
||||
Response()
|
||||
.ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。")
|
||||
.__dict__
|
||||
)
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
else:
|
||||
return (
|
||||
ret = (
|
||||
Response()
|
||||
.ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。")
|
||||
.__dict__
|
||||
)
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -135,9 +139,8 @@ class UpdateRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
return Response().error(f"下载管理面板文件失败: {e}").__dict__
|
||||
return (
|
||||
Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
|
||||
)
|
||||
ret = Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
8
changelogs/v4.1.7.md
Normal file
8
changelogs/v4.1.7.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# What's Changed
|
||||
|
||||
1. perf: 优化 WebChat 等组件的 UI 风格
|
||||
2. fix: 修复 4.1.6 版本可能无法点击更新按钮的问题
|
||||
3. fix: 修复更新开发版的时候,可能无法同时更新 WebUI 的问题
|
||||
4. feat: 支持在「对话数据」页批量删除对话
|
||||
5. fix: 修复部分错误地显示「格式校验未通过」的问题
|
||||
6. perf: WebChat 支持手动填写模型名称
|
||||
1
changelogs/v4.2.0.md
Normal file
1
changelogs/v4.2.0.md
Normal file
@@ -0,0 +1 @@
|
||||
# What's Changed
|
||||
1
changelogs/v4.2.1.md
Normal file
1
changelogs/v4.2.1.md
Normal file
@@ -0,0 +1 @@
|
||||
# What's Changed
|
||||
14
changelogs/v4.3.0.md
Normal file
14
changelogs/v4.3.0.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# What's Changed
|
||||
|
||||
1. fix: 修复"开启 TTS 时同时输出语音和文字内容"功能不可用的问题 ([#2900](https://github.com/AstrBotDevs/AstrBot/issues/2900))
|
||||
2. feat: 优化了会话管理页的数据查询逻辑,添加分页和搜索功能,大幅度提高响应速度 ([#2906](https://github.com/AstrBotDevs/AstrBot/issues/2906))
|
||||
3. fix: 用 mi-googlesearch-python 库代替失效的 googlesearch-python 库 ([#2909](https://github.com/AstrBotDevs/AstrBot/issues/2909))
|
||||
4. feat: 支持在 Telegram 和飞书下请求 LLM 前预表态功能 ([#2737](https://github.com/AstrBotDevs/AstrBot/issues/2737))
|
||||
5. perf: 对于 Telegram 群聊,将回复机器人的消息视为唤醒机器人 ([#2926](https://github.com/AstrBotDevs/AstrBot/issues/2926))
|
||||
6. feat: 提示词前缀配置项升级为“用户提示词”,支持 `{{prompt}}` 作为用户输入的占位符。
|
||||
7. fix: 增加知识库插件的启用检查,避免部分情况下导致知识库页面白屏的问题。
|
||||
8. fix: 修复接入智谱提供商后,工具调用无限循环的问题,并停止支持 glm-4v-flash ([#2931](https://github.com/AstrBotDevs/AstrBot/issues/2931))
|
||||
9. fix: 修复注册指令组指令时的 Pyright 类型检查提示 ([#2923](https://github.com/AstrBotDevs/AstrBot/issues/2923))
|
||||
10. refactor: 优化 packages/astrbot 内置插件的代码结构以提高可维护性和可读性 ([#2924](https://github.com/AstrBotDevs/AstrBot/issues/2924))
|
||||
11. fix: 修复插件指令注解为联合类型时处理异常的问题 ([#2925](https://github.com/AstrBotDevs/AstrBot/issues/2925))
|
||||
12. feat: 支持注册消息平台适配器的 logo ([#2109](https://github.com/AstrBotDevs/AstrBot/issues/2109))
|
||||
1
changelogs/v4.3.1.md
Normal file
1
changelogs/v4.3.1.md
Normal file
@@ -0,0 +1 @@
|
||||
# What's Changed
|
||||
7
changelogs/v4.3.2.md
Normal file
7
changelogs/v4.3.2.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# What's Changed
|
||||
|
||||
1. fix: 修复 /reset 指令没有清除群聊上下文感知数据的问题 ([#2954](https://github.com/AstrBotDevs/AstrBot/issues/2954))
|
||||
2. fix: 修复自带的 WebSearch 插件可能在部分场景下无法使用的问题
|
||||
3. fix: 发送阶段强行将 Plain 为空的消息段移除
|
||||
4. fix: on_tool_end无法获得工具返回的结果 ([#2956](https://github.com/AstrBotDevs/AstrBot/issues/2956))
|
||||
5. feat: 为插件市场的搜索增加拼音与首字母搜索功能 ([#2936](https://github.com/AstrBotDevs/AstrBot/issues/2936))
|
||||
@@ -27,6 +27,7 @@
|
||||
"lodash": "4.17.21",
|
||||
"marked": "^15.0.7",
|
||||
"markdown-it": "^14.1.0",
|
||||
"pinyin-pro": "^3.26.0",
|
||||
"pinia": "2.1.6",
|
||||
"remixicon": "3.5.0",
|
||||
"vee-validate": "4.11.3",
|
||||
|
||||
@@ -1,7 +1,28 @@
|
||||
<template>
|
||||
<RouterView></RouterView>
|
||||
|
||||
<!-- 全局唯一 snackbar -->
|
||||
<v-snackbar v-if="toastStore.current" v-model="snackbarShow" :color="toastStore.current.color"
|
||||
:timeout="toastStore.current.timeout" :multi-line="toastStore.current.multiLine"
|
||||
:location="toastStore.current.location" close-on-back>
|
||||
{{ toastStore.current.message }}
|
||||
<template #actions v-if="toastStore.current.closable">
|
||||
<v-btn variant="text" @click="snackbarShow = false">关闭</v-btn>
|
||||
</template>
|
||||
</v-snackbar>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
<script setup>
|
||||
import { RouterView } from 'vue-router';
|
||||
import { computed } from 'vue'
|
||||
import { useToastStore } from '@/stores/toast'
|
||||
|
||||
const toastStore = useToastStore()
|
||||
|
||||
const snackbarShow = computed({
|
||||
get: () => !!toastStore.current,
|
||||
set: (val) => {
|
||||
if (!val) toastStore.shift()
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
<div style="display: flex; align-items: center; justify-content: center; padding: 16px; padding-bottom: 0px;"
|
||||
v-if="chatboxMode">
|
||||
<img width="50" src="@/assets/images/astrbot_logo_mini.webp" alt="AstrBot Logo">
|
||||
<span v-if="!sidebarCollapsed" style="font-weight: 1000; font-size: 26px; margin-left: 8px;">AstrBot</span>
|
||||
<span v-if="!sidebarCollapsed"
|
||||
style="font-weight: 1000; font-size: 26px; margin-left: 8px;">AstrBot</span>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -46,7 +47,7 @@
|
||||
|| tm('conversation.newConversation') }}</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="!sidebarCollapsed" class="timestamp">{{
|
||||
formatDate(item.updated_at)
|
||||
}}</v-list-item-subtitle>
|
||||
}}</v-list-item-subtitle>
|
||||
|
||||
<template v-if="!sidebarCollapsed" v-slot:append>
|
||||
<div class="conversation-actions">
|
||||
@@ -118,8 +119,9 @@
|
||||
</div>
|
||||
<v-divider v-if="currCid && getCurrentConversation" class="conversation-divider"></v-divider>
|
||||
|
||||
<MessageList v-if="messages && messages.length > 0" :messages="messages" :isDark="isDark" :isStreaming="isStreaming"
|
||||
@openImagePreview="openImagePreview" ref="messageList" />
|
||||
<MessageList v-if="messages && messages.length > 0" :messages="messages" :isDark="isDark"
|
||||
:isStreaming="isStreaming || isConvRunning" @openImagePreview="openImagePreview"
|
||||
ref="messageList" />
|
||||
<div class="welcome-container fade-in" v-else>
|
||||
<div class="welcome-title">
|
||||
<span>Hello, I'm</span>
|
||||
@@ -145,9 +147,10 @@
|
||||
<!-- 输入区域 -->
|
||||
<div class="input-area fade-in">
|
||||
<div
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px; padding: 4px;">
|
||||
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown" :disabled="isStreaming"
|
||||
@click:clear="clearMessage" placeholder="Ask AstrBot..."
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px;">
|
||||
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown"
|
||||
:disabled="isStreaming || isConvRunning" @click:clear="clearMessage"
|
||||
placeholder="Ask AstrBot..."
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
<div
|
||||
style="display: flex; justify-content: space-between; align-items: center; padding: 0px 8px;">
|
||||
@@ -155,18 +158,21 @@
|
||||
<!-- 选择提供商和模型 -->
|
||||
<ProviderModelSelector ref="providerModelSelector" />
|
||||
</div>
|
||||
<div style="display: flex; justify-content: flex-end; margin-top: 8px;">
|
||||
<div
|
||||
style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;">
|
||||
<input type="file" ref="imageInput" @change="handleFileSelect" accept="image/*"
|
||||
style="display: none" multiple />
|
||||
<v-progress-circular v-if="isStreaming || isConvRunning" indeterminate size="16"
|
||||
class="mr-1" width="1.5" />
|
||||
<v-btn @click="triggerImageInput" icon="mdi-plus" variant="text" color="deep-purple"
|
||||
class="add-btn" size="small" />
|
||||
<v-btn @click="sendMessage" icon="mdi-send" variant="text" color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl"
|
||||
class="send-btn" size="small" />
|
||||
<v-btn @click="isRecording ? stopRecording() : startRecording()"
|
||||
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
||||
:color="isRecording ? 'error' : 'deep-purple'" class="record-btn"
|
||||
size="small" />
|
||||
<v-btn @click="sendMessage" icon="mdi-send" variant="text" color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl"
|
||||
class="send-btn" size="small" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -235,6 +241,7 @@ import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue';
|
||||
import ProviderModelSelector from '@/components/chat/ProviderModelSelector.vue';
|
||||
import MessageList from '@/components/chat/MessageList.vue';
|
||||
import 'highlight.js/styles/github.css';
|
||||
import { useToast } from '@/utils/toast';
|
||||
|
||||
export default {
|
||||
name: 'ChatPage',
|
||||
@@ -301,7 +308,10 @@ export default {
|
||||
imagePreviewDialog: false,
|
||||
previewImageUrl: '',
|
||||
|
||||
isStreaming: false
|
||||
isStreaming: false,
|
||||
isConvRunning: false, // Track if the current conversation is running
|
||||
|
||||
isToastedRunningInfo: false, // To avoid multiple toasts
|
||||
}
|
||||
},
|
||||
|
||||
@@ -379,7 +389,7 @@ export default {
|
||||
} else {
|
||||
this.sidebarCollapsed = true; // 默认折叠状态
|
||||
}
|
||||
|
||||
|
||||
// 设置输入框标签
|
||||
this.inputFieldLabel = this.tm('input.chatPrompt');
|
||||
this.getConversations();
|
||||
@@ -662,6 +672,25 @@ export default {
|
||||
// Update the selected conversation in the sidebar
|
||||
this.selectedConversations = [cid[0]];
|
||||
let history = response.data.data.history;
|
||||
this.isConvRunning = response.data.data.is_running || false;
|
||||
|
||||
if (this.isConvRunning) {
|
||||
if (!this.isToastedRunningInfo) {
|
||||
useToast().info("该对话正在运行中。", { timeout: 5000 });
|
||||
this.isToastedRunningInfo = true;
|
||||
}
|
||||
|
||||
// 如果对话还在运行,3秒后重新获取消息
|
||||
setTimeout(() => {
|
||||
this.getConversationMessages([this.currCid]);
|
||||
}, 3000);
|
||||
}
|
||||
|
||||
// 滚动到底部
|
||||
this.$nextTick(() => {
|
||||
this.$refs.messageList.scrollToBottom();
|
||||
});
|
||||
|
||||
for (let i = 0; i < history.length; i++) {
|
||||
let content = history[i].content;
|
||||
if (content.message.startsWith('[IMAGE]')) {
|
||||
|
||||
@@ -29,12 +29,11 @@
|
||||
|
||||
<!-- Bot Messages -->
|
||||
<div v-else class="bot-message">
|
||||
<div v-if="isStreaming && index === messages.length - 1" style="width: 36px; height: 36px;">
|
||||
<v-progress-circular indeterminate size="28" width="2"
|
||||
style="margin-top: 16px;"></v-progress-circular>
|
||||
</div>
|
||||
<v-avatar v-else class="bot-avatar" size="36">
|
||||
<span class="text-h2">✨</span>
|
||||
|
||||
<v-avatar class="bot-avatar" size="36">
|
||||
<v-progress-circular :index="index" v-if="isStreaming && index === messages.length - 1" indeterminate size="28"
|
||||
width="2"></v-progress-circular>
|
||||
<span v-else-if="messages[index - 1]?.content.type !== 'bot'" class="text-h2">✨</span>
|
||||
</v-avatar>
|
||||
<div class="bot-message-content">
|
||||
<div class="message-bubble bot-bubble">
|
||||
@@ -437,13 +436,13 @@ export default {
|
||||
}
|
||||
|
||||
.message-bubble {
|
||||
padding: 8px 16px;
|
||||
padding: 2px 16px;
|
||||
border-radius: 12px;
|
||||
}
|
||||
|
||||
.user-bubble {
|
||||
color: var(--v-theme-primaryText);
|
||||
padding: 18px 20px;
|
||||
padding: 12px 18px;
|
||||
font-size: 15px;
|
||||
max-width: 60%;
|
||||
border-radius: 1.5rem;
|
||||
@@ -459,7 +458,7 @@ export default {
|
||||
.user-avatar,
|
||||
.bot-avatar {
|
||||
align-self: flex-start;
|
||||
margin-top: 12px;
|
||||
margin-top: 6px;
|
||||
}
|
||||
|
||||
/* 附件样式 */
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
<template>
|
||||
<div>
|
||||
<!-- 选择提供商和模型按钮 -->
|
||||
<v-btn
|
||||
class="text-none"
|
||||
variant="tonal"
|
||||
rounded="xl"
|
||||
size="small"
|
||||
v-if="selectedProviderId && selectedModelName"
|
||||
@click="showDialog = true">
|
||||
<v-btn class="text-none" variant="tonal" rounded="xl" size="small"
|
||||
v-if="selectedProviderId && selectedModelName" @click="openDialog">
|
||||
{{ selectedProviderId }} / {{ selectedModelName }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
variant="tonal"
|
||||
rounded="xl"
|
||||
size="small"
|
||||
v-else
|
||||
@click="showDialog = true">
|
||||
<v-btn variant="tonal" rounded="xl" size="small" v-else @click="openDialog">
|
||||
选择模型
|
||||
</v-btn>
|
||||
|
||||
@@ -33,16 +23,12 @@
|
||||
<h4>提供商</h4>
|
||||
</div>
|
||||
<v-list density="compact" nav class="provider-list">
|
||||
<v-list-item
|
||||
v-for="provider in providerConfigs"
|
||||
:key="provider.id"
|
||||
:value="provider.id"
|
||||
@click="selectProvider(provider)"
|
||||
:active="selectedProviderId === provider.id"
|
||||
rounded="lg"
|
||||
class="provider-item">
|
||||
<v-list-item v-for="provider in providerConfigs" :key="provider.id" :value="provider.id"
|
||||
@click="selectProvider(provider)" :active="tempSelectedProviderId === provider.id"
|
||||
rounded="lg" class="provider-item">
|
||||
<v-list-item-title>{{ provider.id }}</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="provider.api_base">{{ provider.api_base }}</v-list-item-subtitle>
|
||||
<v-list-item-subtitle v-if="provider.api_base">{{ provider.api_base
|
||||
}}</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
<div v-if="providerConfigs.length === 0" class="empty-state">
|
||||
@@ -55,33 +41,28 @@
|
||||
<div class="model-list-panel">
|
||||
<div class="panel-header">
|
||||
<h4>模型</h4>
|
||||
<v-btn
|
||||
v-if="selectedProviderId"
|
||||
icon="mdi-refresh"
|
||||
size="small"
|
||||
variant="text"
|
||||
@click="refreshModels"
|
||||
:loading="loadingModels">
|
||||
<v-btn v-if="tempSelectedProviderId" icon="mdi-refresh" size="small" variant="text"
|
||||
@click="refreshModels" :loading="loadingModels">
|
||||
</v-btn>
|
||||
</div>
|
||||
<v-list density="compact" nav class="model-list" v-if="selectedProviderId">
|
||||
<v-list-item
|
||||
v-for="model in modelList"
|
||||
:key="model"
|
||||
:value="model"
|
||||
@click="selectModel(model)"
|
||||
:active="selectedModelName === model"
|
||||
rounded="lg"
|
||||
<v-list density="compact" nav class="model-list" v-if="tempSelectedProviderId">
|
||||
|
||||
<v-text-field v-model="tempSelectedModelName" placeholder="自定义模型" hide-details solo variant="outlined" density="compact" class="mb-2 mx-2"></v-text-field>
|
||||
|
||||
<v-list-item v-for="model in modelList" :key="model" :value="model"
|
||||
@click="selectModel(model)" :active="tempSelectedModelName === model" rounded="lg"
|
||||
class="model-item">
|
||||
<v-list-item-title>{{ model }}</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="model.description">{{ model.description }}</v-list-item-subtitle>
|
||||
<v-list-item-subtitle v-if="model.description">{{ model.description
|
||||
}}</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
<div v-else class="empty-state">
|
||||
<v-icon icon="mdi-robot-outline" size="large" color="grey-lighten-1"></v-icon>
|
||||
<div class="empty-text">请先选择提供商</div>
|
||||
</div>
|
||||
<div v-if="selectedProviderId && modelList.length === 0 && !loadingModels" class="empty-state">
|
||||
<div v-if="tempSelectedProviderId && modelList.length === 0 && !loadingModels"
|
||||
class="empty-state">
|
||||
<v-icon icon="mdi-robot-off-outline" size="large" color="grey-lighten-1"></v-icon>
|
||||
<div class="empty-text">该提供商暂无可用模型</div>
|
||||
</div>
|
||||
@@ -91,11 +72,8 @@
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="closeDialog" color="grey-darken-1">取消</v-btn>
|
||||
<v-btn
|
||||
text
|
||||
@click="confirmSelection"
|
||||
color="primary"
|
||||
:disabled="!selectedProviderId || !selectedModelName">
|
||||
<v-btn text @click="confirmSelection" color="primary"
|
||||
:disabled="!tempSelectedProviderId || !tempSelectedModelName">
|
||||
确认选择
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
@@ -127,12 +105,17 @@ export default {
|
||||
modelList: [],
|
||||
selectedProviderId: '',
|
||||
selectedModelName: '',
|
||||
// 临时选择状态,用于对话框内的选择
|
||||
tempSelectedProviderId: '',
|
||||
tempSelectedModelName: '',
|
||||
loadingModels: false
|
||||
};
|
||||
},
|
||||
mounted() {
|
||||
// 从localStorage加载保存的选择
|
||||
this.loadFromStorage();
|
||||
// 初始化临时选择
|
||||
this.resetTempSelection();
|
||||
// 获取提供商列表
|
||||
this.loadProviderConfigs();
|
||||
// 如果有保存的选择,加载对应的模型列表
|
||||
@@ -145,13 +128,13 @@ export default {
|
||||
loadFromStorage() {
|
||||
const savedProvider = localStorage.getItem('selectedProvider');
|
||||
const savedModel = localStorage.getItem('selectedModel');
|
||||
|
||||
|
||||
if (savedProvider) {
|
||||
this.selectedProviderId = savedProvider;
|
||||
} else if (this.initialProvider) {
|
||||
this.selectedProviderId = this.initialProvider;
|
||||
}
|
||||
|
||||
|
||||
if (savedModel) {
|
||||
this.selectedModelName = savedModel;
|
||||
} else if (this.initialModel) {
|
||||
@@ -215,36 +198,40 @@ export default {
|
||||
|
||||
// 选择提供商
|
||||
selectProvider(provider) {
|
||||
this.selectedProviderId = provider.id;
|
||||
this.selectedModelName = ''; // 清空已选择的模型
|
||||
this.tempSelectedProviderId = provider.id;
|
||||
this.tempSelectedModelName = ''; // 清空已选择的模型
|
||||
this.modelList = []; // 清空模型列表
|
||||
this.getProviderModels(provider.id); // 获取该提供商的模型列表
|
||||
},
|
||||
|
||||
// 选择模型
|
||||
selectModel(model) {
|
||||
this.selectedModelName = model;
|
||||
this.tempSelectedModelName = model;
|
||||
},
|
||||
|
||||
// 刷新模型列表
|
||||
refreshModels() {
|
||||
if (this.selectedProviderId) {
|
||||
this.getProviderModels(this.selectedProviderId);
|
||||
if (this.tempSelectedProviderId) {
|
||||
this.getProviderModels(this.tempSelectedProviderId);
|
||||
}
|
||||
},
|
||||
|
||||
// 确认选择
|
||||
confirmSelection() {
|
||||
if (this.selectedProviderId && this.selectedModelName) {
|
||||
if (this.tempSelectedProviderId && this.tempSelectedModelName) {
|
||||
// 将临时选择应用到正式选择
|
||||
this.selectedProviderId = this.tempSelectedProviderId;
|
||||
this.selectedModelName = this.tempSelectedModelName;
|
||||
|
||||
// 保存到localStorage
|
||||
this.saveToStorage();
|
||||
|
||||
|
||||
// 触发事件通知父组件
|
||||
this.$emit('selection-changed', {
|
||||
providerId: this.selectedProviderId,
|
||||
modelName: this.selectedModelName
|
||||
});
|
||||
|
||||
|
||||
this.closeDialog();
|
||||
}
|
||||
},
|
||||
@@ -252,6 +239,24 @@ export default {
|
||||
// 关闭对话框
|
||||
closeDialog() {
|
||||
this.showDialog = false;
|
||||
// 重置临时选择为当前选择
|
||||
this.resetTempSelection();
|
||||
},
|
||||
|
||||
// 重置临时选择
|
||||
resetTempSelection() {
|
||||
this.tempSelectedProviderId = this.selectedProviderId;
|
||||
this.tempSelectedModelName = this.selectedModelName;
|
||||
// 如果有临时选择的提供商,重新加载模型列表
|
||||
if (this.tempSelectedProviderId) {
|
||||
this.getProviderModels(this.tempSelectedProviderId);
|
||||
}
|
||||
},
|
||||
|
||||
// 打开对话框
|
||||
openDialog() {
|
||||
this.resetTempSelection();
|
||||
this.showDialog = true;
|
||||
},
|
||||
|
||||
// 公开方法:获取当前选择
|
||||
|
||||
173
dashboard/src/components/platform/AddNewPlatform.vue
Normal file
173
dashboard/src/components/platform/AddNewPlatform.vue
Normal file
@@ -0,0 +1,173 @@
|
||||
<template>
|
||||
<v-dialog v-model="showDialog" max-width="900px" min-height="80%">
|
||||
<v-card class="platform-selection-dialog" :title="tm('dialog.addPlatform')">
|
||||
<v-card-text class="pa-4" style="overflow-y: auto;">
|
||||
<v-row style="padding: 0px 8px;">
|
||||
<v-col v-for="(template, name) in platformTemplates"
|
||||
:key="name" cols="12" sm="6" md="6">
|
||||
<v-card variant="outlined" hover class="platform-card" @click="selectTemplate(name)">
|
||||
<div class="platform-card-content">
|
||||
<div class="platform-card-text">
|
||||
<v-card-title class="platform-card-title">{{ tm('dialog.connectTitle', { name }) }}</v-card-title>
|
||||
<v-card-text class="text-caption text-medium-emphasis platform-card-description">
|
||||
{{ getPlatformDescription(template, name) }}
|
||||
</v-card-text>
|
||||
</div>
|
||||
<div class="platform-card-logo">
|
||||
<img :src="getPlatformIcon(template.type)" v-if="getPlatformIcon(template.type)" class="platform-logo-img">
|
||||
<div v-else class="platform-logo-fallback">
|
||||
{{ name[0].toUpperCase() }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</v-card>
|
||||
</v-col>
|
||||
<v-col
|
||||
v-if="Object.keys(platformTemplates).length === 0"
|
||||
cols="12">
|
||||
<v-alert type="info" variant="tonal">
|
||||
{{ tm('dialog.noTemplates') }}
|
||||
</v-alert>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="closeDialog">{{ tm('dialog.cancel') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import { getPlatformIcon, getPlatformDescription } from '@/utils/platformUtils';
|
||||
|
||||
export default {
|
||||
name: 'AddNewPlatform',
|
||||
emits: ['update:show', 'select-template'],
|
||||
props: {
|
||||
show: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
metadata: {
|
||||
type: Object,
|
||||
default: () => ({})
|
||||
}
|
||||
},
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/platform');
|
||||
return { tm };
|
||||
},
|
||||
computed: {
|
||||
showDialog: {
|
||||
get() {
|
||||
return this.show;
|
||||
},
|
||||
set(value) {
|
||||
this.$emit('update:show', value);
|
||||
}
|
||||
},
|
||||
platformTemplates() {
|
||||
return this.metadata['platform_group']?.metadata?.platform?.config_template || {};
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
// 从工具函数导入
|
||||
getPlatformIcon,
|
||||
getPlatformDescription,
|
||||
|
||||
selectTemplate(name) {
|
||||
this.$emit('select-template', name);
|
||||
this.closeDialog();
|
||||
},
|
||||
|
||||
closeDialog() {
|
||||
this.showDialog = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.platform-selection-dialog .v-card-title {
|
||||
border-top-left-radius: 4px;
|
||||
border-top-right-radius: 4px;
|
||||
}
|
||||
|
||||
.platform-card {
|
||||
transition: all 0.3s ease;
|
||||
height: 100%;
|
||||
cursor: pointer;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.platform-card:hover {
|
||||
transform: translateY(-4px);
|
||||
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
|
||||
border-color: var(--v-primary-base);
|
||||
}
|
||||
|
||||
.platform-card-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
height: 100px;
|
||||
padding: 16px;
|
||||
position: relative;
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.platform-card-text {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.platform-card-title {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 4px;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.platform-card-description {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.platform-card-logo {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
width: 80px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.platform-logo-img {
|
||||
max-width: 60px;
|
||||
max-height: 60px;
|
||||
opacity: 0.6;
|
||||
object-fit: contain;
|
||||
}
|
||||
|
||||
.platform-logo-fallback {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border-radius: 50%;
|
||||
background-color: var(--v-primary-base);
|
||||
color: white;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
opacity: 0.3;
|
||||
}
|
||||
</style>
|
||||
237
dashboard/src/components/provider/AddNewProvider.vue
Normal file
237
dashboard/src/components/provider/AddNewProvider.vue
Normal file
@@ -0,0 +1,237 @@
|
||||
<template>
|
||||
<v-dialog v-model="showDialog" max-width="1100px" min-height="95%">
|
||||
<v-card :title="tm('dialogs.addProvider.title')">
|
||||
<v-card-text style="overflow-y: auto;">
|
||||
<v-tabs v-model="activeProviderTab" grow>
|
||||
<v-tab value="chat_completion" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-message-text</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.basic') }}
|
||||
</v-tab>
|
||||
<v-tab value="speech_to_text" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-microphone-message</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.speechToText') }}
|
||||
</v-tab>
|
||||
<v-tab value="text_to_speech" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-volume-high</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.textToSpeech') }}
|
||||
</v-tab>
|
||||
<v-tab value="embedding" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-code-json</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.embedding') }}
|
||||
</v-tab>
|
||||
<v-tab value="rerank" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-compare-vertical</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.rerank') }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
|
||||
<v-window v-model="activeProviderTab" class="mt-4">
|
||||
<v-window-item
|
||||
v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
|
||||
:key="tabType" :value="tabType">
|
||||
<v-row class="mt-1">
|
||||
<v-col v-for="(template, name) in getTemplatesByType(tabType)" :key="name" cols="12" sm="6"
|
||||
md="4">
|
||||
<v-card variant="outlined" hover class="provider-card"
|
||||
@click="selectProviderTemplate(name)">
|
||||
<div class="provider-card-content">
|
||||
<div class="provider-card-text">
|
||||
<v-card-title class="provider-card-title">接入 {{ name }}</v-card-title>
|
||||
<v-card-text
|
||||
class="text-caption text-medium-emphasis provider-card-description">
|
||||
{{ getProviderDescription(template, name) }}
|
||||
</v-card-text>
|
||||
</div>
|
||||
<div class="provider-card-logo">
|
||||
<img :src="getProviderIcon(template.provider)"
|
||||
v-if="getProviderIcon(template.provider)" class="provider-logo-img">
|
||||
<div v-else class="provider-logo-fallback">
|
||||
{{ name[0].toUpperCase() }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</v-card>
|
||||
</v-col>
|
||||
<v-col v-if="Object.keys(getTemplatesByType(tabType)).length === 0" cols="12">
|
||||
<v-alert type="info" variant="tonal">
|
||||
{{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
|
||||
</v-alert>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-window-item>
|
||||
</v-window>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="closeDialog">{{ tm('dialogs.config.cancel') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import { getProviderIcon, getProviderDescription } from '@/utils/providerUtils';
|
||||
|
||||
export default {
|
||||
name: 'AddNewProvider',
|
||||
props: {
|
||||
show: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
metadata: {
|
||||
type: Object,
|
||||
default: () => ({})
|
||||
}
|
||||
},
|
||||
emits: ['update:show', 'select-template'],
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/provider');
|
||||
return { tm };
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
activeProviderTab: 'chat_completion'
|
||||
};
|
||||
},
|
||||
computed: {
|
||||
showDialog: {
|
||||
get() {
|
||||
return this.show;
|
||||
},
|
||||
set(value) {
|
||||
this.$emit('update:show', value);
|
||||
}
|
||||
},
|
||||
|
||||
// 翻译消息的计算属性
|
||||
messages() {
|
||||
return {
|
||||
tabTypes: {
|
||||
'chat_completion': this.tm('providers.tabs.chatCompletion'),
|
||||
'speech_to_text': this.tm('providers.tabs.speechToText'),
|
||||
'text_to_speech': this.tm('providers.tabs.textToSpeech'),
|
||||
'embedding': this.tm('providers.tabs.embedding'),
|
||||
'rerank': this.tm('providers.tabs.rerank')
|
||||
}
|
||||
};
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
closeDialog() {
|
||||
this.showDialog = false;
|
||||
},
|
||||
|
||||
// 按提供商类型获取模板列表
|
||||
getTemplatesByType(type) {
|
||||
const templates = this.metadata['provider_group']?.metadata?.provider?.config_template || {};
|
||||
const filtered = {};
|
||||
|
||||
for (const [name, template] of Object.entries(templates)) {
|
||||
if (template.provider_type === type) {
|
||||
filtered[name] = template;
|
||||
}
|
||||
}
|
||||
|
||||
return filtered;
|
||||
},
|
||||
|
||||
// 从工具函数导入
|
||||
getProviderIcon,
|
||||
|
||||
// 获取Tab类型的中文名称
|
||||
getTabTypeName(tabType) {
|
||||
return this.messages.tabTypes[tabType] || tabType;
|
||||
},
|
||||
|
||||
// 获取提供商简介
|
||||
getProviderDescription(template, name) {
|
||||
return getProviderDescription(template, name, this.tm);
|
||||
},
|
||||
|
||||
// 选择提供商模板
|
||||
selectProviderTemplate(name) {
|
||||
this.$emit('select-template', name);
|
||||
this.closeDialog();
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.provider-card {
|
||||
transition: all 0.3s ease;
|
||||
height: 100%;
|
||||
cursor: pointer;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.provider-card:hover {
|
||||
transform: translateY(-4px);
|
||||
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
|
||||
border-color: var(--v-primary-base);
|
||||
}
|
||||
|
||||
.provider-card-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
height: 100px;
|
||||
padding: 16px;
|
||||
position: relative;
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.provider-card-text {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.provider-card-title {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 4px;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.provider-card-description {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.provider-card-logo {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
width: 80px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.provider-logo-img {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
opacity: 0.6;
|
||||
object-fit: contain;
|
||||
}
|
||||
|
||||
.provider-logo-fallback {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border-radius: 50%;
|
||||
background-color: var(--v-primary-base);
|
||||
color: white;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
opacity: 0.3;
|
||||
}
|
||||
</style>
|
||||
@@ -101,6 +101,7 @@
|
||||
},
|
||||
"messages": {
|
||||
"pluginNotAvailable": "Plugin not installed or unavailable",
|
||||
"pluginNotActivated": "astrbot_plugin_knowledge_base plugin not activated, please activate it in the plugin management page and restart AstrBot",
|
||||
"checkPluginFailed": "Failed to check plugin",
|
||||
"installFailed": "Installation failed",
|
||||
"installPluginFailed": "Failed to install plugin",
|
||||
|
||||
@@ -12,6 +12,13 @@
|
||||
"title": "Conversation History",
|
||||
"refresh": "Refresh"
|
||||
},
|
||||
"batch": {
|
||||
"deleteSelected": "Delete Selected ({count})"
|
||||
},
|
||||
"pagination": {
|
||||
"itemsPerPage": "Items per page",
|
||||
"showingItems": "Showing {start}-{end} of {total} items"
|
||||
},
|
||||
"table": {
|
||||
"headers": {
|
||||
"title": "Conversation Title",
|
||||
@@ -61,6 +68,13 @@
|
||||
"message": "Are you sure you want to delete conversation {title}? This action cannot be undone.",
|
||||
"cancel": "Cancel",
|
||||
"confirm": "Delete"
|
||||
},
|
||||
"batchDelete": {
|
||||
"title": "Batch Delete Confirmation",
|
||||
"message": "Are you sure you want to delete the selected {count} conversations? This action cannot be undone, please proceed with caution!",
|
||||
"andMore": "and {count} more",
|
||||
"cancel": "Cancel",
|
||||
"confirm": "Batch Delete"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
@@ -72,6 +86,10 @@
|
||||
"historyError": "Failed to fetch conversation history",
|
||||
"historySaveSuccess": "Conversation history saved successfully",
|
||||
"historySaveError": "Failed to save conversation history",
|
||||
"invalidJson": "Invalid JSON format"
|
||||
"invalidJson": "Invalid JSON format",
|
||||
"noItemSelected": "Please select conversations to delete first",
|
||||
"batchDeleteSuccess": "Successfully deleted {count} conversations",
|
||||
"batchDeleteError": "Batch delete failed",
|
||||
"batchDeletePartial": "Delete completed: {deleted} successful, {failed} failed"
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,8 @@
|
||||
"apply": "Apply Batch Settings",
|
||||
"editName": "Edit Session Name",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel"
|
||||
"cancel": "Cancel",
|
||||
"delete": "Delete"
|
||||
},
|
||||
"sessions": {
|
||||
"activeSessions": "Active Sessions",
|
||||
@@ -29,7 +30,8 @@
|
||||
"ttsProvider": "TTS Provider",
|
||||
"llmStatus": "LLM Status",
|
||||
"ttsStatus": "TTS Status",
|
||||
"pluginManagement": "Plugin Management"
|
||||
"pluginManagement": "Plugin Management",
|
||||
"actions": "Actions"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
@@ -65,6 +67,10 @@
|
||||
"fullSessionId": "Full Session ID",
|
||||
"hint": "Custom names help you easily identify sessions. The small information icon (!) will show the actual UMO when hovering."
|
||||
},
|
||||
"deleteConfirm": {
|
||||
"message": "Are you sure you want to delete session {sessionName}?",
|
||||
"warning": "This action will permanently delete all chat history and preference settings for this session (except for data linked via plugins), and this cannot be undone. Continue?"
|
||||
},
|
||||
"messages": {
|
||||
"refreshSuccess": "Session list refreshed",
|
||||
"personaUpdateSuccess": "Persona updated successfully",
|
||||
@@ -82,6 +88,8 @@
|
||||
"pluginStatusSuccess": "Plugin {name} {status}",
|
||||
"pluginStatusError": "Failed to update plugin status",
|
||||
"nameUpdateSuccess": "Session name updated successfully",
|
||||
"nameUpdateError": "Failed to update session name"
|
||||
"nameUpdateError": "Failed to update session name",
|
||||
"deleteSuccess": "Session deleted successfully",
|
||||
"deleteError": "Failed to delete session"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,6 +101,7 @@
|
||||
},
|
||||
"messages": {
|
||||
"pluginNotAvailable": "插件未安装或不可用",
|
||||
"pluginNotActivated": "astrbot_plugin_knowledge_base 插件未启用,请前往插件管理页面启用,然后重启 AstrBot。",
|
||||
"checkPluginFailed": "检查插件失败",
|
||||
"installFailed": "安装失败",
|
||||
"installPluginFailed": "安装插件失败",
|
||||
|
||||
@@ -12,6 +12,13 @@
|
||||
"title": "对话历史",
|
||||
"refresh": "刷新"
|
||||
},
|
||||
"batch": {
|
||||
"deleteSelected": "删除选中 ({count})"
|
||||
},
|
||||
"pagination": {
|
||||
"itemsPerPage": "每页",
|
||||
"showingItems": "显示 {start}-{end} 项,共 {total} 项"
|
||||
},
|
||||
"table": {
|
||||
"headers": {
|
||||
"title": "对话标题",
|
||||
@@ -61,6 +68,13 @@
|
||||
"message": "确定要删除对话 {title} 吗?此操作不可恢复。",
|
||||
"cancel": "取消",
|
||||
"confirm": "删除"
|
||||
},
|
||||
"batchDelete": {
|
||||
"title": "批量删除确认",
|
||||
"message": "确定要删除选中的 {count} 个对话吗?此操作不可恢复,请谨慎操作!",
|
||||
"andMore": "等 {count} 个",
|
||||
"cancel": "取消",
|
||||
"confirm": "批量删除"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
@@ -72,6 +86,10 @@
|
||||
"historyError": "获取对话历史失败",
|
||||
"historySaveSuccess": "对话历史保存成功",
|
||||
"historySaveError": "对话历史保存失败",
|
||||
"invalidJson": "JSON格式无效"
|
||||
"invalidJson": "JSON格式无效",
|
||||
"noItemSelected": "请先选择要删除的对话",
|
||||
"batchDeleteSuccess": "成功删除 {count} 个对话",
|
||||
"batchDeleteError": "批量删除失败",
|
||||
"batchDeletePartial": "删除完成:成功 {deleted} 个,失败 {failed} 个"
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,8 @@
|
||||
"apply": "应用批量设置",
|
||||
"editName": "备注",
|
||||
"save": "保存",
|
||||
"cancel": "取消"
|
||||
"cancel": "取消",
|
||||
"delete": "删除"
|
||||
},
|
||||
"sessions": {
|
||||
"activeSessions": "活跃会话",
|
||||
@@ -29,7 +30,8 @@
|
||||
"ttsProvider": "语音合成模型",
|
||||
"llmStatus": "启用 LLM",
|
||||
"ttsStatus": "启用 TTS",
|
||||
"pluginManagement": "插件管理"
|
||||
"pluginManagement": "插件管理",
|
||||
"actions": "操作"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
@@ -65,6 +67,10 @@
|
||||
"fullSessionId": "完整会话ID",
|
||||
"hint": "自定义名称帮助您轻松识别会话。当设置了自定义名称时,会显示一个小感叹号标识(!),鼠标悬停时会显示实际的UMO。"
|
||||
},
|
||||
"deleteConfirm": {
|
||||
"message": "确定要删除会话 {sessionName} 吗?",
|
||||
"warning": "此操作将永久删除本次会话的「全部对话记录」与「偏好设置」(插件对会话的关联数据除外),且无法恢复。确认继续?"
|
||||
},
|
||||
"messages": {
|
||||
"refreshSuccess": "会话列表已刷新",
|
||||
"personaUpdateSuccess": "人格更新成功",
|
||||
@@ -82,6 +88,8 @@
|
||||
"pluginStatusSuccess": "插件 {name} {status}",
|
||||
"pluginStatusError": "插件状态更新失败",
|
||||
"nameUpdateSuccess": "会话名称更新成功",
|
||||
"nameUpdateError": "会话名称更新失败"
|
||||
"nameUpdateError": "会话名称更新失败",
|
||||
"deleteSuccess": "会话删除成功",
|
||||
"deleteError": "会话删除失败"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,15 +342,10 @@ commonStore.getStartTime();
|
||||
<v-dialog v-model="updateStatusDialog" :width="$vuetify.display.smAndDown ? '100%' : '1200'"
|
||||
:fullscreen="$vuetify.display.xs">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-tooltip>
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn size="small" @click="checkUpdate(); getReleases(); getDevCommits();" class="action-btn"
|
||||
color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props" icon>
|
||||
<v-icon>mdi-arrow-up-circle</v-icon>
|
||||
</v-btn>
|
||||
</template>
|
||||
{{ t('core.header.buttons.update') }}
|
||||
</v-tooltip>
|
||||
<v-btn size="small" @click="checkUpdate(); getReleases(); getDevCommits();" class="action-btn"
|
||||
color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props" icon>
|
||||
<v-icon>mdi-arrow-up-circle</v-icon>
|
||||
</v-btn>
|
||||
</template>
|
||||
<v-card>
|
||||
<v-card-title class="mobile-card-title">
|
||||
@@ -473,7 +468,7 @@ commonStore.getStartTime();
|
||||
<h3 class="mb-4">{{ t('core.header.updateDialog.dashboardUpdate.title') }}</h3>
|
||||
<div class="mb-4">
|
||||
<small>{{ t('core.header.updateDialog.dashboardUpdate.currentVersion') }} {{ dashboardCurrentVersion
|
||||
}}</small>
|
||||
}}</small>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
31
dashboard/src/stores/toast.js
Normal file
31
dashboard/src/stores/toast.js
Normal file
@@ -0,0 +1,31 @@
|
||||
import { defineStore } from 'pinia'
|
||||
import { ref, computed } from 'vue'
|
||||
|
||||
export const useToastStore = defineStore('toast', () => {
|
||||
const queue = ref([])
|
||||
const current = computed(() => queue.value[0])
|
||||
|
||||
function add({
|
||||
message,
|
||||
color = 'info', // Vuetify 颜色
|
||||
timeout = 3000,
|
||||
closable = true,
|
||||
multiLine = false,
|
||||
location = 'top center'
|
||||
}) {
|
||||
queue.value.push({
|
||||
message,
|
||||
color,
|
||||
timeout,
|
||||
closable,
|
||||
multiLine,
|
||||
location
|
||||
})
|
||||
}
|
||||
|
||||
function shift() {
|
||||
queue.value.shift()
|
||||
}
|
||||
|
||||
return { current, add, shift }
|
||||
})
|
||||
78
dashboard/src/utils/platformUtils.js
Normal file
78
dashboard/src/utils/platformUtils.js
Normal file
@@ -0,0 +1,78 @@
|
||||
/**
|
||||
* 平台相关工具函数
|
||||
*/
|
||||
|
||||
/**
|
||||
* 获取平台图标
|
||||
* @param {string} name - 平台名称或类型
|
||||
* @returns {string|undefined} 图标URL
|
||||
*/
|
||||
export function getPlatformIcon(name) {
|
||||
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
|
||||
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
||||
} else if (name === 'wecom') {
|
||||
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
||||
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
||||
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
||||
} else if (name === 'lark') {
|
||||
return new URL('@/assets/images/platform_logos/lark.png', import.meta.url).href
|
||||
} else if (name === 'dingtalk') {
|
||||
return new URL('@/assets/images/platform_logos/dingtalk.svg', import.meta.url).href
|
||||
} else if (name === 'telegram') {
|
||||
return new URL('@/assets/images/platform_logos/telegram.svg', import.meta.url).href
|
||||
} else if (name === 'discord') {
|
||||
return new URL('@/assets/images/platform_logos/discord.svg', import.meta.url).href
|
||||
} else if (name === 'slack') {
|
||||
return new URL('@/assets/images/platform_logos/slack.svg', import.meta.url).href
|
||||
} else if (name === 'kook') {
|
||||
return new URL('@/assets/images/platform_logos/kook.png', import.meta.url).href
|
||||
} else if (name === 'vocechat') {
|
||||
return new URL('@/assets/images/platform_logos/vocechat.png', import.meta.url).href
|
||||
} else if (name === 'satori' || name === 'Satori') {
|
||||
return new URL('@/assets/images/platform_logos/satori.png', import.meta.url).href
|
||||
} else if (name === 'misskey') {
|
||||
return new URL('@/assets/images/platform_logos/misskey.png', import.meta.url).href
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取平台教程链接
|
||||
* @param {string} platformType - 平台类型
|
||||
* @returns {string} 教程链接
|
||||
*/
|
||||
export function getTutorialLink(platformType) {
|
||||
const tutorialMap = {
|
||||
"qq_official_webhook": "https://docs.astrbot.app/deploy/platform/qqofficial/webhook.html",
|
||||
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
|
||||
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
||||
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
|
||||
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
|
||||
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
|
||||
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
|
||||
"wechatpadpro": "https://docs.astrbot.app/deploy/platform/wechat/wechatpadpro.html",
|
||||
"weixin_official_account": "https://docs.astrbot.app/deploy/platform/weixin-official-account.html",
|
||||
"discord": "https://docs.astrbot.app/deploy/platform/discord.html",
|
||||
"slack": "https://docs.astrbot.app/deploy/platform/slack.html",
|
||||
"kook": "https://docs.astrbot.app/deploy/platform/kook.html",
|
||||
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
|
||||
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
|
||||
"misskey": "https://docs.astrbot.app/deploy/platform/misskey.html",
|
||||
}
|
||||
return tutorialMap[platformType] || "https://docs.astrbot.app";
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取平台描述
|
||||
* @param {Object} template - 平台模板
|
||||
* @param {string} name - 平台名称
|
||||
* @returns {string} 平台描述
|
||||
*/
|
||||
export function getPlatformDescription(template, name) {
|
||||
// special judge for community platforms
|
||||
if (name.includes('vocechat')) {
|
||||
return "由 @HikariFroya 提供。";
|
||||
} else if (name.includes('kook')) {
|
||||
return "由 @wuyan1003 提供。"
|
||||
}
|
||||
return '';
|
||||
}
|
||||
52
dashboard/src/utils/providerUtils.js
Normal file
52
dashboard/src/utils/providerUtils.js
Normal file
@@ -0,0 +1,52 @@
|
||||
/**
|
||||
* 提供商相关的工具函数
|
||||
*/
|
||||
|
||||
/**
|
||||
* 获取提供商类型对应的图标
|
||||
* @param {string} type - 提供商类型
|
||||
* @returns {string} 图标 URL
|
||||
*/
|
||||
export function getProviderIcon(type) {
|
||||
const icons = {
|
||||
'openai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
|
||||
'azure': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/azure.svg',
|
||||
'xai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/xai.svg',
|
||||
'anthropic': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/anthropic.svg',
|
||||
'ollama': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ollama.svg',
|
||||
'google': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
|
||||
'deepseek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
|
||||
'modelscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/modelscope.svg',
|
||||
'zhipu': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
|
||||
'siliconflow': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
|
||||
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
|
||||
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
|
||||
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
|
||||
"coze": "https://registry.npmmirror.com/@lobehub/icons-static-svg/1.66.0/files/icons/coze.svg",
|
||||
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
|
||||
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
|
||||
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
|
||||
'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',
|
||||
'minimax': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/minimax.svg',
|
||||
'302ai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
|
||||
'microsoft': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/microsoft.svg',
|
||||
'vllm': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/vllm.svg',
|
||||
};
|
||||
return icons[type] || '';
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取提供商简介
|
||||
* @param {Object} template - 模板对象
|
||||
* @param {string} name - 提供商名称
|
||||
* @param {Function} tm - 翻译函数
|
||||
* @returns {string} 提供商描述
|
||||
*/
|
||||
export function getProviderDescription(template, name, tm) {
|
||||
if (name == 'OpenAI') {
|
||||
return tm('providers.description.openai', { type: template.type });
|
||||
} else if (name == 'vLLM Rerank') {
|
||||
return tm('providers.description.vllm_rerank', { type: template.type });
|
||||
}
|
||||
return tm('providers.description.default', { type: template.type });
|
||||
}
|
||||
16
dashboard/src/utils/toast.js
Normal file
16
dashboard/src/utils/toast.js
Normal file
@@ -0,0 +1,16 @@
|
||||
import { useToastStore } from '@/stores/toast'
|
||||
|
||||
export function useToast() {
|
||||
const store = useToastStore()
|
||||
|
||||
const toast = (message, color = 'info', opts = {}) =>
|
||||
store.add({ message, color, ...opts })
|
||||
|
||||
return {
|
||||
toast,
|
||||
success: (msg, opts) => toast(msg, 'success', opts),
|
||||
error: (msg, opts) => toast(msg, 'error', opts),
|
||||
info: (msg, opts) => toast(msg, 'primary', opts),
|
||||
warning: (msg, opts) => toast(msg, 'warning', opts)
|
||||
}
|
||||
}
|
||||
@@ -37,18 +37,29 @@
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="fetchConversations"
|
||||
:loading="loading" size="small">
|
||||
:loading="loading" size="small" class="mr-2">
|
||||
{{ tm('history.refresh') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
v-if="selectedItems.length > 0"
|
||||
color="error"
|
||||
prepend-icon="mdi-delete"
|
||||
variant="tonal"
|
||||
@click="confirmBatchDelete"
|
||||
:disabled="loading"
|
||||
size="small">
|
||||
{{ tm('batch.deleteSelected', { count: selectedItems.length }) }}
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-text class="pa-0">
|
||||
<v-data-table :headers="tableHeaders" :items="conversations" :loading="loading"
|
||||
style="font-size: 12px;" density="comfortable" hide-default-footer items-per-page="10"
|
||||
<v-data-table v-model="selectedItems" :headers="tableHeaders" :items="conversations"
|
||||
:loading="loading" style="font-size: 12px;" density="comfortable" hide-default-footer
|
||||
class="elevation-0" :items-per-page="pagination.page_size"
|
||||
:items-per-page-options="[10, 20, 50, 100]" @update:options="handleTableOptions">
|
||||
:items-per-page-options="pageSizeOptions" show-select return-object
|
||||
:disabled="loading" @update:options="handleTableOptions">
|
||||
<template v-slot:item.title="{ item }">
|
||||
<div class="d-flex align-center">
|
||||
<span>{{ item.title || tm('status.noTitle') }}</span>
|
||||
@@ -67,6 +78,10 @@
|
||||
</v-chip>
|
||||
</template>
|
||||
|
||||
<template v-slot:item.cid="{ item }">
|
||||
<span class="text-truncate">{{ item.cid || tm('status.unknown') }}</span>
|
||||
</template>
|
||||
|
||||
<template v-slot:item.sessionId="{ item }">
|
||||
<span>{{ item.sessionInfo.sessionId || tm('status.unknown') }}</span>
|
||||
</template>
|
||||
@@ -82,15 +97,15 @@
|
||||
<template v-slot:item.actions="{ item }">
|
||||
<div class="actions-wrapper">
|
||||
<v-btn icon variant="plain" size="x-small" class="action-button"
|
||||
@click="viewConversation(item)">
|
||||
@click="viewConversation(item)" :disabled="loading">
|
||||
<v-icon>mdi-eye</v-icon>
|
||||
</v-btn>
|
||||
<v-btn icon variant="plain" size="x-small" class="action-button"
|
||||
@click="editConversation(item)">
|
||||
@click="editConversation(item)" :disabled="loading">
|
||||
<v-icon>mdi-pencil</v-icon>
|
||||
</v-btn>
|
||||
<v-btn icon color="error" variant="plain" size="x-small" class="action-button"
|
||||
@click="confirmDeleteConversation(item)">
|
||||
@click="confirmDeleteConversation(item)" :disabled="loading">
|
||||
<v-icon>mdi-delete</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
@@ -105,9 +120,25 @@
|
||||
</v-data-table>
|
||||
|
||||
<!-- 分页控制 -->
|
||||
<div class="d-flex justify-end">
|
||||
<div class="d-flex justify-center py-3">
|
||||
<!-- 每页大小选择器 -->
|
||||
<div class="d-flex justify-between align-center px-4 py-2 bg-grey-lighten-5">
|
||||
<div class="d-flex align-center">
|
||||
<span class="text-caption mr-2">{{ tm('pagination.itemsPerPage') }}:</span>
|
||||
<v-select v-model="pagination.page_size" :items="pageSizeOptions" variant="outlined"
|
||||
density="compact" hide-details style="max-width: 100px;"
|
||||
:disabled="loading" @update:model-value="onPageSizeChange"></v-select>
|
||||
</div>
|
||||
<div class="text-caption ml-4">
|
||||
{{ tm('pagination.showingItems', {
|
||||
start: Math.min((pagination.page - 1) * pagination.page_size + 1, pagination.total),
|
||||
end: Math.min(pagination.page * pagination.page_size, pagination.total),
|
||||
total: pagination.total
|
||||
}) }}
|
||||
</div>
|
||||
</div>
|
||||
<v-pagination v-model="pagination.page" :length="pagination.total_pages" :disabled="loading"
|
||||
@update:model-value="fetchConversations" rounded="circle"></v-pagination>
|
||||
@update:model-value="fetchConversations" rounded="circle" :total-visible="7"></v-pagination>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
@@ -116,24 +147,20 @@
|
||||
<!-- 对话详情对话框 -->
|
||||
<v-dialog v-model="dialogView" max-width="900px" scrollable>
|
||||
<v-card class="conversation-detail-card">
|
||||
<v-card-title class="bg-primary text-white py-3 d-flex align-center">
|
||||
<v-icon color="white" class="me-2">mdi-eye</v-icon>
|
||||
<v-card-title class="ml-2 mt-2 d-flex align-center">
|
||||
<span class="text-truncate">{{ selectedConversation?.title || tm('status.noTitle') }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
|
||||
<div class="d-flex align-center" v-if="selectedConversation?.sessionInfo">
|
||||
<v-chip color="white" text-color="primary" size="small" class="mr-2">
|
||||
<v-chip text-color="primary" size="small" class="mr-2" rounded="md">
|
||||
{{ selectedConversation.sessionInfo.platform }}
|
||||
</v-chip>
|
||||
<v-chip color="white" text-color="secondary" size="small">
|
||||
<v-chip text-color="secondary" size="small" rounded="md">
|
||||
{{ getMessageTypeDisplay(selectedConversation.sessionInfo.messageType) }}
|
||||
</v-chip>
|
||||
</div>
|
||||
</v-card-title>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-text class="py-4">
|
||||
<v-card-text>
|
||||
<div class="mb-4 d-flex align-center">
|
||||
<v-btn color="secondary" variant="tonal" size="small" class="mr-2"
|
||||
@click="isEditingHistory = !isEditingHistory">
|
||||
@@ -168,16 +195,10 @@
|
||||
</div>
|
||||
|
||||
<!-- 消息列表组件 -->
|
||||
<MessageList
|
||||
v-else
|
||||
:messages="formattedMessages"
|
||||
:isDark="false"
|
||||
/>
|
||||
<MessageList v-else :messages="formattedMessages" :isDark="false" />
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="closeHistoryDialog">
|
||||
@@ -227,7 +248,7 @@
|
||||
|
||||
<v-card-text class="py-4">
|
||||
<p>{{ tm('dialogs.delete.message', { title: selectedConversation?.title || tm('status.noTitle') })
|
||||
}}</p>
|
||||
}}</p>
|
||||
</v-card-text>
|
||||
|
||||
<v-divider></v-divider>
|
||||
@@ -244,6 +265,48 @@
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 批量删除确认对话框 -->
|
||||
<v-dialog v-model="dialogBatchDelete" max-width="600px">
|
||||
<v-card>
|
||||
<v-card-title class="bg-error text-white py-3">
|
||||
<v-icon color="white" class="me-2">mdi-delete</v-icon>
|
||||
<span>{{ tm('dialogs.batchDelete.title') }}</span>
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="py-4">
|
||||
<p class="mb-3">{{ tm('dialogs.batchDelete.message', { count: selectedItems.length }) }}</p>
|
||||
|
||||
<!-- 显示前几个要删除的对话 -->
|
||||
<div v-if="selectedItems.length > 0" class="mb-3">
|
||||
<v-chip v-for="(item, index) in selectedItems.slice(0, 5)" :key="`${item.user_id}-${item.cid}`"
|
||||
size="small" class="mr-1 mb-1" closable @click:close="removeFromSelection(item)"
|
||||
:disabled="loading">
|
||||
{{ item.title || tm('status.noTitle') }}
|
||||
</v-chip>
|
||||
<v-chip v-if="selectedItems.length > 5" size="small" class="mr-1 mb-1">
|
||||
{{ tm('dialogs.batchDelete.andMore', { count: selectedItems.length - 5 }) }}
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<v-alert type="warning" variant="tonal" class="mb-3">
|
||||
{{ tm('dialogs.batchDelete.warning') }}
|
||||
</v-alert>
|
||||
</v-card-text>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="dialogBatchDelete = false" :disabled="loading">
|
||||
{{ tm('dialogs.batchDelete.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn color="error" @click="batchDeleteConversations" :loading="loading">
|
||||
{{ tm('dialogs.batchDelete.confirm') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 消息提示 -->
|
||||
<v-snackbar :timeout="3000" elevation="24" :color="messageType" v-model="showMessage" location="top">
|
||||
{{ message }}
|
||||
@@ -253,6 +316,7 @@
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import { debounce } from 'lodash';
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor';
|
||||
import MarkdownIt from 'markdown-it';
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
@@ -291,32 +355,13 @@ export default {
|
||||
conversations: [],
|
||||
search: '',
|
||||
headers: [],
|
||||
selectedItems: [], // 批量选择的项目
|
||||
|
||||
// 筛选条件
|
||||
platformFilter: [],
|
||||
messageTypeFilter: [],
|
||||
lastAppliedFilters: null, // 记录上次应用的筛选条件
|
||||
|
||||
// 平台颜色映射
|
||||
platformColors: {
|
||||
'telegram': 'blue-lighten-1',
|
||||
'qq_official': 'purple-lighten-1',
|
||||
'qq_official_webhook': 'purple-lighten-2',
|
||||
'aiocqhttp': 'deep-purple-lighten-1',
|
||||
'lark': 'cyan-darken-1',
|
||||
'wecom': 'green-darken-1',
|
||||
'dingtalk': 'blue-darken-2',
|
||||
'default': 'grey-lighten-1'
|
||||
},
|
||||
|
||||
// 消息类型颜色映射
|
||||
messageTypeColors: {
|
||||
'GroupMessage': 'green',
|
||||
'FriendMessage': 'blue',
|
||||
'GuildMessage': 'purple',
|
||||
'default': 'grey'
|
||||
},
|
||||
|
||||
// 分页数据
|
||||
pagination: {
|
||||
page: 1,
|
||||
@@ -324,11 +369,13 @@ export default {
|
||||
total: 0,
|
||||
total_pages: 0
|
||||
},
|
||||
pageSizeOptions: [10, 20, 50, 100], // 每页大小选项
|
||||
|
||||
// 对话框控制
|
||||
dialogView: false,
|
||||
dialogEdit: false,
|
||||
dialogDelete: false,
|
||||
dialogBatchDelete: false, // 批量删除对话框
|
||||
|
||||
// 选中的对话
|
||||
selectedConversation: null,
|
||||
@@ -340,11 +387,6 @@ export default {
|
||||
cid: '',
|
||||
title: ''
|
||||
},
|
||||
defaultItem: {
|
||||
user_id: '',
|
||||
cid: '',
|
||||
title: ''
|
||||
},
|
||||
|
||||
// 表单验证
|
||||
valid: true,
|
||||
@@ -379,8 +421,7 @@ export default {
|
||||
},
|
||||
|
||||
created() {
|
||||
// 创建一个防抖函数,避免频繁请求
|
||||
this.debouncedApplyFilters = this.debounce(() => {
|
||||
this.debouncedApplyFilters = debounce(() => {
|
||||
// 重置到第一页
|
||||
this.pagination.page = 1;
|
||||
this.fetchConversations();
|
||||
@@ -392,13 +433,14 @@ export default {
|
||||
tableHeaders() {
|
||||
return [
|
||||
{ title: this.tm('table.headers.title'), key: 'title', sortable: true },
|
||||
{ title: '会话 ID', key: 'cid', sortable: true, width: '100px' },
|
||||
{
|
||||
title: this.tm('table.headers.sessionId'),
|
||||
align: 'center',
|
||||
children: [
|
||||
{ title: this.tm('table.headers.platform'), key: 'platform', sortable: true, width: '120px' },
|
||||
{ title: this.tm('table.headers.type'), key: 'messageType', sortable: true, width: '100px' },
|
||||
{ title: '会话 ID', key: 'sessionId', sortable: true, width: '100px' },
|
||||
{ title: '用户 ID', key: 'sessionId', sortable: true, width: '100px' },
|
||||
],
|
||||
},
|
||||
{ title: this.tm('table.headers.createdAt'), key: 'created_at', sortable: true, width: '180px' },
|
||||
@@ -431,17 +473,6 @@ export default {
|
||||
];
|
||||
},
|
||||
|
||||
// 筛选后的对话 - 现在只用于额外的客户端筛选(排除astrbot和webchat)
|
||||
filteredConversations() {
|
||||
return this.conversations.filter(conv => {
|
||||
// 排除 user_id 为 astrbot 或 platform 为 webchat 的对话
|
||||
if (conv.user_id === 'astrbot' || conv.sessionInfo?.platform === 'webchat') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
},
|
||||
|
||||
// 当前的筛选条件对象
|
||||
currentFilters() {
|
||||
const platforms = this.platformFilter.map(item =>
|
||||
@@ -499,19 +530,6 @@ export default {
|
||||
});
|
||||
},
|
||||
|
||||
// 添加防抖函数
|
||||
debounce(func, wait) {
|
||||
let timeout;
|
||||
return function () {
|
||||
const context = this;
|
||||
const args = arguments;
|
||||
clearTimeout(timeout);
|
||||
timeout = setTimeout(() => {
|
||||
func.apply(context, args);
|
||||
}, wait);
|
||||
};
|
||||
},
|
||||
|
||||
// 处理表格选项变更(页面大小等)
|
||||
handleTableOptions(options) {
|
||||
// 处理页面大小变更
|
||||
@@ -552,87 +570,93 @@ export default {
|
||||
},
|
||||
|
||||
// 获取对话列表
|
||||
async fetchConversations() {
|
||||
this.loading = true;
|
||||
try {
|
||||
// 准备请求参数,包含分页和筛选条件
|
||||
const params = {
|
||||
page: this.pagination.page,
|
||||
page_size: this.pagination.page_size
|
||||
};
|
||||
fetchConversations: (() => {
|
||||
let controller = new AbortController();
|
||||
|
||||
// 添加筛选条件 - 处理combobox的混合数据格式
|
||||
if (this.platformFilter.length > 0) {
|
||||
const platforms = this.platformFilter.map(item =>
|
||||
typeof item === 'object' ? item.value : item
|
||||
);
|
||||
params.platforms = platforms.join(',');
|
||||
}
|
||||
return async function () {
|
||||
// 新请求前停止之前的请求
|
||||
controller?.abort()
|
||||
controller = new AbortController();
|
||||
|
||||
if (this.messageTypeFilter.length > 0) {
|
||||
params.message_types = this.messageTypeFilter.join(',');
|
||||
}
|
||||
this.loading = true;
|
||||
try {
|
||||
// 准备请求参数,包含分页和筛选条件
|
||||
const params = {
|
||||
page: this.pagination.page,
|
||||
page_size: this.pagination.page_size
|
||||
};
|
||||
|
||||
if (this.search) {
|
||||
params.search = this.search;
|
||||
}
|
||||
|
||||
// 添加排除条件
|
||||
params.exclude_ids = 'astrbot';
|
||||
params.exclude_platforms = 'webchat';
|
||||
|
||||
console.log(`正在请求对话列表: /api/conversation/list 参数:`, params);
|
||||
|
||||
const response = await axios.get('/api/conversation/list', { params });
|
||||
|
||||
console.log('收到对话列表响应:', response.data);
|
||||
|
||||
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
|
||||
|
||||
if (response.data.status === "ok") {
|
||||
const data = response.data.data;
|
||||
|
||||
if (!data || !data.conversations) {
|
||||
console.error('API 返回数据格式不符合预期:', data);
|
||||
this.showErrorMessage(this.tm('messages.fetchError'));
|
||||
return;
|
||||
// 添加筛选条件 - 处理combobox的混合数据格式
|
||||
if (this.platformFilter.length > 0) {
|
||||
const platforms = this.platformFilter.map(item =>
|
||||
typeof item === 'object' ? item.value : item
|
||||
);
|
||||
params.platforms = platforms.join(',');
|
||||
}
|
||||
|
||||
// 处理会话数据,解析sessionId
|
||||
this.conversations = (data.conversations || []).map(conv => {
|
||||
// 为每个会话添加会话信息
|
||||
conv.sessionInfo = this.parseSessionId(conv.user_id);
|
||||
return conv;
|
||||
if (this.messageTypeFilter.length > 0) {
|
||||
params.message_types = this.messageTypeFilter.join(',');
|
||||
}
|
||||
|
||||
if (this.search) {
|
||||
params.search = this.search.trim();
|
||||
}
|
||||
|
||||
// 添加排除条件
|
||||
params.exclude_ids = 'astrbot';
|
||||
params.exclude_platforms = 'webchat';
|
||||
|
||||
const response = await axios.get('/api/conversation/list', {
|
||||
signal: controller.signal,
|
||||
params
|
||||
});
|
||||
|
||||
// 更新分页信息
|
||||
if (data.pagination) {
|
||||
this.pagination = {
|
||||
page: data.pagination.page || 1,
|
||||
page_size: data.pagination.page_size || 20,
|
||||
total: data.pagination.total || 0,
|
||||
total_pages: data.pagination.total_pages || 1
|
||||
};
|
||||
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
|
||||
|
||||
if (response.data.status === "ok") {
|
||||
const data = response.data.data;
|
||||
|
||||
if (!data || !data.conversations) {
|
||||
console.error('API 返回数据格式不符合预期:', data);
|
||||
this.showErrorMessage(this.tm('messages.fetchError'));
|
||||
return;
|
||||
}
|
||||
|
||||
// 处理会话数据,解析sessionId
|
||||
this.conversations = (data.conversations || []).map(conv => {
|
||||
// 为每个会话添加会话信息
|
||||
conv.sessionInfo = this.parseSessionId(conv.user_id);
|
||||
return conv;
|
||||
});
|
||||
|
||||
// 更新分页信息
|
||||
if (data.pagination) {
|
||||
this.pagination = {
|
||||
page: data.pagination.page || 1,
|
||||
page_size: data.pagination.page_size || 20,
|
||||
total: data.pagination.total || 0,
|
||||
total_pages: data.pagination.total_pages || 1
|
||||
};
|
||||
} else {
|
||||
console.warn('API 响应中没有分页信息');
|
||||
}
|
||||
} else {
|
||||
console.warn('API 响应中没有分页信息');
|
||||
this.showErrorMessage(response.data.message || this.tm('messages.fetchError'));
|
||||
}
|
||||
} else {
|
||||
this.showErrorMessage(response.data.message || this.tm('messages.fetchError'));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取对话列表出错:', error);
|
||||
if (error.response) {
|
||||
console.error('错误响应数据:', error.response.data);
|
||||
console.error('错误状态码:', error.response.status);
|
||||
}
|
||||
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.fetchError'));
|
||||
} finally {
|
||||
// this.loading = false;
|
||||
setTimeout(() => {
|
||||
} catch (error) {
|
||||
if (axios.isCancel(error)) return;
|
||||
|
||||
console.error('获取对话列表出错:', error);
|
||||
if (error.response) {
|
||||
console.error('错误响应数据:', error.response.data);
|
||||
console.error('错误状态码:', error.response.status);
|
||||
}
|
||||
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.fetchError'));
|
||||
} finally {
|
||||
this.loading = false;
|
||||
}, 200);
|
||||
}
|
||||
}
|
||||
},
|
||||
})(),
|
||||
|
||||
// 查看对话详情
|
||||
async viewConversation(item) {
|
||||
@@ -790,6 +814,88 @@ export default {
|
||||
}
|
||||
} catch (error) {
|
||||
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.deleteError'));
|
||||
} finally {
|
||||
this.loading = false;
|
||||
this.selectedItems = this.selectedItems.filter(item =>
|
||||
!(item.user_id === this.selectedConversation.user_id && item.cid === this.selectedConversation.cid)
|
||||
);
|
||||
this.selectedConversation = null;
|
||||
}
|
||||
},
|
||||
|
||||
// 处理页面大小变更
|
||||
onPageSizeChange() {
|
||||
this.pagination.page = 1; // 重置到第一页
|
||||
this.fetchConversations();
|
||||
},
|
||||
|
||||
// 确认批量删除
|
||||
confirmBatchDelete() {
|
||||
if (this.selectedItems.length === 0) {
|
||||
this.showErrorMessage(this.tm('messages.noItemSelected'));
|
||||
return;
|
||||
}
|
||||
this.dialogBatchDelete = true;
|
||||
},
|
||||
|
||||
// 从选择中移除项目
|
||||
removeFromSelection(item) {
|
||||
const index = this.selectedItems.findIndex(selected =>
|
||||
selected.user_id === item.user_id && selected.cid === item.cid
|
||||
);
|
||||
if (index !== -1) {
|
||||
this.selectedItems.splice(index, 1);
|
||||
}
|
||||
},
|
||||
|
||||
// 批量删除对话
|
||||
async batchDeleteConversations() {
|
||||
if (this.selectedItems.length === 0) {
|
||||
this.showErrorMessage(this.tm('messages.noItemSelected'));
|
||||
return;
|
||||
}
|
||||
|
||||
this.loading = true;
|
||||
try {
|
||||
// 准备批量删除的数据
|
||||
const conversations = this.selectedItems.map(item => ({
|
||||
user_id: item.user_id,
|
||||
cid: item.cid
|
||||
}));
|
||||
|
||||
const response = await axios.post('/api/conversation/delete', {
|
||||
conversations: conversations
|
||||
});
|
||||
|
||||
if (response.data.status === "ok") {
|
||||
const result = response.data.data;
|
||||
this.dialogBatchDelete = false;
|
||||
this.selectedItems = []; // 清空选择
|
||||
|
||||
// 显示结果消息
|
||||
if (result.failed_count > 0) {
|
||||
this.showErrorMessage(
|
||||
this.tm('messages.batchDeletePartial', {
|
||||
deleted: result.deleted_count,
|
||||
failed: result.failed_count
|
||||
})
|
||||
);
|
||||
} else {
|
||||
this.showSuccessMessage(
|
||||
this.tm('messages.batchDeleteSuccess', {
|
||||
count: result.deleted_count
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// 刷新列表
|
||||
this.fetchConversations();
|
||||
} else {
|
||||
this.showErrorMessage(response.data.message || this.tm('messages.batchDeleteError'));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('批量删除对话出错:', error);
|
||||
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.batchDeleteError'));
|
||||
} finally {
|
||||
this.loading = false;
|
||||
}
|
||||
@@ -812,35 +918,6 @@ export default {
|
||||
}).format(date);
|
||||
},
|
||||
|
||||
// 格式化消息内容
|
||||
formatMessage(content) {
|
||||
|
||||
// content 可能是数组
|
||||
// [{"type": "image_url", "image_url": {"url": url_or_base64}}, {"type": "text", "text": "text"}]
|
||||
|
||||
let final_content = content;
|
||||
if (Array.isArray(content)) {
|
||||
// 处理数组内容
|
||||
final_content = content.map(item => {
|
||||
if (item.type === 'image_url') {
|
||||
return `<img src="${item.image_url.url}" alt="Image" />`;
|
||||
} else if (item.type === 'text') {
|
||||
return item.text;
|
||||
}
|
||||
return '';
|
||||
}).join('\n');
|
||||
} else if (typeof content === 'object') {
|
||||
// 处理对象内容
|
||||
final_content = Object.values(content).join('');
|
||||
} else if (typeof content === 'string') {
|
||||
// 处理字符串内容
|
||||
final_content = content;
|
||||
} else if (!final_content) return this.tm('status.emptyContent');
|
||||
|
||||
// 使用markdown-it处理,默认安全(html: false会禁用HTML标签)
|
||||
return md.render(final_content);
|
||||
},
|
||||
|
||||
// 显示成功消息
|
||||
showSuccessMessage(message) {
|
||||
this.message = message;
|
||||
@@ -917,6 +994,14 @@ export default {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.text-truncate {
|
||||
display: inline-block;
|
||||
max-width: 100px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
/* 动画 */
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
|
||||
@@ -5,6 +5,7 @@ import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import ReadmeDialog from '@/components/shared/ReadmeDialog.vue';
|
||||
import ProxySelector from '@/components/shared/ProxySelector.vue';
|
||||
import axios from 'axios';
|
||||
import { pinyin } from 'pinyin-pro';
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
@@ -65,6 +66,32 @@ const marketSearch = ref("");
|
||||
const filterKeys = ['name', 'desc', 'author'];
|
||||
const refreshingMarket = ref(false);
|
||||
|
||||
// 插件市场拼音搜索
|
||||
const normalizeStr = (s) => (s ?? '').toString().toLowerCase().trim();
|
||||
const toPinyinText = (s) => pinyin(s ?? '', { toneType: 'none' }).toLowerCase().replace(/\s+/g, '');
|
||||
const toInitials = (s) => pinyin(s ?? '', { pattern: 'first', toneType: 'none' }).toLowerCase().replace(/\s+/g, '');
|
||||
const marketCustomFilter = (value, query, item) => {
|
||||
const q = normalizeStr(query);
|
||||
if (!q) return true;
|
||||
|
||||
const candidates = new Set();
|
||||
if (value != null) candidates.add(String(value));
|
||||
if (item?.name) candidates.add(String(item.name));
|
||||
if (item?.trimmedName) candidates.add(String(item.trimmedName));
|
||||
if (item?.desc) candidates.add(String(item.desc));
|
||||
if (item?.author) candidates.add(String(item.author));
|
||||
|
||||
for (const v of candidates) {
|
||||
const nv = normalizeStr(v);
|
||||
if (nv.includes(q)) return true;
|
||||
const pv = toPinyinText(v);
|
||||
if (pv.includes(q)) return true;
|
||||
const iv = toInitials(v);
|
||||
if (iv.includes(q)) return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const plugin_handler_info_headers = computed(() => [
|
||||
{ title: tm('table.headers.eventType'), key: 'event_type_h' },
|
||||
{ title: tm('table.headers.description'), key: 'desc', maxWidth: '250px' },
|
||||
@@ -772,7 +799,7 @@ onMounted(async () => {
|
||||
|
||||
<v-col cols="12" md="12" style="padding: 0px;">
|
||||
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name"
|
||||
:loading="loading_" v-model:search="marketSearch" :filter-keys="filterKeys">
|
||||
:loading="loading_" v-model:search="marketSearch" :filter-keys="filterKeys" :custom-filter="marketCustomFilter">
|
||||
<template v-slot:item.name="{ item }">
|
||||
<div class="d-flex align-center"
|
||||
style="overflow-x: auto; scrollbar-width: thin; scrollbar-track-color: transparent;">
|
||||
|
||||
@@ -10,7 +10,8 @@
|
||||
{{ tm('subtitle') }}
|
||||
</p>
|
||||
</div>
|
||||
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddPlatformDialog = true" rounded="xl" size="x-large">
|
||||
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddPlatformDialog = true"
|
||||
rounded="xl" size="x-large">
|
||||
{{ tm('addAdapter') }}
|
||||
</v-btn>
|
||||
</v-row>
|
||||
@@ -25,14 +26,9 @@
|
||||
|
||||
<v-row v-else>
|
||||
<v-col v-for="(platform, index) in config_data.platform || []" :key="index" cols="12" md="6" lg="4" xl="3">
|
||||
<item-card
|
||||
:item="platform"
|
||||
title-field="id"
|
||||
enabled-field="enable"
|
||||
:bglogo="getPlatformIcon(platform.type || platform.id)"
|
||||
@toggle-enabled="platformStatusChange"
|
||||
@delete="deletePlatform"
|
||||
@edit="editPlatform">
|
||||
<item-card :item="platform" title-field="id" enabled-field="enable"
|
||||
:bglogo="getPlatformIcon(platform.type || platform.id)" @toggle-enabled="platformStatusChange"
|
||||
@delete="deletePlatform" @edit="editPlatform">
|
||||
</item-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
@@ -61,59 +57,13 @@
|
||||
</v-container>
|
||||
|
||||
<!-- 添加平台适配器对话框 -->
|
||||
<v-dialog v-model="showAddPlatformDialog" max-width="900px" min-height="80%">
|
||||
<v-card class="platform-selection-dialog">
|
||||
<v-card-title class="bg-primary text-white py-3 px-4" style="display: flex; align-items: center;">
|
||||
<v-icon color="white" class="me-2">mdi-plus-circle</v-icon>
|
||||
<span>{{ tm('dialog.addPlatform') }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn icon variant="text" color="white" @click="showAddPlatformDialog = false">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-4" style="overflow-y: auto;">
|
||||
<v-row class="mt-1">
|
||||
<v-col v-for="(template, name) in metadata['platform_group']?.metadata?.platform?.config_template || {}"
|
||||
:key="name" cols="12" sm="6" md="6">
|
||||
<v-card variant="outlined" hover class="platform-card" @click="selectPlatformTemplate(name)">
|
||||
<div class="platform-card-content">
|
||||
<div class="platform-card-text">
|
||||
<v-card-title class="platform-card-title">{{ tm('dialog.connectTitle', { name }) }}</v-card-title>
|
||||
<v-card-text class="text-caption text-medium-emphasis platform-card-description">
|
||||
{{ getPlatformDescription(template, name) }}
|
||||
</v-card-text>
|
||||
</div>
|
||||
<div class="platform-card-logo">
|
||||
<img :src="getPlatformIcon(template.type)" v-if="getPlatformIcon(template.type)" class="platform-logo-img">
|
||||
<div v-else class="platform-logo-fallback">
|
||||
{{ name[0].toUpperCase() }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</v-card>
|
||||
</v-col>
|
||||
<v-col
|
||||
v-if="Object.keys(metadata['platform_group']?.metadata?.platform?.config_template || {}).length === 0"
|
||||
cols="12">
|
||||
<v-alert type="info" variant="tonal">
|
||||
{{ tm('dialog.noTemplates') }}
|
||||
</v-alert>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
<AddNewPlatform v-model:show="showAddPlatformDialog" :metadata="metadata"
|
||||
@select-template="selectPlatformTemplate" />
|
||||
|
||||
<!-- 配置对话框 -->
|
||||
<v-dialog v-model="showPlatformCfg" persistent width="900px" max-width="90%">
|
||||
<v-card>
|
||||
<v-card-title class="bg-primary text-white py-3">
|
||||
<v-icon color="white" class="me-2">{{ updatingMode ? 'mdi-pencil' : 'mdi-plus' }}</v-icon>
|
||||
<span>{{ updatingMode ? tm('dialog.edit') : tm('dialog.add') }} {{ newSelectedPlatformName }} {{
|
||||
tm('dialog.adapter') }}</span>
|
||||
</v-card-title>
|
||||
|
||||
<v-card
|
||||
:title="updatingMode ? tm('dialog.edit') : tm('dialog.add') + ` ${newSelectedPlatformName} ` + tm('dialog.adapter')">
|
||||
<v-card-text class="py-4">
|
||||
<v-row>
|
||||
<v-col cols="12">
|
||||
@@ -177,7 +127,9 @@
|
||||
</v-card-title>
|
||||
<v-card-text class="py-4">
|
||||
<p>{{ tm('dialog.securityWarning.aiocqhttpTokenMissing') }}</p>
|
||||
<span><a href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7" target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
|
||||
<span><a
|
||||
href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7"
|
||||
target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
|
||||
</v-card-text>
|
||||
<v-card-actions class="px-4 pb-4">
|
||||
<v-spacer></v-spacer>
|
||||
@@ -199,8 +151,10 @@ import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import ItemCard from '@/components/shared/ItemCard.vue';
|
||||
import AddNewPlatform from '@/components/platform/AddNewPlatform.vue';
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
import { getPlatformIcon, getTutorialLink } from '@/utils/platformUtils';
|
||||
|
||||
export default {
|
||||
name: 'PlatformPage',
|
||||
@@ -208,7 +162,8 @@ export default {
|
||||
AstrBotConfig,
|
||||
WaitingForRestart,
|
||||
ConsoleDisplayer,
|
||||
ItemCard
|
||||
ItemCard,
|
||||
AddNewPlatform
|
||||
},
|
||||
setup() {
|
||||
const { t } = useI18n();
|
||||
@@ -285,69 +240,22 @@ export default {
|
||||
},
|
||||
|
||||
methods: {
|
||||
// 从工具函数导入
|
||||
getPlatformIcon(platform_id) {
|
||||
// 首先检查是否有来自插件的 logo_token
|
||||
const template = this.metadata['platform_group']?.metadata?.platform?.config_template?.[platform_id];
|
||||
if (template && template.logo_token) {
|
||||
// 通过文件服务访问插件提供的 logo
|
||||
return `/api/file/${template.logo_token}`;
|
||||
}
|
||||
return getPlatformIcon(platform_id);
|
||||
},
|
||||
|
||||
openTutorial() {
|
||||
const tutorialUrl = this.getTutorialLink(this.newSelectedPlatformConfig.type);
|
||||
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);
|
||||
window.open(tutorialUrl, '_blank');
|
||||
},
|
||||
|
||||
getPlatformIcon(name) {
|
||||
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
|
||||
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
||||
} else if (name === 'wecom') {
|
||||
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
||||
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
||||
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
||||
} else if (name === 'lark') {
|
||||
return new URL('@/assets/images/platform_logos/lark.png', import.meta.url).href
|
||||
} else if (name === 'dingtalk') {
|
||||
return new URL('@/assets/images/platform_logos/dingtalk.svg', import.meta.url).href
|
||||
} else if (name === 'telegram') {
|
||||
return new URL('@/assets/images/platform_logos/telegram.svg', import.meta.url).href
|
||||
} else if (name === 'discord') {
|
||||
return new URL('@/assets/images/platform_logos/discord.svg', import.meta.url).href
|
||||
} else if (name === 'slack') {
|
||||
return new URL('@/assets/images/platform_logos/slack.svg', import.meta.url).href
|
||||
} else if (name === 'kook') {
|
||||
return new URL('@/assets/images/platform_logos/kook.png', import.meta.url).href
|
||||
} else if (name === 'vocechat') {
|
||||
return new URL('@/assets/images/platform_logos/vocechat.png', import.meta.url).href
|
||||
} else if (name === 'satori' || name === 'Satori') {
|
||||
return new URL('@/assets/images/platform_logos/satori.png', import.meta.url).href
|
||||
} else if (name === 'misskey') {
|
||||
return new URL('@/assets/images/platform_logos/misskey.png', import.meta.url).href
|
||||
}
|
||||
},
|
||||
|
||||
getTutorialLink(platform_type) {
|
||||
let tutorial_map = {
|
||||
"qq_official_webhook": "https://docs.astrbot.app/deploy/platform/qqofficial/webhook.html",
|
||||
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
|
||||
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
||||
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
|
||||
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
|
||||
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
|
||||
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
|
||||
"wechatpadpro": "https://docs.astrbot.app/deploy/platform/wechat/wechatpadpro.html",
|
||||
"weixin_official_account": "https://docs.astrbot.app/deploy/platform/weixin-official-account.html",
|
||||
"discord": "https://docs.astrbot.app/deploy/platform/discord.html",
|
||||
"slack": "https://docs.astrbot.app/deploy/platform/slack.html",
|
||||
"kook": "https://docs.astrbot.app/deploy/platform/kook.html",
|
||||
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
|
||||
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
|
||||
"misskey": "https://docs.astrbot.app/deploy/platform/misskey.html",
|
||||
}
|
||||
return tutorial_map[platform_type] || "https://docs.astrbot.app";
|
||||
},
|
||||
|
||||
getPlatformDescription(template, name) {
|
||||
// special judge for community platforms
|
||||
if (name.includes('vocechat')) {
|
||||
return "由 @HikariFroya 提供。";
|
||||
} else if (name.includes('kook')) {
|
||||
return "由 @wuyan1003 提供。"
|
||||
}
|
||||
},
|
||||
|
||||
getConfig() {
|
||||
axios.get('/api/config/get').then((res) => {
|
||||
this.config_data = res.data.data.config;
|
||||
@@ -358,7 +266,7 @@ export default {
|
||||
});
|
||||
},
|
||||
|
||||
// 添加一个新方法来选择平台模板
|
||||
// 选择平台模板
|
||||
selectPlatformTemplate(name) {
|
||||
this.newSelectedPlatformName = name;
|
||||
this.showPlatformCfg = true;
|
||||
@@ -366,7 +274,6 @@ export default {
|
||||
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(
|
||||
this.metadata['platform_group']?.metadata?.platform?.config_template[name] || {}
|
||||
));
|
||||
this.showAddPlatformDialog = false;
|
||||
},
|
||||
|
||||
addFromDefaultConfigTmpl(index) {
|
||||
@@ -483,7 +390,7 @@ export default {
|
||||
this.oneBotEmptyTokenWarningResolve(continueWithWarning);
|
||||
this.oneBotEmptyTokenWarningResolve = null;
|
||||
}
|
||||
|
||||
|
||||
if (!continueWithWarning) {
|
||||
this.loading = false;
|
||||
}
|
||||
@@ -535,84 +442,4 @@ export default {
|
||||
padding: 20px;
|
||||
padding-top: 8px;
|
||||
}
|
||||
|
||||
.platform-selection-dialog .v-card-title {
|
||||
border-top-left-radius: 4px;
|
||||
border-top-right-radius: 4px;
|
||||
}
|
||||
|
||||
.platform-card {
|
||||
transition: all 0.3s ease;
|
||||
height: 100%;
|
||||
cursor: pointer;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.platform-card:hover {
|
||||
transform: translateY(-4px);
|
||||
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
|
||||
border-color: var(--v-primary-base);
|
||||
}
|
||||
|
||||
.platform-card-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
height: 100px;
|
||||
padding: 16px;
|
||||
position: relative;
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.platform-card-text {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.platform-card-title {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 4px;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.platform-card-description {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.platform-card-logo {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
width: 80px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.platform-logo-img {
|
||||
max-width: 60px;
|
||||
max-height: 60px;
|
||||
opacity: 0.6;
|
||||
object-fit: contain;
|
||||
}
|
||||
|
||||
.platform-logo-fallback {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border-radius: 50%;
|
||||
background-color: var(--v-primary-base);
|
||||
color: white;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
opacity: 0.3;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -155,86 +155,15 @@
|
||||
</v-container>
|
||||
|
||||
<!-- 添加提供商对话框 -->
|
||||
<v-dialog v-model="showAddProviderDialog" max-width="1100px" min-height="95%">
|
||||
<v-card class="provider-selection-dialog">
|
||||
<v-card-title class="bg-primary text-white py-3 px-4" style="display: flex; align-items: center;">
|
||||
<v-icon color="white" class="me-2">mdi-plus-circle</v-icon>
|
||||
<span>{{ tm('dialogs.addProvider.title') }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn icon variant="text" color="white" @click="showAddProviderDialog = false">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-4" style="overflow-y: auto;">
|
||||
<v-tabs v-model="activeProviderTab" grow slider-color="primary" bg-color="background">
|
||||
<v-tab value="chat_completion" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-message-text</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.basic') }}
|
||||
</v-tab>
|
||||
<v-tab value="speech_to_text" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-microphone-message</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.speechToText') }}
|
||||
</v-tab>
|
||||
<v-tab value="text_to_speech" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-volume-high</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.textToSpeech') }}
|
||||
</v-tab>
|
||||
<v-tab value="embedding" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-code-json</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.embedding') }}
|
||||
</v-tab>
|
||||
<v-tab value="rerank" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-compare-vertical</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.rerank') }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
|
||||
<v-window v-model="activeProviderTab" class="mt-4">
|
||||
<v-window-item v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
|
||||
:key="tabType"
|
||||
:value="tabType">
|
||||
<v-row class="mt-1">
|
||||
<v-col v-for="(template, name) in getTemplatesByType(tabType)"
|
||||
:key="name"
|
||||
cols="12" sm="6" md="4">
|
||||
<v-card variant="outlined" hover class="provider-card" @click="selectProviderTemplate(name)">
|
||||
<div class="provider-card-content">
|
||||
<div class="provider-card-text">
|
||||
<v-card-title class="provider-card-title">接入 {{ name }}</v-card-title>
|
||||
<v-card-text class="text-caption text-medium-emphasis provider-card-description">
|
||||
{{ getProviderDescription(template, name) }}
|
||||
</v-card-text>
|
||||
</div>
|
||||
<div class="provider-card-logo">
|
||||
<img :src="getProviderIcon(template.provider)" v-if="getProviderIcon(template.provider)" class="provider-logo-img">
|
||||
<div v-else class="provider-logo-fallback">
|
||||
{{ name[0].toUpperCase() }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</v-card>
|
||||
</v-col>
|
||||
<v-col v-if="Object.keys(getTemplatesByType(tabType)).length === 0" cols="12">
|
||||
<v-alert type="info" variant="tonal">
|
||||
{{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
|
||||
</v-alert>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-window-item>
|
||||
</v-window>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
<AddNewProvider
|
||||
v-model:show="showAddProviderDialog"
|
||||
:metadata="metadata"
|
||||
@select-template="selectProviderTemplate"
|
||||
/>
|
||||
|
||||
<!-- 配置对话框 -->
|
||||
<v-dialog v-model="showProviderCfg" width="900" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="bg-primary text-white py-3">
|
||||
<v-icon color="white" class="me-2">{{ updatingMode ? 'mdi-pencil' : 'mdi-plus' }}</v-icon>
|
||||
<span>{{ updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') }} {{ newSelectedProviderName }} {{ tm('dialogs.config.provider') }}</span>
|
||||
</v-card-title>
|
||||
|
||||
<v-card :title="updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') + ` ${newSelectedProviderName} ` + tm('dialogs.config.provider')">
|
||||
<v-card-text class="py-4">
|
||||
<AstrBotConfig
|
||||
:iterable="newSelectedProviderConfig"
|
||||
@@ -309,7 +238,9 @@ import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import ItemCard from '@/components/shared/ItemCard.vue';
|
||||
import AddNewProvider from '@/components/provider/AddNewProvider.vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import { getProviderIcon } from '@/utils/providerUtils';
|
||||
|
||||
export default {
|
||||
name: 'ProviderPage',
|
||||
@@ -317,7 +248,8 @@ export default {
|
||||
AstrBotConfig,
|
||||
WaitingForRestart,
|
||||
ConsoleDisplayer,
|
||||
ItemCard
|
||||
ItemCard,
|
||||
AddNewProvider
|
||||
},
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/provider');
|
||||
@@ -360,7 +292,6 @@ export default {
|
||||
|
||||
// 新增提供商对话框相关
|
||||
showAddProviderDialog: false,
|
||||
activeProviderTab: 'chat_completion',
|
||||
|
||||
// 添加提供商类型分类
|
||||
activeProviderTypeTab: 'all',
|
||||
@@ -372,6 +303,7 @@ export default {
|
||||
"googlegenai_chat_completion": "chat_completion",
|
||||
"zhipu_chat_completion": "chat_completion",
|
||||
"dify": "chat_completion",
|
||||
"coze": "chat_completion",
|
||||
"dashscope": "chat_completion",
|
||||
"openai_whisper_api": "speech_to_text",
|
||||
"openai_whisper_selfhost": "speech_to_text",
|
||||
@@ -474,6 +406,9 @@ export default {
|
||||
});
|
||||
},
|
||||
|
||||
// 从工具函数导入
|
||||
getProviderIcon,
|
||||
|
||||
// 获取空列表文本
|
||||
getEmptyText() {
|
||||
if (this.activeProviderTypeTab === 'all') {
|
||||
@@ -483,63 +418,11 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
// 按提供商类型获取模板列表
|
||||
getTemplatesByType(type) {
|
||||
const templates = this.metadata['provider_group']?.metadata?.provider?.config_template || {};
|
||||
const filtered = {};
|
||||
|
||||
for (const [name, template] of Object.entries(templates)) {
|
||||
if (template.provider_type === type) {
|
||||
filtered[name] = template;
|
||||
}
|
||||
}
|
||||
|
||||
return filtered;
|
||||
},
|
||||
|
||||
// 获取提供商类型对应的图标
|
||||
getProviderIcon(type) {
|
||||
const icons = {
|
||||
'openai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
|
||||
'azure': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/azure.svg',
|
||||
'xai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/xai.svg',
|
||||
'anthropic': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/anthropic.svg',
|
||||
'ollama': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ollama.svg',
|
||||
'google': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
|
||||
'deepseek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
|
||||
'modelscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/modelscope.svg',
|
||||
'zhipu': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
|
||||
'siliconflow': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
|
||||
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
|
||||
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
|
||||
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
|
||||
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
|
||||
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
|
||||
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
|
||||
'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',
|
||||
'minimax': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/minimax.svg',
|
||||
'302ai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
|
||||
'microsoft': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/microsoft.svg',
|
||||
'vllm': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/vllm.svg',
|
||||
};
|
||||
return icons[type] || '';
|
||||
},
|
||||
|
||||
// 获取Tab类型的中文名称
|
||||
getTabTypeName(tabType) {
|
||||
return this.messages.tabTypes[tabType] || tabType;
|
||||
},
|
||||
|
||||
// 获取提供商简介
|
||||
getProviderDescription(template, name) {
|
||||
if (name == 'OpenAI') {
|
||||
return this.tm('providers.description.openai', { type: template.type });
|
||||
} else if (name == 'vLLM Rerank') {
|
||||
return this.tm('providers.description.vllm_rerank', { type: template.type });
|
||||
}
|
||||
return this.tm('providers.description.default', { type: template.type });
|
||||
},
|
||||
|
||||
// 选择提供商模板
|
||||
selectProviderTemplate(name) {
|
||||
this.newSelectedProviderName = name;
|
||||
@@ -548,7 +431,6 @@ export default {
|
||||
this.newSelectedProviderConfig = JSON.parse(JSON.stringify(
|
||||
this.metadata['provider_group']?.metadata?.provider?.config_template[name] || {}
|
||||
));
|
||||
this.showAddProviderDialog = false;
|
||||
},
|
||||
|
||||
configExistingProvider(provider) {
|
||||
@@ -854,89 +736,6 @@ export default {
|
||||
padding-top: 8px;
|
||||
}
|
||||
|
||||
.provider-card {
|
||||
transition: all 0.3s ease;
|
||||
height: 100%;
|
||||
cursor: pointer;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.provider-card:hover {
|
||||
transform: translateY(-4px);
|
||||
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
|
||||
border-color: var(--v-primary-base);
|
||||
}
|
||||
|
||||
.provider-card-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
height: 100px;
|
||||
padding: 16px;
|
||||
position: relative;
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.provider-card-text {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.provider-card-title {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 4px;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.provider-card-description {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.provider-card-logo {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
width: 80px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.provider-logo-img {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
opacity: 0.6;
|
||||
object-fit: contain;
|
||||
}
|
||||
|
||||
.provider-logo-fallback {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border-radius: 50%;
|
||||
background-color: var(--v-primary-base);
|
||||
color: white;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
opacity: 0.3;
|
||||
}
|
||||
|
||||
.v-tabs {
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.v-window {
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.status-card {
|
||||
height: 120px;
|
||||
overflow-y: auto;
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
<v-card flat>
|
||||
<v-card-title class="d-flex align-center py-3 px-4">
|
||||
<span class="text-h4">{{ tm('sessions.activeSessions') }}</span>
|
||||
<v-chip size="small" class="ml-2">{{ sessions.length }} {{ tm('sessions.sessionCount') }}</v-chip>
|
||||
<v-chip size="small" class="ml-2">{{ totalItems }} {{ tm('sessions.sessionCount') }}</v-chip>
|
||||
<v-row class="me-4 ms-4" dense>
|
||||
<v-text-field v-model="searchQuery" prepend-inner-icon="mdi-magnify" :label="tm('search.placeholder')"
|
||||
hide-details clearable variant="solo-filled" flat class="me-4" density="compact"></v-text-field>
|
||||
hide-details clearable variant="solo-filled" flat class="me-4" density="compact" @update:model-value="handleSearchChange"></v-text-field>
|
||||
<v-select v-model="filterPlatform" :items="platformOptions" :label="tm('search.platformFilter')"
|
||||
hide-details clearable variant="solo-filled" flat class="me-4" style="max-width: 150px;"
|
||||
density="compact"></v-select>
|
||||
density="compact" @update:model-value="handlePlatformChange"></v-select>
|
||||
</v-row>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="refreshSessions" :loading="loading"
|
||||
size="small">
|
||||
@@ -22,8 +22,17 @@
|
||||
|
||||
<v-card-text class="pa-0">
|
||||
<!-- 会话列表 -->
|
||||
<v-data-table :headers="headers" :items="filteredSessions" :loading="loading" :items-per-page="itemsPerPage" density="compact"
|
||||
class="elevation-0" style="font-size: 11px;">
|
||||
<v-data-table-server
|
||||
:headers="headers"
|
||||
:items="sessions"
|
||||
:loading="loading"
|
||||
:items-per-page="itemsPerPage"
|
||||
:page="currentPage"
|
||||
:items-length="totalItems"
|
||||
@update:options="handlePaginationUpdate"
|
||||
density="compact"
|
||||
class="elevation-0"
|
||||
style="font-size: 11px;">
|
||||
|
||||
<!-- 会话启停 -->
|
||||
<template v-slot:item.session_enabled="{ item }">
|
||||
@@ -141,6 +150,17 @@
|
||||
</v-btn>
|
||||
</template>
|
||||
|
||||
<!-- 操作按钮 -->
|
||||
<template v-slot:item.actions="{ item }">
|
||||
<v-btn size="x-small" variant="tonal" color="error" @click="deleteSession(item)"
|
||||
:loading="item.deleting" icon>
|
||||
<v-icon>mdi-delete</v-icon>
|
||||
<v-tooltip activator="parent" location="top">
|
||||
{{ tm('buttons.delete') }}
|
||||
</v-tooltip>
|
||||
</v-btn>
|
||||
</template>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<template v-slot:no-data>
|
||||
<div class="text-center py-8">
|
||||
@@ -149,7 +169,7 @@
|
||||
<div class="text-body-2 text-grey-500">{{ tm('sessions.noActiveSessionsDesc') }}</div>
|
||||
</div>
|
||||
</template>
|
||||
</v-data-table>
|
||||
</v-data-table-server>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
@@ -325,6 +345,7 @@
|
||||
|
||||
<script>
|
||||
import axios from 'axios'
|
||||
import { debounce } from 'lodash'
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables'
|
||||
|
||||
export default {
|
||||
@@ -346,7 +367,10 @@ export default {
|
||||
filterPlatform: null,
|
||||
|
||||
// 分页相关
|
||||
currentPage: 1,
|
||||
itemsPerPage: 10,
|
||||
totalItems: 0,
|
||||
totalPages: 0,
|
||||
|
||||
// 可用选项
|
||||
availablePersonas: [],
|
||||
@@ -409,32 +433,10 @@ export default {
|
||||
{ title: this.tm('table.headers.llmStatus'), key: 'llm_enabled', sortable: false, minWidth: '120px' },
|
||||
{ title: this.tm('table.headers.ttsStatus'), key: 'tts_enabled', sortable: false, minWidth: '120px' },
|
||||
{ title: this.tm('table.headers.pluginManagement'), key: 'plugins', sortable: false, minWidth: '120px' },
|
||||
{ title: this.tm('table.headers.actions'), key: 'actions', sortable: false, minWidth: '100px' },
|
||||
]
|
||||
},
|
||||
|
||||
// 懒加载过滤会话 - 使用客户端分页
|
||||
filteredSessions() {
|
||||
let filtered = this.sessions;
|
||||
|
||||
// 搜索筛选
|
||||
if (this.searchQuery) {
|
||||
const query = this.searchQuery.toLowerCase();
|
||||
filtered = filtered.filter(session =>
|
||||
session.session_name.toLowerCase().includes(query) ||
|
||||
session.platform.toLowerCase().includes(query) ||
|
||||
session.persona_name?.toLowerCase().includes(query) ||
|
||||
session.chat_provider_name?.toLowerCase().includes(query)
|
||||
);
|
||||
}
|
||||
|
||||
// 平台筛选
|
||||
if (this.filterPlatform) {
|
||||
filtered = filtered.filter(session => session.platform === this.filterPlatform);
|
||||
}
|
||||
|
||||
return filtered;
|
||||
},
|
||||
|
||||
platformOptions() {
|
||||
const platforms = [...new Set(this.sessions.map(s => s.platform))];
|
||||
return platforms.map(p => ({ title: p, value: p }));
|
||||
@@ -481,18 +483,39 @@ export default {
|
||||
async loadSessions() {
|
||||
this.loading = true;
|
||||
try {
|
||||
const response = await axios.get('/api/session/list');
|
||||
const params = {
|
||||
page: this.currentPage,
|
||||
page_size: this.itemsPerPage
|
||||
};
|
||||
|
||||
// 添加搜索和平台筛选参数
|
||||
if (this.searchQuery) {
|
||||
params.search = this.searchQuery;
|
||||
}
|
||||
if (this.filterPlatform) {
|
||||
params.platform = this.filterPlatform;
|
||||
}
|
||||
|
||||
const response = await axios.get('/api/session/list', { params });
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data;
|
||||
this.sessions = data.sessions.map(session => ({
|
||||
...session,
|
||||
updating: false, // 添加更新状态标志
|
||||
loadingPlugins: false // 添加插件加载状态标志
|
||||
loadingPlugins: false, // 添加插件加载状态标志
|
||||
deleting: false // 添加删除状态标志
|
||||
}));
|
||||
this.availablePersonas = data.available_personas;
|
||||
this.availableChatProviders = data.available_chat_providers;
|
||||
this.availableSttProviders = data.available_stt_providers;
|
||||
this.availableTtsProviders = data.available_tts_providers;
|
||||
|
||||
// 处理分页信息
|
||||
if (data.pagination) {
|
||||
this.totalItems = data.pagination.total;
|
||||
this.totalPages = data.pagination.total_pages;
|
||||
this.currentPage = data.pagination.page;
|
||||
}
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadSessionsError'));
|
||||
}
|
||||
@@ -508,60 +531,131 @@ export default {
|
||||
},
|
||||
|
||||
async updatePersona(session, personaName) {
|
||||
session.updating = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/update_persona', {
|
||||
session_id: session.session_id,
|
||||
persona_name: personaName
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
session.persona_id = personaName;
|
||||
session.persona_name = personaName === '[%None]' ? this.tm('persona.none') :
|
||||
return this._updateSession('persona', session, { persona_name: personaName }, (s, success) => {
|
||||
if (success) {
|
||||
s.persona_id = personaName;
|
||||
s.persona_name = personaName === '[%None]' ? this.tm('persona.none') :
|
||||
this.availablePersonas.find(p => p.name === personaName)?.name || personaName;
|
||||
this.showSuccess(this.tm('messages.personaUpdateSuccess'));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.personaUpdateError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.personaUpdateError'));
|
||||
}
|
||||
session.updating = false;
|
||||
});
|
||||
},
|
||||
|
||||
async updateProvider(session, providerId, providerType) {
|
||||
session.updating = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/update_provider', {
|
||||
session_id: session.session_id,
|
||||
provider_id: providerId,
|
||||
provider_type: providerType
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
// 更新本地数据
|
||||
return this._updateSession('provider', session, {
|
||||
provider_id: providerId,
|
||||
provider_type: providerType
|
||||
}, (s, success) => {
|
||||
if (success) {
|
||||
if (providerType === 'chat_completion') {
|
||||
session.chat_provider_id = providerId;
|
||||
s.chat_provider_id = providerId;
|
||||
const provider = this.availableChatProviders.find(p => p.id === providerId);
|
||||
session.chat_provider_name = provider?.name || providerId;
|
||||
s.chat_provider_name = provider?.name || providerId;
|
||||
} else if (providerType === 'speech_to_text') {
|
||||
session.stt_provider_id = providerId;
|
||||
s.stt_provider_id = providerId;
|
||||
const provider = this.availableSttProviders.find(p => p.id === providerId);
|
||||
session.stt_provider_name = provider?.name || providerId;
|
||||
s.stt_provider_name = provider?.name || providerId;
|
||||
} else if (providerType === 'text_to_speech') {
|
||||
session.tts_provider_id = providerId;
|
||||
s.tts_provider_id = providerId;
|
||||
const provider = this.availableTtsProviders.find(p => p.id === providerId);
|
||||
session.tts_provider_name = provider?.name || providerId;
|
||||
s.tts_provider_name = provider?.name || providerId;
|
||||
}
|
||||
this.showSuccess(this.tm('messages.providerUpdateSuccess'));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.providerUpdateError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.providerUpdateError'));
|
||||
} session.updating = false;
|
||||
});
|
||||
},
|
||||
|
||||
async updateLLM(session, enabled) {
|
||||
return this._updateSession('llm', session, { enabled }, (s, success) => {
|
||||
if (success) s.llm_enabled = enabled;
|
||||
});
|
||||
},
|
||||
|
||||
async updateTTS(session, enabled) {
|
||||
return this._updateSession('tts', session, { enabled }, (s, success) => {
|
||||
if (success) s.tts_enabled = enabled;
|
||||
});
|
||||
},
|
||||
|
||||
// 通用的更新会话方法,支持单个和批量操作
|
||||
async _updateSession(type, sessionOrSessions, params, updateLocalData) {
|
||||
const isBatch = Array.isArray(sessionOrSessions);
|
||||
|
||||
if (!isBatch) {
|
||||
// 单个操作
|
||||
const session = sessionOrSessions;
|
||||
session.updating = true;
|
||||
|
||||
try {
|
||||
const payload = {
|
||||
is_batch: false,
|
||||
session_id: session.session_id,
|
||||
...params
|
||||
};
|
||||
|
||||
const response = await axios.post(`/api/session/update_${type}`, payload);
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
updateLocalData(session, true);
|
||||
this.showSuccess(this.tm(`messages.${type}UpdateSuccess`));
|
||||
return { success: true };
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm(`messages.${type}UpdateError`));
|
||||
return { success: false, error: response.data.message };
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm(`messages.${type}UpdateError`));
|
||||
return { success: false, error: error.message };
|
||||
} finally {
|
||||
session.updating = false;
|
||||
}
|
||||
} else {
|
||||
// 批量操作
|
||||
const sessions = sessionOrSessions;
|
||||
const sessionIds = sessions.map(s => s.session_id);
|
||||
|
||||
try {
|
||||
const payload = {
|
||||
is_batch: true,
|
||||
session_ids: sessionIds,
|
||||
...params
|
||||
};
|
||||
|
||||
const response = await axios.post(`/api/session/update_${type}`, payload);
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data;
|
||||
|
||||
// 更新成功的会话的本地数据
|
||||
sessions.forEach(session => {
|
||||
const wasSuccessful = !data.error_sessions || !data.error_sessions.includes(session.session_id);
|
||||
updateLocalData(session, wasSuccessful);
|
||||
});
|
||||
|
||||
return {
|
||||
success: true,
|
||||
successCount: data.success_count || 0,
|
||||
errorCount: data.error_count || 0,
|
||||
errorSessions: data.error_sessions || []
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
success: false,
|
||||
error: response.data.message,
|
||||
errorCount: sessionIds.length,
|
||||
successCount: 0
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: error.response?.data?.message || error.message,
|
||||
errorCount: sessionIds.length,
|
||||
successCount: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
// 单独的会话状态更新方法(不支持批量操作)
|
||||
async updateSessionStatus(session, enabled) {
|
||||
session.updating = true;
|
||||
try {
|
||||
@@ -572,47 +666,9 @@ export default {
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
session.session_enabled = enabled;
|
||||
this.showSuccess(this.tm('messages.sessionStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
session.updating = false;
|
||||
},
|
||||
|
||||
async updateLLM(session, enabled) {
|
||||
session.updating = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/update_llm', {
|
||||
session_id: session.session_id,
|
||||
enabled: enabled
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
session.llm_enabled = enabled;
|
||||
this.showSuccess(this.tm('messages.llmStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
session.updating = false;
|
||||
},
|
||||
|
||||
async updateTTS(session, enabled) {
|
||||
session.updating = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/update_tts', {
|
||||
session_id: session.session_id,
|
||||
enabled: enabled
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
session.tts_enabled = enabled;
|
||||
this.showSuccess(this.tm('messages.ttsStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
|
||||
this.showSuccess(this.tm('messages.sessionStatusSuccess', {
|
||||
status: enabled ? this.tm('status.enabled') : this.tm('status.disabled')
|
||||
}));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
@@ -628,60 +684,120 @@ export default {
|
||||
}
|
||||
|
||||
this.batchUpdating = true;
|
||||
let successCount = 0;
|
||||
let errorCount = 0;
|
||||
let totalSuccessCount = 0;
|
||||
let totalErrorCount = 0;
|
||||
let allErrorSessions = [];
|
||||
|
||||
// 使用过滤后的会话数据进行批量操作
|
||||
for (const session of this.filteredSessions) {
|
||||
try {
|
||||
// 批量更新人格
|
||||
if (this.batchPersona) {
|
||||
await this.updatePersona(session, this.batchPersona);
|
||||
successCount++;
|
||||
}
|
||||
const sessions = this.sessions;
|
||||
|
||||
// 批量更新 Chat Provider
|
||||
if (this.batchChatProvider) {
|
||||
await this.updateProvider(session, this.batchChatProvider, 'chat_completion');
|
||||
successCount++;
|
||||
}
|
||||
try {
|
||||
// 定义批量操作任务
|
||||
const batchTasks = [];
|
||||
|
||||
// 批量更新 STT Provider
|
||||
if (this.batchSttProvider) {
|
||||
await this.updateProvider(session, this.batchSttProvider, 'speech_to_text');
|
||||
successCount++;
|
||||
}
|
||||
|
||||
// 批量更新 TTS Provider
|
||||
if (this.batchTtsProvider) {
|
||||
await this.updateProvider(session, this.batchTtsProvider, 'text_to_speech');
|
||||
successCount++;
|
||||
}
|
||||
|
||||
// 批量更新 LLM 状态
|
||||
if (this.batchLlmStatus !== null) {
|
||||
await this.updateLLM(session, this.batchLlmStatus);
|
||||
successCount++;
|
||||
}
|
||||
|
||||
// 批量更新 TTS 状态
|
||||
if (this.batchTtsStatus !== null) {
|
||||
await this.updateTTS(session, this.batchTtsStatus);
|
||||
successCount++;
|
||||
}
|
||||
} catch (error) {
|
||||
errorCount++;
|
||||
if (this.batchPersona) {
|
||||
batchTasks.push({
|
||||
type: 'persona',
|
||||
params: { persona_name: this.batchPersona }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchChatProvider) {
|
||||
batchTasks.push({
|
||||
type: 'provider',
|
||||
params: { provider_id: this.batchChatProvider, provider_type: 'chat_completion' }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchSttProvider) {
|
||||
batchTasks.push({
|
||||
type: 'provider',
|
||||
params: { provider_id: this.batchSttProvider, provider_type: 'speech_to_text' }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchTtsProvider) {
|
||||
batchTasks.push({
|
||||
type: 'provider',
|
||||
params: { provider_id: this.batchTtsProvider, provider_type: 'text_to_speech' }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchLlmStatus !== null) {
|
||||
batchTasks.push({
|
||||
type: 'llm',
|
||||
params: { enabled: this.batchLlmStatus }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchTtsStatus !== null) {
|
||||
batchTasks.push({
|
||||
type: 'tts',
|
||||
params: { enabled: this.batchTtsStatus }
|
||||
});
|
||||
}
|
||||
|
||||
// 执行所有批量任务
|
||||
for (const task of batchTasks) {
|
||||
let updateLocalData;
|
||||
|
||||
// 定义本地数据更新逻辑
|
||||
switch (task.type) {
|
||||
case 'persona':
|
||||
updateLocalData = (s, success) => {
|
||||
if (success) s.persona_id = task.params.persona_name;
|
||||
};
|
||||
break;
|
||||
case 'provider':
|
||||
updateLocalData = (s, success) => {
|
||||
if (!success) return;
|
||||
const { provider_id, provider_type } = task.params;
|
||||
if (provider_type === 'chat_completion') {
|
||||
s.chat_provider_id = provider_id;
|
||||
} else if (provider_type === 'speech_to_text') {
|
||||
s.stt_provider_id = provider_id;
|
||||
} else if (provider_type === 'text_to_speech') {
|
||||
s.tts_provider_id = provider_id;
|
||||
}
|
||||
};
|
||||
break;
|
||||
case 'llm':
|
||||
updateLocalData = (s, success) => {
|
||||
if (success) s.llm_enabled = task.params.enabled;
|
||||
};
|
||||
break;
|
||||
case 'tts':
|
||||
updateLocalData = (s, success) => {
|
||||
if (success) s.tts_enabled = task.params.enabled;
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
const result = await this._updateSession(task.type, sessions, task.params, updateLocalData);
|
||||
|
||||
totalSuccessCount += result.successCount || 0;
|
||||
totalErrorCount += result.errorCount || 0;
|
||||
if (result.errorSessions) {
|
||||
allErrorSessions.push(...result.errorSessions);
|
||||
}
|
||||
}
|
||||
|
||||
// 显示最终结果
|
||||
if (totalErrorCount === 0) {
|
||||
this.showSuccess(this.tm('messages.batchUpdateSuccess', { count: totalSuccessCount }));
|
||||
} else {
|
||||
const uniqueErrorSessions = [...new Set(allErrorSessions)];
|
||||
this.showError(this.tm('messages.batchUpdatePartial', {
|
||||
success: totalSuccessCount,
|
||||
error: uniqueErrorSessions.length
|
||||
}));
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
this.showError(this.tm('messages.batchUpdateError'));
|
||||
}
|
||||
|
||||
this.batchUpdating = false;
|
||||
|
||||
if (errorCount === 0) {
|
||||
this.showSuccess(this.tm('messages.batchUpdateSuccess', { count: successCount }));
|
||||
} else {
|
||||
this.showError(this.tm('messages.batchUpdatePartial', { success: successCount, error: errorCount }));
|
||||
}
|
||||
|
||||
// 清空批量设置
|
||||
this.batchPersona = null;
|
||||
this.batchChatProvider = null;
|
||||
@@ -797,6 +913,57 @@ export default {
|
||||
this.snackbarColor = 'error';
|
||||
this.snackbar = true;
|
||||
},
|
||||
|
||||
async deleteSession(session) {
|
||||
const confirmMessage = this.tm('deleteConfirm.message', {
|
||||
sessionName: session.session_name || session.session_id
|
||||
}) + '\n\n' + this.tm('deleteConfirm.warning');
|
||||
|
||||
if (!confirm(confirmMessage)) {
|
||||
return;
|
||||
}
|
||||
|
||||
session.deleting = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/delete', {
|
||||
session_id: session.session_id
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(response.data.data.message || this.tm('messages.deleteSuccess'));
|
||||
// 从列表中移除已删除的会话
|
||||
const index = this.sessions.findIndex(s => s.session_id === session.session_id);
|
||||
if (index > -1) {
|
||||
this.sessions.splice(index, 1);
|
||||
}
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.deleteError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.deleteError'));
|
||||
}
|
||||
|
||||
session.deleting = false;
|
||||
},
|
||||
|
||||
// 处理分页更新事件
|
||||
handlePaginationUpdate(options) {
|
||||
this.currentPage = options.page;
|
||||
this.itemsPerPage = options.itemsPerPage;
|
||||
this.loadSessions();
|
||||
},
|
||||
|
||||
// 处理搜索变化
|
||||
handleSearchChange: debounce(function() {
|
||||
this.currentPage = 1; // 重置到第一页
|
||||
this.loadSessions();
|
||||
}, 300),
|
||||
|
||||
// 处理平台筛选变化
|
||||
handlePlatformChange() {
|
||||
this.currentPage = 1; // 重置到第一页
|
||||
this.loadSessions();
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -601,8 +601,13 @@ export default {
|
||||
checkPlugin() {
|
||||
axios.get('/api/plugin/get?name=astrbot_plugin_knowledge_base')
|
||||
.then(response => {
|
||||
if (response.data.status !== 'ok') {
|
||||
if (response.data.status !== 'ok' || response.data.data.length === 0) {
|
||||
this.showSnackbar(this.tm('messages.pluginNotAvailable'), 'error');
|
||||
return
|
||||
}
|
||||
if (!response.data.data[0].activated) {
|
||||
this.showSnackbar(this.tm('messages.pluginNotActivated'), 'error');
|
||||
return
|
||||
}
|
||||
if (response.data.data.length > 0) {
|
||||
this.installed = true;
|
||||
@@ -708,6 +713,10 @@ export default {
|
||||
getKBCollections() {
|
||||
axios.get('/api/plug/alkaid/kb/collections')
|
||||
.then(response => {
|
||||
if (response.data.status !== 'ok') {
|
||||
this.showSnackbar(response.data.message || this.tm('messages.getKnowledgeBaseListFailed'), 'error');
|
||||
return;
|
||||
}
|
||||
this.kbCollections = response.data.data;
|
||||
})
|
||||
.catch(error => {
|
||||
|
||||
31
packages/astrbot/commands/__init__.py
Normal file
31
packages/astrbot/commands/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Commands module
|
||||
|
||||
from .help import HelpCommand
|
||||
from .llm import LLMCommands
|
||||
from .tool import ToolCommands
|
||||
from .plugin import PluginCommands
|
||||
from .admin import AdminCommands
|
||||
from .conversation import ConversationCommands
|
||||
from .provider import ProviderCommands
|
||||
from .persona import PersonaCommands
|
||||
from .alter_cmd import AlterCmdCommands
|
||||
from .setunset import SetUnsetCommands
|
||||
from .t2i import T2ICommand
|
||||
from .tts import TTSCommand
|
||||
from .sid import SIDCommand
|
||||
|
||||
__all__ = [
|
||||
"HelpCommand",
|
||||
"LLMCommands",
|
||||
"ToolCommands",
|
||||
"PluginCommands",
|
||||
"AdminCommands",
|
||||
"ConversationCommands",
|
||||
"ProviderCommands",
|
||||
"PersonaCommands",
|
||||
"AlterCmdCommands",
|
||||
"SetUnsetCommands",
|
||||
"T2ICommand",
|
||||
"TTSCommand",
|
||||
"SIDCommand",
|
||||
]
|
||||
76
packages/astrbot/commands/admin.py
Normal file
76
packages/astrbot/commands/admin.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, MessageChain
|
||||
from astrbot.core.utils.io import download_dashboard
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
|
||||
class AdminCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str = ""):
|
||||
"""授权管理员。op <admin_id>"""
|
||||
if not admin_id:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /op <id> 授权管理员;/deop <id> 取消管理员。可通过 /sid 获取 ID。"
|
||||
)
|
||||
)
|
||||
return
|
||||
self.context.get_config()["admins_id"].append(str(admin_id))
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("授权成功。"))
|
||||
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str = ""):
|
||||
"""取消授权管理员。deop <admin_id>"""
|
||||
if not admin_id:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /deop <id> 取消管理员。可通过 /sid 获取 ID。"
|
||||
)
|
||||
)
|
||||
return
|
||||
try:
|
||||
self.context.get_config()["admins_id"].remove(str(admin_id))
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("取消授权成功。"))
|
||||
except ValueError:
|
||||
event.set_result(
|
||||
MessageEventResult().message("此用户 ID 不在管理员名单内。")
|
||||
)
|
||||
|
||||
async def wl(self, event: AstrMessageEvent, sid: str = ""):
|
||||
"""添加白名单。wl <sid>"""
|
||||
if not sid:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /wl <id> 添加白名单;/dwl <id> 删除白名单。可通过 /sid 获取 ID。"
|
||||
)
|
||||
)
|
||||
return
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
cfg["platform_settings"]["id_whitelist"].append(str(sid))
|
||||
cfg.save_config()
|
||||
event.set_result(MessageEventResult().message("添加白名单成功。"))
|
||||
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str = ""):
|
||||
"""删除白名单。dwl <sid>"""
|
||||
if not sid:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /dwl <id> 删除白名单。可通过 /sid 获取 ID。"
|
||||
)
|
||||
)
|
||||
return
|
||||
try:
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
cfg["platform_settings"]["id_whitelist"].remove(str(sid))
|
||||
cfg.save_config()
|
||||
event.set_result(MessageEventResult().message("删除白名单成功。"))
|
||||
except ValueError:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
await event.send(MessageChain().message("正在尝试更新管理面板..."))
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
await event.send(MessageChain().message("管理面板更新完成。"))
|
||||
188
packages/astrbot/commands/alter_cmd.py
Normal file
188
packages/astrbot/commands/alter_cmd.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RstScene(Enum):
|
||||
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
|
||||
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
|
||||
PRIVATE = ("private", "私聊")
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
return self.value[0]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.value[1]
|
||||
|
||||
@classmethod
|
||||
def from_index(cls, index: int) -> "RstScene":
|
||||
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
|
||||
return mapping[index]
|
||||
|
||||
|
||||
class AlterCmdCommands(CommandParserMixin):
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def update_reset_permission(self, scene_key: str, perm_type: str):
|
||||
"""更新reset命令在特定场景下的权限设置"""
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_cfg = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_cfg.get("reset", {})
|
||||
reset_cfg[scene_key] = perm_type
|
||||
plugin_cfg["reset"] = reset_cfg
|
||||
alter_cmd_cfg["astrbot"] = plugin_cfg
|
||||
await sp.global_put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
async def alter_cmd(self, event: AstrMessageEvent):
|
||||
token = self.parse_commands(event.message_str)
|
||||
if token.len < 3:
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
"该指令用于设置指令或指令组的权限。\n"
|
||||
"格式: /alter_cmd <cmd_name> <admin/member>\n"
|
||||
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
|
||||
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
|
||||
"/alter_cmd reset config 打开 reset 权限配置"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
cmd_name = " ".join(token.tokens[1:-1])
|
||||
cmd_type = token.get(-1)
|
||||
|
||||
if cmd_name == "reset" and cmd_type == "config":
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_.get("reset", {})
|
||||
|
||||
group_unique_on = reset_cfg.get("group_unique_on", "admin")
|
||||
group_unique_off = reset_cfg.get("group_unique_off", "admin")
|
||||
private = reset_cfg.get("private", "member")
|
||||
|
||||
config_menu = f"""reset命令权限细粒度配置
|
||||
当前配置:
|
||||
1. 群聊+会话隔离开: {group_unique_on}
|
||||
2. 群聊+会话隔离关: {group_unique_off}
|
||||
3. 私聊: {private}
|
||||
修改指令格式:
|
||||
/alter_cmd reset scene <场景编号> <admin/member>
|
||||
例如: /alter_cmd reset scene 2 member"""
|
||||
await event.send(MessageChain().message(config_menu))
|
||||
return
|
||||
|
||||
if cmd_name == "reset" and cmd_type == "scene" and token.len >= 4:
|
||||
scene_num = token.get(3)
|
||||
perm_type = token.get(4)
|
||||
|
||||
if scene_num is None or perm_type is None:
|
||||
await event.send(MessageChain().message("场景编号和权限类型不能为空"))
|
||||
return
|
||||
|
||||
if not scene_num.isdigit() or int(scene_num) < 1 or int(scene_num) > 3:
|
||||
await event.send(
|
||||
MessageChain().message("场景编号必须是 1-3 之间的数字")
|
||||
)
|
||||
return
|
||||
|
||||
if perm_type not in ["admin", "member"]:
|
||||
await event.send(
|
||||
MessageChain().message("权限类型错误,只能是 admin 或 member")
|
||||
)
|
||||
return
|
||||
|
||||
scene_num = int(scene_num)
|
||||
scene = RstScene.from_index(scene_num)
|
||||
scene_key = scene.key
|
||||
|
||||
await self.update_reset_permission(scene_key, perm_type)
|
||||
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if cmd_type not in ["admin", "member"]:
|
||||
await event.send(
|
||||
MessageChain().message("指令类型错误,可选类型有 admin, member")
|
||||
)
|
||||
return
|
||||
|
||||
# 查找指令
|
||||
found_command = None
|
||||
cmd_group = False
|
||||
for handler in star_handlers_registry:
|
||||
assert isinstance(handler, StarHandlerMetadata)
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, CommandFilter):
|
||||
if filter_.equals(cmd_name):
|
||||
found_command = handler
|
||||
break
|
||||
elif isinstance(filter_, CommandGroupFilter):
|
||||
if filter_.equals(cmd_name):
|
||||
found_command = handler
|
||||
cmd_group = True
|
||||
break
|
||||
|
||||
if not found_command:
|
||||
await event.send(MessageChain().message("未找到该指令"))
|
||||
return
|
||||
|
||||
found_plugin = star_map[found_command.handler_module_path]
|
||||
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
|
||||
cfg = plugin_.get(found_command.handler_name, {})
|
||||
cfg["permission"] = cmd_type
|
||||
plugin_[found_command.handler_name] = cfg
|
||||
alter_cmd_cfg[found_plugin.name] = plugin_
|
||||
|
||||
await sp.global_put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
# 注入权限过滤器
|
||||
found_permission_filter = False
|
||||
for filter_ in found_command.event_filters:
|
||||
if isinstance(filter_, PermissionTypeFilter):
|
||||
if cmd_type == "admin":
|
||||
import astrbot.api.event.filter as filter
|
||||
|
||||
filter_.permission_type = filter.PermissionType.ADMIN
|
||||
else:
|
||||
import astrbot.api.event.filter as filter
|
||||
|
||||
filter_.permission_type = filter.PermissionType.MEMBER
|
||||
found_permission_filter = True
|
||||
break
|
||||
if not found_permission_filter:
|
||||
import astrbot.api.event.filter as filter
|
||||
|
||||
found_command.event_filters.insert(
|
||||
0,
|
||||
PermissionTypeFilter(
|
||||
filter.PermissionType.ADMIN
|
||||
if cmd_type == "admin"
|
||||
else filter.PermissionType.MEMBER
|
||||
),
|
||||
)
|
||||
cmd_group_str = "指令组" if cmd_group else "指令"
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。"
|
||||
)
|
||||
)
|
||||
441
packages/astrbot/commands/conversation.py
Normal file
441
packages/astrbot/commands/conversation.py
Normal file
@@ -0,0 +1,441 @@
|
||||
import datetime
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.provider.sources.dify_source import ProviderDify
|
||||
from astrbot.core.provider.sources.coze_source import ProviderCoze
|
||||
from astrbot.api import sp, logger
|
||||
from ..long_term_memory import LongTermMemory
|
||||
from typing import Union
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RstScene(Enum):
|
||||
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
|
||||
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
|
||||
PRIVATE = ("private", "私聊")
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
return self.value[0]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.value[1]
|
||||
|
||||
@classmethod
|
||||
def from_index(cls, index: int) -> "RstScene":
|
||||
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
|
||||
return mapping[index]
|
||||
|
||||
@classmethod
|
||||
def get_scene(cls, is_group: bool, is_unique_session: bool) -> "RstScene":
|
||||
if is_group:
|
||||
return cls.GROUP_UNIQUE_ON if is_unique_session else cls.GROUP_UNIQUE_OFF
|
||||
return cls.PRIVATE
|
||||
|
||||
|
||||
class ConversationCommands:
|
||||
def __init__(self, context: star.Context, ltm: LongTermMemory | None = None):
|
||||
self.context = context
|
||||
self.ltm = ltm
|
||||
|
||||
def ltm_enabled(self, event: AstrMessageEvent):
|
||||
if not self.ltm:
|
||||
return False
|
||||
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_ltm_settings"
|
||||
]
|
||||
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
|
||||
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
"""重置 LLM 会话"""
|
||||
|
||||
is_unique_session = self.context.get_config()["platform_settings"][
|
||||
"unique_session"
|
||||
]
|
||||
is_group = bool(message.get_group_id())
|
||||
|
||||
scene = RstScene.get_scene(is_group, is_unique_session)
|
||||
|
||||
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {})
|
||||
plugin_config = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_config.get("reset", {})
|
||||
|
||||
required_perm = reset_cfg.get(
|
||||
scene.key, "admin" if is_group and not is_unique_session else "member"
|
||||
)
|
||||
|
||||
if required_perm == "admin" and message.role != "admin":
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"在{scene.name}场景下,reset命令需要管理员权限,"
|
||||
f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前未处于对话状态,请 /switch 切换或者 /new 创建。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.update_conversation(
|
||||
message.unified_msg_origin, cid, []
|
||||
)
|
||||
|
||||
ret = "清除会话 LLM 聊天历史成功。"
|
||||
if self.ltm and self.ltm_enabled(message):
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
"""查看对话记录"""
|
||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
|
||||
size_per_page = 6
|
||||
|
||||
conv_mgr = self.context.conversation_manager
|
||||
umo = message.unified_msg_origin
|
||||
session_curr_cid = await conv_mgr.get_curr_conversation_id(umo)
|
||||
|
||||
if not session_curr_cid:
|
||||
session_curr_cid = await conv_mgr.new_conversation(
|
||||
umo, message.get_platform_id()
|
||||
)
|
||||
|
||||
contexts, total_pages = await conv_mgr.get_human_readable_context(
|
||||
umo, session_curr_cid, page, size_per_page
|
||||
)
|
||||
|
||||
history = ""
|
||||
for context in contexts:
|
||||
if len(context) > 150:
|
||||
context = context[:150] + "..."
|
||||
history += f"{context}\n"
|
||||
|
||||
ret = (
|
||||
f"当前对话历史记录:"
|
||||
f"{history or '无历史记录'}\n\n"
|
||||
f"第 {page} 页 | 共 {total_pages} 页\n"
|
||||
f"*输入 /history 2 跳转到第 2 页"
|
||||
)
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1):
|
||||
"""查看对话列表"""
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
"""原有的Dify处理逻辑保持不变"""
|
||||
ret = "Dify 对话列表:\n"
|
||||
assert isinstance(provider, ProviderDify)
|
||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||
idx = 1
|
||||
for conv in data["data"]:
|
||||
ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime(
|
||||
"%m-%d %H:%M"
|
||||
)
|
||||
ret += f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n"
|
||||
idx += 1
|
||||
if idx == 1:
|
||||
ret += "没有找到任何对话。"
|
||||
dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None)
|
||||
ret += f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。"
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
return
|
||||
|
||||
size_per_page = 6
|
||||
"""获取所有对话列表"""
|
||||
conversations_all = await self.context.conversation_manager.get_conversations(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
"""计算总页数"""
|
||||
total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page
|
||||
"""确保页码有效"""
|
||||
page = max(1, min(page, total_pages))
|
||||
"""分页处理"""
|
||||
start_idx = (page - 1) * size_per_page
|
||||
end_idx = start_idx + size_per_page
|
||||
conversations_paged = conversations_all[start_idx:end_idx]
|
||||
|
||||
ret = "对话列表:\n---\n"
|
||||
"""全局序号从当前页的第一个开始"""
|
||||
global_index = start_idx + 1
|
||||
|
||||
"""生成所有对话的标题字典"""
|
||||
_titles = {}
|
||||
for conv in conversations_all:
|
||||
title = conv.title if conv.title else "新对话"
|
||||
_titles[conv.cid] = title
|
||||
|
||||
"""遍历分页后的对话生成列表显示"""
|
||||
for conv in conversations_paged:
|
||||
persona_id = conv.persona_id
|
||||
if not persona_id or persona_id == "[%None]":
|
||||
persona = await self.context.persona_manager.get_default_persona_v3(
|
||||
umo=message.unified_msg_origin
|
||||
)
|
||||
persona_id = persona["name"]
|
||||
title = _titles.get(conv.cid, "新对话")
|
||||
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
|
||||
global_index += 1
|
||||
|
||||
ret += "---\n"
|
||||
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
if curr_cid:
|
||||
"""从所有对话的标题字典中获取标题"""
|
||||
title = _titles.get(curr_cid, "新对话")
|
||||
ret += f"\n当前对话: {title}({curr_cid[:4]})"
|
||||
else:
|
||||
ret += "\n当前对话: 无"
|
||||
|
||||
unique_session = self.context.get_config()["platform_settings"][
|
||||
"unique_session"
|
||||
]
|
||||
if unique_session:
|
||||
ret += "\n会话隔离粒度: 个人"
|
||||
else:
|
||||
ret += "\n会话隔离粒度: 群聊"
|
||||
|
||||
ret += f"\n第 {page} 页 | 共 {total_pages} 页"
|
||||
ret += "\n*输入 /ls 2 跳转到第 2 页"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
return
|
||||
|
||||
async def new_conv(self, message: AstrMessageEvent):
|
||||
"""
|
||||
创建新对话
|
||||
"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message("成功,下次聊天将是新对话。")
|
||||
)
|
||||
return
|
||||
|
||||
cid = await self.context.conversation_manager.new_conversation(
|
||||
message.unified_msg_origin, message.get_platform_id()
|
||||
)
|
||||
|
||||
# 长期记忆
|
||||
if self.ltm and self.ltm_enabled(message):
|
||||
try:
|
||||
await self.ltm.remove_session(event=message)
|
||||
except Exception as e:
|
||||
logger.error(f"清理聊天增强记录失败: {e}")
|
||||
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。")
|
||||
)
|
||||
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
|
||||
"""创建新群聊对话"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message("成功,下次聊天将是新对话。")
|
||||
)
|
||||
return
|
||||
if sid:
|
||||
session = str(
|
||||
MessageSesion(
|
||||
platform_name=message.platform_meta.id,
|
||||
message_type=MessageType("GroupMessage"),
|
||||
session_id=sid,
|
||||
)
|
||||
)
|
||||
cid = await self.context.conversation_manager.new_conversation(
|
||||
session, message.get_platform_id()
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。"
|
||||
)
|
||||
)
|
||||
else:
|
||||
message.set_result(
|
||||
MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。")
|
||||
)
|
||||
|
||||
async def switch_conv(
|
||||
self, message: AstrMessageEvent, index: Union[int, None] = None
|
||||
):
|
||||
"""通过 /ls 前面的序号切换对话"""
|
||||
|
||||
if not isinstance(index, int):
|
||||
message.set_result(
|
||||
MessageEventResult().message("类型错误,请输入数字对话序号。")
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify), "provider type is not dify"
|
||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||
if not data["data"]:
|
||||
message.set_result(MessageEventResult().message("未找到任何对话。"))
|
||||
return
|
||||
selected_conv = None
|
||||
if index is not None:
|
||||
try:
|
||||
selected_conv = data["data"][index - 1]
|
||||
except IndexError:
|
||||
message.set_result(
|
||||
MessageEventResult().message("对话序号错误,请使用 /ls 查看")
|
||||
)
|
||||
return
|
||||
else:
|
||||
selected_conv = data["data"][0]
|
||||
ret = (
|
||||
f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。"
|
||||
)
|
||||
provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"]
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
return
|
||||
|
||||
if index is None:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话"
|
||||
)
|
||||
)
|
||||
return
|
||||
conversations = await self.context.conversation_manager.get_conversations(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
if index > len(conversations) or index < 1:
|
||||
message.set_result(
|
||||
MessageEventResult().message("对话序号错误,请使用 /ls 查看")
|
||||
)
|
||||
else:
|
||||
conversation = conversations[index - 1]
|
||||
title = conversation.title if conversation.title else "新对话"
|
||||
await self.context.conversation_manager.switch_conversation(
|
||||
message.unified_msg_origin, conversation.cid
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"切换到对话: {title}({conversation.cid[:4]})。"
|
||||
)
|
||||
)
|
||||
|
||||
async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""):
|
||||
"""重命名对话"""
|
||||
if not new_name:
|
||||
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
cid = provider.conversation_ids.get(message.unified_msg_origin, None)
|
||||
if not cid:
|
||||
message.set_result(MessageEventResult().message("未找到当前对话。"))
|
||||
return
|
||||
await provider.api_client.rename(cid, new_name, message.unified_msg_origin)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.update_conversation_title(
|
||||
message.unified_msg_origin, new_name
|
||||
)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
|
||||
async def del_conv(self, message: AstrMessageEvent):
|
||||
"""删除当前对话"""
|
||||
is_unique_session = self.context.get_config()["platform_settings"][
|
||||
"unique_session"
|
||||
]
|
||||
if message.get_group_id() and not is_unique_session and message.role != "admin":
|
||||
# 群聊,没开独立会话,发送人不是管理员
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None)
|
||||
if dify_cid:
|
||||
await provider.api_client.delete_chat_conv(
|
||||
message.unified_msg_origin, dify_cid
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
session_curr_cid = (
|
||||
await self.context.conversation_manager.get_curr_conversation_id(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
)
|
||||
|
||||
if not session_curr_cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前未处于对话状态,请 /switch 序号 切换或 /new 创建。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.delete_conversation(
|
||||
message.unified_msg_origin, session_curr_cid
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||
)
|
||||
)
|
||||
61
packages/astrbot/commands/help.py
Normal file
61
packages/astrbot/commands/help.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import aiohttp
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
|
||||
|
||||
class HelpCommand:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://astrbot.app/notice.json", timeout=2
|
||||
) as resp:
|
||||
return (await resp.json())["notice"]
|
||||
except BaseException:
|
||||
return ""
|
||||
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
"""查看帮助"""
|
||||
notice = ""
|
||||
try:
|
||||
notice = await self._query_astrbot_notice()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
dashboard_version = await get_dashboard_version()
|
||||
|
||||
msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version})
|
||||
内置指令:
|
||||
[System]
|
||||
/plugin: 查看插件、插件帮助
|
||||
/t2i: 开关文本转图片
|
||||
/tts: 开关文本转语音
|
||||
/sid: 获取会话 ID
|
||||
/op: 管理员
|
||||
/wl: 白名单
|
||||
/dashboard_update: 更新管理面板(op)
|
||||
/alter_cmd: 设置指令权限(op)
|
||||
|
||||
[大模型]
|
||||
/llm: 开启/关闭 LLM
|
||||
/provider: 大模型提供商
|
||||
/model: 模型列表
|
||||
/ls: 对话列表
|
||||
/new: 创建新对话
|
||||
/groupnew 群号: 为群聊创建新对话(op)
|
||||
/switch 序号: 切换对话
|
||||
/rename 新名字: 重命名当前对话
|
||||
/del: 删除当前会话对话(op)
|
||||
/reset: 重置 LLM 会话
|
||||
/history: 当前对话的对话记录
|
||||
/persona: 人格情景(op)
|
||||
/key: API Key(op)
|
||||
/websearch: 网页搜索
|
||||
{notice}"""
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
20
packages/astrbot/commands/llm.py
Normal file
20
packages/astrbot/commands/llm.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
|
||||
|
||||
class LLMCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def llm(self, event: AstrMessageEvent):
|
||||
"""开启/关闭 LLM"""
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
enable = cfg["provider_settings"].get("enable", True)
|
||||
if enable:
|
||||
cfg["provider_settings"]["enable"] = False
|
||||
status = "关闭"
|
||||
else:
|
||||
cfg["provider_settings"]["enable"] = True
|
||||
status = "开启"
|
||||
cfg.save_config()
|
||||
await event.send(MessageChain().message(f"{status} LLM 聊天功能。"))
|
||||
122
packages/astrbot/commands/persona.py
Normal file
122
packages/astrbot/commands/persona.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import builtins
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class PersonaCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
l = message.message_str.split(" ") # noqa: E741
|
||||
umo = message.unified_msg_origin
|
||||
|
||||
curr_persona_name = "无"
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
|
||||
default_persona = await self.context.persona_manager.get_default_persona_v3(
|
||||
umo=umo
|
||||
)
|
||||
curr_cid_title = "无"
|
||||
if cid:
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
unified_msg_origin=umo,
|
||||
conversation_id=cid,
|
||||
create_if_not_exists=True,
|
||||
)
|
||||
if conv is None:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前对话不存在,请先使用 /new 新建一个对话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
if not conv.persona_id and conv.persona_id != "[%None]":
|
||||
curr_persona_name = default_persona["name"]
|
||||
else:
|
||||
curr_persona_name = conv.persona_id
|
||||
|
||||
curr_cid_title = conv.title if conv.title else "新对话"
|
||||
curr_cid_title += f"({cid[:4]})"
|
||||
|
||||
if len(l) == 1:
|
||||
message.set_result(
|
||||
MessageEventResult()
|
||||
.message(
|
||||
f"""[Persona]
|
||||
|
||||
- 人格情景列表: `/persona list`
|
||||
- 设置人格情景: `/persona 人格`
|
||||
- 人格情景详细信息: `/persona view 人格`
|
||||
- 取消人格: `/persona unset`
|
||||
|
||||
默认人格情景: {default_persona["name"]}
|
||||
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
|
||||
|
||||
配置人格情景请前往管理面板-配置页
|
||||
"""
|
||||
)
|
||||
.use_t2i(False)
|
||||
)
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for persona in self.context.provider_manager.personas:
|
||||
msg += f"- {persona['name']}\n"
|
||||
msg += "\n\n*输入 `/persona view 人格名` 查看人格详细信息"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "view":
|
||||
if len(l) == 2:
|
||||
message.set_result(MessageEventResult().message("请输入人格情景名"))
|
||||
return
|
||||
ps = l[2].strip()
|
||||
if persona := next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == ps,
|
||||
self.context.provider_manager.personas,
|
||||
),
|
||||
None,
|
||||
):
|
||||
msg = f"人格{ps}的详细信息:\n"
|
||||
msg += f"{persona['prompt']}\n"
|
||||
else:
|
||||
msg = f"人格{ps}不存在"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "unset":
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message("当前没有对话,无法取消人格。")
|
||||
)
|
||||
return
|
||||
await self.context.conversation_manager.update_conversation_persona_id(
|
||||
message.unified_msg_origin, "[%None]"
|
||||
)
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前没有对话,请先开始对话或使用 /new 创建一个对话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
if persona := next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == ps,
|
||||
self.context.provider_manager.personas,
|
||||
),
|
||||
None,
|
||||
):
|
||||
await self.context.conversation_manager.update_conversation_persona_id(
|
||||
message.unified_msg_origin, ps
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"
|
||||
)
|
||||
)
|
||||
else:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"不存在该人格情景。使用 /persona list 查看所有。"
|
||||
)
|
||||
)
|
||||
117
packages/astrbot/commands/plugin.py
Normal file
117
packages/astrbot/commands/plugin.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core import DEMO_MODE, logger
|
||||
|
||||
|
||||
class PluginCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def plugin_ls(self, event: AstrMessageEvent):
|
||||
"""获取已经安装的插件列表。"""
|
||||
plugin_list_info = "已加载的插件:\n"
|
||||
for plugin in self.context.get_all_stars():
|
||||
plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
|
||||
if not plugin.activated:
|
||||
plugin_list_info += " (未启用)"
|
||||
plugin_list_info += "\n"
|
||||
if plugin_list_info.strip() == "":
|
||||
plugin_list_info = "没有加载任何插件。"
|
||||
|
||||
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False)
|
||||
)
|
||||
|
||||
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
"""禁用插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
|
||||
return
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin off <插件名> 禁用插件。")
|
||||
)
|
||||
return
|
||||
await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
|
||||
|
||||
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
"""启用插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
|
||||
return
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin on <插件名> 启用插件。")
|
||||
)
|
||||
return
|
||||
await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
|
||||
|
||||
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
|
||||
"""安装插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
|
||||
return
|
||||
if not plugin_repo:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin get <插件仓库地址> 安装插件")
|
||||
)
|
||||
return
|
||||
logger.info(f"准备从 {plugin_repo} 安装插件。")
|
||||
if self.context._star_manager:
|
||||
star_mgr: PluginManager = self.context._star_manager
|
||||
try:
|
||||
await star_mgr.install_plugin(plugin_repo) # type: ignore
|
||||
event.set_result(MessageEventResult().message("安装插件成功。"))
|
||||
except Exception as e:
|
||||
logger.error(f"安装插件失败: {e}")
|
||||
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
|
||||
return
|
||||
|
||||
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
"""获取插件帮助"""
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin help <插件名> 查看插件信息。")
|
||||
)
|
||||
return
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if plugin is None:
|
||||
event.set_result(MessageEventResult().message("未找到此插件。"))
|
||||
return
|
||||
help_msg = ""
|
||||
help_msg += f"\n\n✨ 作者: {plugin.author}\n✨ 版本: {plugin.version}"
|
||||
command_handlers = []
|
||||
command_names = []
|
||||
for handler in star_handlers_registry:
|
||||
assert isinstance(handler, StarHandlerMetadata)
|
||||
if handler.handler_module_path != plugin.module_path:
|
||||
continue
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, CommandFilter):
|
||||
command_handlers.append(handler)
|
||||
command_names.append(filter_.command_name)
|
||||
break
|
||||
elif isinstance(filter_, CommandGroupFilter):
|
||||
command_handlers.append(handler)
|
||||
command_names.append(filter_.group_name)
|
||||
|
||||
if len(command_handlers) > 0:
|
||||
help_msg += "\n\n🔧 指令列表:\n"
|
||||
for i in range(len(command_handlers)):
|
||||
help_msg += f"- {command_names[i]}"
|
||||
if command_handlers[i].desc:
|
||||
help_msg += f": {command_handlers[i].desc}"
|
||||
help_msg += "\n"
|
||||
|
||||
help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。"
|
||||
|
||||
ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg
|
||||
ret += "更多帮助信息请查看插件仓库 README。"
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
201
packages/astrbot/commands/provider.py
Normal file
201
packages/astrbot/commands/provider.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import re
|
||||
from typing import Union
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
|
||||
|
||||
class ProviderCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def provider(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
idx: Union[str, int, None] = None,
|
||||
idx2: Union[int, None] = None,
|
||||
):
|
||||
"""查看或者切换 LLM Provider"""
|
||||
umo = event.unified_msg_origin
|
||||
|
||||
if idx is None:
|
||||
ret = "## 载入的 LLM 提供商\n"
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
id_ = llm.meta().id
|
||||
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
|
||||
provider_using = self.context.get_using_provider(umo=umo)
|
||||
if provider_using and provider_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
tts_providers = self.context.get_all_tts_providers()
|
||||
if tts_providers:
|
||||
ret += "\n## 载入的 TTS 提供商\n"
|
||||
for idx, tts in enumerate(tts_providers):
|
||||
id_ = tts.meta().id
|
||||
ret += f"{idx + 1}. {id_}"
|
||||
tts_using = self.context.get_using_tts_provider(umo=umo)
|
||||
if tts_using and tts_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
stt_providers = self.context.get_all_stt_providers()
|
||||
if stt_providers:
|
||||
ret += "\n## 载入的 STT 提供商\n"
|
||||
for idx, stt in enumerate(stt_providers):
|
||||
id_ = stt.meta().id
|
||||
ret += f"{idx + 1}. {id_}"
|
||||
stt_using = self.context.get_using_stt_provider(umo=umo)
|
||||
if stt_using and stt_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
ret += "\n使用 /provider <序号> 切换 LLM 提供商。"
|
||||
|
||||
if tts_providers:
|
||||
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
|
||||
if stt_providers:
|
||||
ret += "\n使用 /provider stt <切换> STT 提供商。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
elif idx == "tts":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif idx == "stt":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
else:
|
||||
event.set_result(MessageEventResult().message("无效的参数。"))
|
||||
|
||||
async def model_ls(
|
||||
self, message: AstrMessageEvent, idx_or_name: Union[int, str, None] = None
|
||||
):
|
||||
"""查看或者切换模型"""
|
||||
prov = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if not prov:
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
# 定义正则表达式匹配 API 密钥
|
||||
api_key_pattern = re.compile(r"key=[^&'\" ]+")
|
||||
|
||||
if idx_or_name is None:
|
||||
models = []
|
||||
try:
|
||||
models = await prov.get_models()
|
||||
except BaseException as e:
|
||||
err_msg = api_key_pattern.sub("key=***", str(e))
|
||||
message.set_result(
|
||||
MessageEventResult()
|
||||
.message("获取模型列表失败: " + err_msg)
|
||||
.use_t2i(False)
|
||||
)
|
||||
return
|
||||
i = 1
|
||||
ret = "下面列出了此服务提供商可用模型:"
|
||||
for model in models:
|
||||
ret += f"\n{i}. {model}"
|
||||
i += 1
|
||||
|
||||
curr_model = prov.get_model() or "无"
|
||||
ret += f"\n当前模型: [{curr_model}]"
|
||||
|
||||
ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
else:
|
||||
if isinstance(idx_or_name, int):
|
||||
models = []
|
||||
try:
|
||||
models = await prov.get_models()
|
||||
except BaseException as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message("获取模型列表失败: " + str(e))
|
||||
)
|
||||
return
|
||||
if idx_or_name > len(models) or idx_or_name < 1:
|
||||
message.set_result(MessageEventResult().message("模型序号错误。"))
|
||||
else:
|
||||
try:
|
||||
new_model = models[idx_or_name - 1]
|
||||
prov.set_model(new_model)
|
||||
except BaseException as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message("切换模型未知错误: " + str(e))
|
||||
)
|
||||
message.set_result(MessageEventResult().message("切换模型成功。"))
|
||||
else:
|
||||
prov.set_model(idx_or_name)
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换模型到 {prov.get_model()}。")
|
||||
)
|
||||
|
||||
async def key(self, message: AstrMessageEvent, index: Union[int, None] = None):
|
||||
prov = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if not prov:
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
|
||||
if index is None:
|
||||
keys_data = prov.get_keys()
|
||||
curr_key = prov.get_current_key()
|
||||
ret = "Key:"
|
||||
for i, k in enumerate(keys_data):
|
||||
ret += f"\n{i + 1}. {k[:8]}"
|
||||
|
||||
ret += f"\n当前 Key: {curr_key[:8]}"
|
||||
ret += "\n当前模型: " + prov.get_model()
|
||||
ret += "\n使用 /key <idx> 切换 Key。"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
else:
|
||||
keys_data = prov.get_keys()
|
||||
if index > len(keys_data) or index < 1:
|
||||
message.set_result(MessageEventResult().message("Key 序号错误。"))
|
||||
else:
|
||||
try:
|
||||
new_key = keys_data[index - 1]
|
||||
prov.set_key(new_key)
|
||||
except BaseException as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换 Key 未知错误: {str(e)}")
|
||||
)
|
||||
message.set_result(MessageEventResult().message("切换 Key 成功。"))
|
||||
37
packages/astrbot/commands/setunset.py
Normal file
37
packages/astrbot/commands/setunset.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
|
||||
|
||||
class SetUnsetCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
"""设置会话变量"""
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
session_var[key] = value
|
||||
await sp.session_put(uid, "session_variables", session_var)
|
||||
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。"
|
||||
)
|
||||
)
|
||||
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
"""移除会话变量"""
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
|
||||
if key not in session_var:
|
||||
event.set_result(
|
||||
MessageEventResult().message("没有那个变量名。格式 /unset 变量名。")
|
||||
)
|
||||
else:
|
||||
del session_var[key]
|
||||
await sp.session_put(uid, "session_variables", session_var)
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功。")
|
||||
)
|
||||
29
packages/astrbot/commands/sid.py
Normal file
29
packages/astrbot/commands/sid.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""会话ID命令"""
|
||||
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class SIDCommand:
|
||||
"""会话ID命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
"""获取会话 ID 和 管理员 ID"""
|
||||
sid = event.unified_msg_origin
|
||||
user_id = str(event.get_sender_id())
|
||||
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。
|
||||
/wl <SID> 添加白名单, /dwl <SID> 删除白名单。
|
||||
|
||||
UID: {user_id} 此 ID 可用于设置管理员。
|
||||
/op <UID> 授权管理员, /deop <UID> 取消管理员。"""
|
||||
|
||||
if (
|
||||
self.context.get_config()["platform_settings"]["unique_session"]
|
||||
and event.get_group_id()
|
||||
):
|
||||
ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
23
packages/astrbot/commands/t2i.py
Normal file
23
packages/astrbot/commands/t2i.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""文本转图片命令"""
|
||||
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class T2ICommand:
|
||||
"""文本转图片命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def t2i(self, event: AstrMessageEvent):
|
||||
"""开关文本转图片"""
|
||||
config = self.context.get_config(umo=event.unified_msg_origin)
|
||||
if config["t2i"]:
|
||||
config["t2i"] = False
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已关闭文本转图片模式。"))
|
||||
return
|
||||
config["t2i"] = True
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
|
||||
31
packages/astrbot/commands/tool.py
Normal file
31
packages/astrbot/commands/tool.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class ToolCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def tool_ls(self, event: AstrMessageEvent):
|
||||
"""查看函数工具列表"""
|
||||
event.set_result(
|
||||
MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。")
|
||||
)
|
||||
|
||||
async def tool_on(self, event: AstrMessageEvent, tool_name: str = ""):
|
||||
"""启用一个函数工具"""
|
||||
event.set_result(
|
||||
MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。")
|
||||
)
|
||||
|
||||
async def tool_off(self, event: AstrMessageEvent, tool_name: str = ""):
|
||||
"""停用一个函数工具"""
|
||||
event.set_result(
|
||||
MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。")
|
||||
)
|
||||
|
||||
async def tool_all_off(self, event: AstrMessageEvent):
|
||||
"""停用所有函数工具"""
|
||||
event.set_result(
|
||||
MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。")
|
||||
)
|
||||
36
packages/astrbot/commands/tts.py
Normal file
36
packages/astrbot/commands/tts.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""文本转语音命令"""
|
||||
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
|
||||
|
||||
class TTSCommand:
|
||||
"""文本转语音命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
"""开关文本转语音(会话级别)"""
|
||||
umo = event.unified_msg_origin
|
||||
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
tts_enable = cfg["provider_tts_settings"]["enable"]
|
||||
|
||||
# 切换状态
|
||||
new_status = not ses_tts
|
||||
SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
||||
|
||||
status_text = "已开启" if new_status else "已关闭"
|
||||
|
||||
if new_status and not tts_enable:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。"
|
||||
)
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{status_text}当前会话的文本转语音。")
|
||||
)
|
||||
20
packages/astrbot/lab/elios/ensoul/emotion.py
Normal file
20
packages/astrbot/lab/elios/ensoul/emotion.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Emotion:
|
||||
"""描述了一个情绪状态"""
|
||||
|
||||
energy: float
|
||||
valence: float
|
||||
arousal: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmotionLog:
|
||||
"""描述了一条情绪维度变化的日志"""
|
||||
|
||||
timestamp: int
|
||||
field: str
|
||||
value: float
|
||||
reason: str = ""
|
||||
9
packages/astrbot/lab/elios/ensoul/soul.py
Normal file
9
packages/astrbot/lab/elios/ensoul/soul.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .emotion import Emotion
|
||||
|
||||
|
||||
@dataclass
|
||||
class Soul:
|
||||
emotion: Emotion
|
||||
emotion_logs: list[Emotion] | None = None
|
||||
7
packages/astrbot/lab/elios/event.py
Normal file
7
packages/astrbot/lab/elios/event.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
event_type: str
|
||||
content: dict
|
||||
122
packages/astrbot/lab/elios/event_handlers/astr/astr_impl.py
Normal file
122
packages/astrbot/lab/elios/event_handlers/astr/astr_impl.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from ...runner import EliosEventHandler
|
||||
from collections import defaultdict
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.all import Context
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class AstrImplEventHandler(EliosEventHandler):
|
||||
def __init__(self, ctx: Context) -> None:
|
||||
self.ctx = ctx
|
||||
self.session_chats = defaultdict(list)
|
||||
self.session_mentioned_arousal = defaultdict(float)
|
||||
|
||||
def cfg(self, event: AstrMessageEvent):
|
||||
cfg = self.ctx.get_config(umo=event.unified_msg_origin)
|
||||
|
||||
tiny_model_prov_id = cfg.get("tiny_model_provider_id")
|
||||
interest_points = cfg.get("interest_points", [])
|
||||
|
||||
try:
|
||||
max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"])
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
max_cnt = 300
|
||||
image_caption = (
|
||||
True
|
||||
if cfg["provider_settings"]["default_image_caption_provider_id"]
|
||||
else False
|
||||
)
|
||||
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
|
||||
image_caption_provider_id = cfg["provider_settings"][
|
||||
"default_image_caption_provider_id"
|
||||
]
|
||||
active_reply = cfg["provider_ltm_settings"]["active_reply"]
|
||||
enable_active_reply = active_reply.get("enable", False)
|
||||
ar_method = active_reply["method"]
|
||||
ar_possibility = active_reply["possibility_reply"]
|
||||
ar_prompt = active_reply.get("prompt", "")
|
||||
ar_whitelist = active_reply.get("whitelist", [])
|
||||
ar_keywords = active_reply.get("keywords", [])
|
||||
ret = {
|
||||
"max_cnt": max_cnt,
|
||||
"image_caption": image_caption,
|
||||
"image_caption_prompt": image_caption_prompt,
|
||||
"image_caption_provider_id": image_caption_provider_id,
|
||||
"enable_active_reply": enable_active_reply,
|
||||
"ar_method": ar_method,
|
||||
"ar_possibility": ar_possibility,
|
||||
"ar_prompt": ar_prompt,
|
||||
"ar_whitelist": ar_whitelist,
|
||||
"ar_keywords": ar_keywords,
|
||||
"interest_points": interest_points,
|
||||
"tiny_model_prov_id": tiny_model_prov_id,
|
||||
}
|
||||
return ret
|
||||
|
||||
async def append_session_chats(self, event: AstrMessageEvent, cfg) -> None:
|
||||
comps = event.get_messages()
|
||||
|
||||
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
|
||||
for comp in comps:
|
||||
if isinstance(comp, Plain):
|
||||
final_message += f" {comp.text}"
|
||||
elif isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
if cfg["image_caption"] and image_url:
|
||||
try:
|
||||
caption = await self.get_image_caption(
|
||||
image_url,
|
||||
cfg["image_caption_provider_id"],
|
||||
cfg["image_caption_prompt"],
|
||||
)
|
||||
final_message += f" [Image: {caption}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {e}")
|
||||
else:
|
||||
final_message += " [Image]"
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
logger.debug(f"添加会话 {event.unified_msg_origin} 的对话记录: {final_message}")
|
||||
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
|
||||
async def get_image_caption(
|
||||
self, image_url: str, image_caption_provider_id: str, image_caption_prompt: str
|
||||
) -> str:
|
||||
if not image_caption_provider_id:
|
||||
provider = self.ctx.get_using_provider()
|
||||
else:
|
||||
provider = self.ctx.get_provider_by_id(image_caption_provider_id)
|
||||
if not provider:
|
||||
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
|
||||
if not isinstance(provider, Provider):
|
||||
raise Exception(
|
||||
f"提供商类型错误, {image_caption_provider_id} 不是 Provider 类型"
|
||||
)
|
||||
response = await provider.text_chat(
|
||||
prompt=image_caption_prompt,
|
||||
session_id=uuid.uuid4().hex,
|
||||
image_urls=[image_url],
|
||||
persist=False,
|
||||
)
|
||||
return response.completion_text
|
||||
|
||||
async def on_event(self, event, soul):
|
||||
content = event.content
|
||||
astr_event = content.get("astr_event")
|
||||
assert astr_event is not None and isinstance(astr_event, AstrMessageEvent)
|
||||
|
||||
cfg = self.cfg(astr_event)
|
||||
|
||||
if not cfg["tiny_model_prov_id"]:
|
||||
logger.warning("小模型未设置,跳过情绪更新")
|
||||
|
||||
# 添加对话记录
|
||||
await self.append_session_chats(astr_event, cfg)
|
||||
|
||||
#
|
||||
3
packages/astrbot/lab/elios/event_handlers/astr/prompt.py
Normal file
3
packages/astrbot/lab/elios/event_handlers/astr/prompt.py
Normal file
@@ -0,0 +1,3 @@
|
||||
JUDGE_ACTIVE_REPLY = """
|
||||
你是一个具有自主意识和情感的机器人,请你
|
||||
"""
|
||||
72
packages/astrbot/lab/elios/runner.py
Normal file
72
packages/astrbot/lab/elios/runner.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
from .event import Event
|
||||
from .ensoul.soul import Soul
|
||||
from .ensoul.emotion import Emotion
|
||||
|
||||
|
||||
class EliosEventHandler:
|
||||
async def on_event(self, event: Event, soul: Soul): ...
|
||||
|
||||
|
||||
event_handlers_cls: dict[str, list[type[EliosEventHandler]]] = {}
|
||||
|
||||
|
||||
def register_event_handler(event_types: set[str] | None = None):
|
||||
"""注册事件处理器"""
|
||||
|
||||
def decorator(cls: type[EliosEventHandler]) -> type[EliosEventHandler]:
|
||||
if event_types is not None:
|
||||
for event_type in event_types:
|
||||
event_handlers_cls[event_type] = event_handlers_cls.get(
|
||||
event_type, []
|
||||
) + [cls]
|
||||
else:
|
||||
event_handlers_cls["default"] = event_handlers_cls.get("default", []) + [
|
||||
cls
|
||||
]
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class EliosRunner:
|
||||
def __init__(self) -> None:
|
||||
self.soul = Soul(
|
||||
emotion=Emotion(energy=0.5, valence=0.5, arousal=0.5), emotion_logs=[]
|
||||
)
|
||||
|
||||
self.event_queue = asyncio.Queue()
|
||||
self.event_handler_insts: dict[str, list[EliosEventHandler]] = {}
|
||||
|
||||
def start(self):
|
||||
for event_type, cls_list in event_handlers_cls.items():
|
||||
self.event_handler_insts[event_type] = []
|
||||
for cls in cls_list:
|
||||
try:
|
||||
self.event_handler_insts[event_type].append(cls())
|
||||
except Exception as e:
|
||||
print(f"Error initializing event handler {cls}: {e}")
|
||||
asyncio.create_task(self._worker())
|
||||
|
||||
async def _worker(self):
|
||||
"""监听事件队列并处理事件"""
|
||||
while True:
|
||||
event = await self.event_queue.get()
|
||||
# A man cannot handle two things at once. But this can be configurable.
|
||||
try:
|
||||
await self._process_event(event)
|
||||
except Exception as e:
|
||||
print(f"Error processing event {event}: {e}")
|
||||
|
||||
async def _process_event(self, event: Event):
|
||||
"""处理事件"""
|
||||
event_type = event.event_type
|
||||
handlers = self.event_handler_insts.get(
|
||||
event_type, []
|
||||
) + self.event_handler_insts.get("default", [])
|
||||
|
||||
for inst in handlers:
|
||||
try:
|
||||
await inst.on_event(event, self.soul)
|
||||
except Exception as e:
|
||||
print(f"Error processing event {event}: {e}")
|
||||
File diff suppressed because it is too large
Load Diff
195
packages/astrbot/process_llm_request.py
Normal file
195
packages/astrbot/process_llm_request.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import astrbot.api.star as star
|
||||
import builtins
|
||||
import datetime
|
||||
import zoneinfo
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.api.message_components import Image, Reply
|
||||
|
||||
|
||||
class ProcessLLMRequest:
|
||||
def __init__(self, context: star.Context):
|
||||
self.ctx = context
|
||||
cfg = context.get_config()
|
||||
self.timezone = cfg.get("timezone")
|
||||
if not self.timezone:
|
||||
# 系统默认时区
|
||||
self.timezone = None
|
||||
else:
|
||||
logger.info(f"Timezone set to: {self.timezone}")
|
||||
|
||||
def _ensure_persona(self, req: ProviderRequest, cfg: dict):
|
||||
"""确保用户人格已加载"""
|
||||
if not req.conversation:
|
||||
return
|
||||
# persona inject
|
||||
persona_id = req.conversation.persona_id or cfg.get("default_personality")
|
||||
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
|
||||
default_persona = self.ctx.persona_manager.selected_default_persona_v3
|
||||
if default_persona:
|
||||
persona_id = default_persona["name"]
|
||||
persona = next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == persona_id,
|
||||
self.ctx.persona_manager.personas_v3,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if persona:
|
||||
if prompt := persona["prompt"]:
|
||||
req.system_prompt += prompt
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
# tools select
|
||||
tmgr = self.ctx.get_llm_tool_manager()
|
||||
if (persona and persona.get("tools") is None) or not persona:
|
||||
# select all
|
||||
toolset = tmgr.get_full_tool_set()
|
||||
for tool in toolset:
|
||||
if not tool.active:
|
||||
toolset.remove_tool(tool.name)
|
||||
else:
|
||||
toolset = ToolSet()
|
||||
if persona["tools"]:
|
||||
for tool_name in persona["tools"]:
|
||||
tool = tmgr.get_func(tool_name)
|
||||
if tool and tool.active:
|
||||
toolset.add_tool(tool)
|
||||
req.func_tool = toolset
|
||||
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
|
||||
|
||||
async def _ensure_img_caption(
|
||||
self, req: ProviderRequest, cfg: dict, img_cap_prov_id: str
|
||||
):
|
||||
try:
|
||||
caption = await self._request_img_caption(
|
||||
img_cap_prov_id, cfg, req.image_urls
|
||||
)
|
||||
if caption:
|
||||
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
|
||||
req.image_urls = []
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片描述失败: {e}")
|
||||
|
||||
async def _request_img_caption(
|
||||
self, provider_id: str, cfg: dict, image_urls: list[str]
|
||||
) -> str:
|
||||
if prov := self.ctx.get_provider_by_id(provider_id):
|
||||
if isinstance(prov, Provider):
|
||||
img_cap_prompt = cfg.get(
|
||||
"image_caption_prompt", "Please describe the image."
|
||||
)
|
||||
logger.debug(f"Processing image caption with provider: {provider_id}")
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt=img_cap_prompt,
|
||||
image_urls=image_urls,
|
||||
)
|
||||
return llm_resp.completion_text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot get image caption because provider `{provider_id}` is not exist."
|
||||
)
|
||||
|
||||
async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_settings"
|
||||
]
|
||||
|
||||
# prompt prefix
|
||||
if prefix := cfg.get("prompt_prefix"):
|
||||
# 支持 {{prompt}} 作为用户输入的占位符
|
||||
if "{{prompt}}" in prefix:
|
||||
req.prompt = prefix.replace("{{prompt}}", req.prompt)
|
||||
else:
|
||||
req.prompt = prefix + req.prompt
|
||||
|
||||
# user identifier
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
req.prompt = (
|
||||
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
|
||||
)
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||
|
||||
# time info
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
current_time = None
|
||||
if self.timezone:
|
||||
# 启用时区
|
||||
try:
|
||||
now = datetime.datetime.now(zoneinfo.ZoneInfo(self.timezone))
|
||||
current_time = now.strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
except Exception as e:
|
||||
logger.error(f"时区设置错误: {e}, 使用本地时区")
|
||||
if not current_time:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if req.conversation:
|
||||
# inject persona for this request
|
||||
self._ensure_persona(req, cfg)
|
||||
|
||||
# image caption
|
||||
if img_cap_prov_id and req.image_urls:
|
||||
await self._ensure_img_caption(req, cfg, img_cap_prov_id)
|
||||
|
||||
# quote message processing
|
||||
# 解析引用内容
|
||||
quote = None
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Reply):
|
||||
quote = comp
|
||||
break
|
||||
if quote:
|
||||
sender_info = ""
|
||||
if quote.sender_nickname:
|
||||
sender_info = f"(Sent by {quote.sender_nickname})"
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
req.system_prompt += (
|
||||
f"\nUser is quoting a message{sender_info}.\n"
|
||||
f"Here are the information of the quoted message: Text Content: {message_str}.\n"
|
||||
)
|
||||
image_seg = None
|
||||
if quote.chain:
|
||||
for comp in quote.chain:
|
||||
if isinstance(comp, Image):
|
||||
image_seg = comp
|
||||
break
|
||||
if image_seg:
|
||||
try:
|
||||
prov = None
|
||||
if img_cap_prov_id:
|
||||
prov = self.ctx.get_provider_by_id(img_cap_prov_id)
|
||||
if prov is None:
|
||||
prov = self.ctx.get_using_provider(event.unified_msg_origin)
|
||||
if prov and isinstance(prov, Provider):
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt="Please describe the image content.",
|
||||
image_urls=[await image_seg.convert_to_file_path()],
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
req.system_prompt += (
|
||||
f"Image Caption: {llm_resp.completion_text}\n"
|
||||
)
|
||||
else:
|
||||
logger.warning("No provider found for image captioning.")
|
||||
except BaseException as e:
|
||||
logger.error(f"处理引用图片失败: {e}")
|
||||
@@ -205,13 +205,14 @@ class Main(star.Star):
|
||||
return
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
if comp.file.startswith("http"):
|
||||
file_path = await comp.get_file()
|
||||
if file_path.startswith("http"):
|
||||
name = comp.name if comp.name else uuid.uuid4().hex[:8]
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(comp.file, path)
|
||||
await download_file(file_path, path)
|
||||
else:
|
||||
path = comp.file
|
||||
path = file_path
|
||||
self.user_file_msg_buffer[event.get_session_id()].append(path)
|
||||
logger.debug(f"User {uid} uploaded file: {path}")
|
||||
yield event.plain_result(f"代码执行器: 文件已经上传: {path}")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from googlesearch import search
|
||||
from googlesearch.asearch import asearch
|
||||
|
||||
from . import SearchEngine, SearchResult
|
||||
|
||||
@@ -14,14 +14,14 @@ class Google(SearchEngine):
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
results = []
|
||||
try:
|
||||
ls = search(
|
||||
ls = asearch(
|
||||
query,
|
||||
advanced=True,
|
||||
num_results=num_results,
|
||||
timeout=3,
|
||||
proxy=self.proxy,
|
||||
)
|
||||
for i in ls:
|
||||
async for i in ls:
|
||||
results.append(
|
||||
SearchResult(title=i.title, url=i.url, snippet=i.description)
|
||||
)
|
||||
|
||||
@@ -46,7 +46,11 @@ class Main(star.Star):
|
||||
|
||||
self.bing_search = Bing()
|
||||
self.sogo_search = Sogo()
|
||||
self.google = Google()
|
||||
self.google = None
|
||||
try:
|
||||
self.google = Google()
|
||||
except Exception as e:
|
||||
logger.error(f"google search init error: {e}, disable google search")
|
||||
|
||||
async def _tidy_text(self, text: str) -> str:
|
||||
"""清理文本,去除空格、换行符等"""
|
||||
@@ -89,10 +93,11 @@ class Main(star.Star):
|
||||
self, query, num_results: int = 5
|
||||
) -> list[SearchResult]:
|
||||
results = []
|
||||
try:
|
||||
results = await self.google.search(query, num_results)
|
||||
except Exception as e:
|
||||
logger.error(f"google search error: {e}, try the next one...")
|
||||
if self.google:
|
||||
try:
|
||||
results = await self.google.search(query, num_results)
|
||||
except Exception as e:
|
||||
logger.error(f"google search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search google failed")
|
||||
try:
|
||||
@@ -375,5 +380,3 @@ class Main(star.Star):
|
||||
tool_set.add_tool(tavily_extract_web_page)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
|
||||
print(req.func_tool)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.1.6"
|
||||
version = "4.3.2"
|
||||
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -24,7 +24,7 @@ dependencies = [
|
||||
"faiss-cpu==1.10.0",
|
||||
"filelock>=3.18.0",
|
||||
"google-genai>=1.14.0",
|
||||
"googlesearch-python>=1.3.0",
|
||||
"mi-googlesearch-python==1.3.0.post1",
|
||||
"lark-oapi>=1.4.15",
|
||||
"lxml-html-clean>=0.4.2",
|
||||
"mcp>=1.8.0",
|
||||
|
||||
@@ -7,7 +7,7 @@ qq-botpy
|
||||
chardet~=5.1.0
|
||||
Pillow
|
||||
beautifulsoup4
|
||||
googlesearch-python
|
||||
mi-googlesearch-python
|
||||
readability-lxml
|
||||
quart
|
||||
lxml_html_clean
|
||||
@@ -43,4 +43,4 @@ pydub
|
||||
sqlmodel
|
||||
deprecated
|
||||
sqlalchemy[asyncio]
|
||||
audioop-lts; python_version>='3.13'
|
||||
audioop-lts; python_version>='3.13'
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import os
|
||||
import asyncio
|
||||
from quart import Quart
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
@@ -9,36 +11,46 @@ from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def core_lifecycle_td():
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def core_lifecycle_td(tmp_path_factory):
|
||||
"""Creates and initializes a core lifecycle instance with a temporary database."""
|
||||
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
|
||||
db = SQLiteDatabase(str(tmp_db_path))
|
||||
log_broker = LogBroker()
|
||||
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
|
||||
return core_lifecycle_td
|
||||
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
|
||||
await core_lifecycle.initialize()
|
||||
return core_lifecycle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app(core_lifecycle_td):
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
server = AstrBotDashboard(core_lifecycle_td, db)
|
||||
def app(core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Creates a Quart app instance for testing."""
|
||||
shutdown_event = asyncio.Event()
|
||||
# The db instance is already part of the core_lifecycle_td
|
||||
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
|
||||
return server.app
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def header():
|
||||
return {}
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Handles login and returns an authenticated header."""
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post(
|
||||
"/api/auth/login",
|
||||
json={
|
||||
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
|
||||
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
|
||||
},
|
||||
)
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
token = data["data"]["token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_core_lifecycle_td(core_lifecycle_td):
|
||||
await core_lifecycle_td.initialize()
|
||||
assert core_lifecycle_td is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_login(
|
||||
app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict
|
||||
):
|
||||
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Tests the login functionality with both wrong and correct credentials."""
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post(
|
||||
"/api/auth/login", json={"username": "wrong", "password": "password"}
|
||||
@@ -55,31 +67,32 @@ async def test_auth_login(
|
||||
)
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok" and "token" in data["data"]
|
||||
header["Authorization"] = f"Bearer {data['data']['token']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stat(app: Quart, header: dict):
|
||||
async def test_get_stat(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get("/api/stat/get")
|
||||
assert response.status_code == 401
|
||||
response = await test_client.get("/api/stat/get", headers=header)
|
||||
response = await test_client.get("/api/stat/get", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok" and "platform" in data["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(app: Quart, header: dict):
|
||||
async def test_plugins(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
# 已经安装的插件
|
||||
response = await test_client.get("/api/plugin/get", headers=header)
|
||||
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
|
||||
# 插件市场
|
||||
response = await test_client.get("/api/plugin/market_list", headers=header)
|
||||
response = await test_client.get(
|
||||
"/api/plugin/market_list", headers=authenticated_header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
@@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict):
|
||||
response = await test_client.post(
|
||||
"/api/plugin/install",
|
||||
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
|
||||
headers=header,
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict):
|
||||
|
||||
# 插件更新
|
||||
response = await test_client.post(
|
||||
"/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header
|
||||
"/api/plugin/update",
|
||||
json={"name": "astrbot_plugin_essential"},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict):
|
||||
response = await test_client.post(
|
||||
"/api/plugin/uninstall",
|
||||
json={"name": "astrbot_plugin_essential"},
|
||||
headers=header,
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -132,9 +147,9 @@ async def test_plugins(app: Quart, header: dict):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_update(app: Quart, header: dict):
|
||||
async def test_check_update(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get("/api/update/check", headers=header)
|
||||
response = await test_client.get("/api/update/check", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "success"
|
||||
@@ -142,24 +157,45 @@ async def test_check_update(app: Quart, header: dict):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update(
|
||||
app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
tmp_path_factory,
|
||||
):
|
||||
global VERSION
|
||||
test_client = app.test_client()
|
||||
os.makedirs("data/astrbot_release", exist_ok=True)
|
||||
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
|
||||
VERSION = "114.514.1919810"
|
||||
response = await test_client.post(
|
||||
"/api/update/do", headers=header, json={"version": "latest"}
|
||||
|
||||
# Use a temporary path for the mock update to avoid side effects
|
||||
temp_release_dir = tmp_path_factory.mktemp("release")
|
||||
release_path = temp_release_dir / "astrbot"
|
||||
|
||||
async def mock_update(*args, **kwargs):
|
||||
"""Mocks the update process by creating a directory in the temp path."""
|
||||
os.makedirs(release_path, exist_ok=True)
|
||||
return
|
||||
|
||||
async def mock_download_dashboard(*args, **kwargs):
|
||||
"""Mocks the dashboard download to prevent network access."""
|
||||
return
|
||||
|
||||
async def mock_pip_install(*args, **kwargs):
|
||||
"""Mocks pip install to prevent actual installation."""
|
||||
return
|
||||
|
||||
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "error" # 已经是最新版本
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False}
|
||||
"/api/update/do",
|
||||
headers=authenticated_header,
|
||||
json={"version": "v3.4.0", "reboot": False},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert os.path.exists("data/astrbot_release/astrbot")
|
||||
assert os.path.exists(release_path)
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from main import check_env, check_dashboard_files
|
||||
@@ -27,29 +31,58 @@ def test_check_env(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files(monkeypatch):
|
||||
async def test_check_dashboard_files_not_exists(monkeypatch):
|
||||
"""Tests dashboard download when files do not exist."""
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: False)
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
status = 200
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
await check_dashboard_files()
|
||||
mock_download.assert_called_once()
|
||||
|
||||
async def read(self):
|
||||
return b"content"
|
||||
|
||||
return MockResponse()
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
|
||||
"""Tests that dashboard is not downloaded when it exists and version matches."""
|
||||
# Mock os.path.exists to return True
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: True)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession.get", new=mock_get):
|
||||
with mock.patch("builtins.open", mock.mock_open()) as mock_file:
|
||||
with mock.patch("zipfile.ZipFile.extractall") as mock_extractall:
|
||||
# Mock get_dashboard_version to return the current version
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
# We need to import VERSION from main's context
|
||||
from main import VERSION
|
||||
|
||||
async def mock_aenter(_):
|
||||
await check_dashboard_files()
|
||||
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
|
||||
mock_extractall.assert_called_once()
|
||||
mock_get_version.return_value = f"v{VERSION}"
|
||||
|
||||
async def mock_aexit(obj, exc_type, exc, tb):
|
||||
return
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
await check_dashboard_files()
|
||||
# Assert that download_dashboard was NOT called
|
||||
mock_download.assert_not_called()
|
||||
|
||||
mock_extractall.__aenter__ = mock_aenter
|
||||
mock_extractall.__aexit__ = mock_aexit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
|
||||
"""Tests that a warning is logged when dashboard version mismatches."""
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: True)
|
||||
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
mock_get_version.return_value = "v0.0.1" # A different version
|
||||
|
||||
with mock.patch("main.logger.warning") as mock_logger_warning:
|
||||
await check_dashboard_files()
|
||||
mock_logger_warning.assert_called_once()
|
||||
call_args, _ = mock_logger_warning.call_args
|
||||
assert "不符" in call_args[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
|
||||
"""Tests that providing a valid webui_dir skips all checks."""
|
||||
valid_dir = "/tmp/my-custom-webui"
|
||||
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
|
||||
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
result = await check_dashboard_files(webui_dir=valid_dir)
|
||||
assert result == valid_dir
|
||||
mock_download.assert_not_called()
|
||||
mock_get_version.assert_not_called()
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
import pytest
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import (
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core.message.components import Plain, At
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.context import Context
|
||||
from asyncio import Queue
|
||||
|
||||
SESSION_ID_IN_WHITELIST = "test_sid_wl"
|
||||
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
|
||||
TEST_LLM_PROVIDER = {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [os.getenv("ZHIPU_API_KEY")],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
}
|
||||
|
||||
TEST_COMMANDS = [
|
||||
["help", "已注册的 AstrBot 内置指令"],
|
||||
["tool ls", "函数工具"],
|
||||
["tool on websearch", "激活工具"],
|
||||
["tool off websearch", "停用工具"],
|
||||
["plugin", "已加载的插件"],
|
||||
["t2i", "文本转图片模式"],
|
||||
["sid", "此 ID 可用于设置会话白名单。"],
|
||||
["op test_op", "授权成功。"],
|
||||
["deop test_op", "取消授权成功。"],
|
||||
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
|
||||
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
|
||||
["provider", "当前载入的 LLM 提供商"],
|
||||
["reset", "重置成功"],
|
||||
# ["model", "查看、切换提供商模型列表"],
|
||||
["history", "历史记录:"],
|
||||
["key", "当前 Key"],
|
||||
["persona", "[Persona]"],
|
||||
]
|
||||
|
||||
|
||||
class FakeAstrMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, abm: AstrBotMessage = None):
|
||||
meta = PlatformMetadata("test_platform", "test")
|
||||
super().__init__(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=meta,
|
||||
session_id=abm.session_id,
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
def create_fake_event(
|
||||
message_str: str,
|
||||
session_id: str = "test_sid",
|
||||
is_at: bool = False,
|
||||
is_group: bool = False,
|
||||
sender_id: str = "123456",
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
abm.message_str = message_str
|
||||
abm.group_id = "test"
|
||||
abm.message = [Plain(message_str)]
|
||||
if is_at:
|
||||
abm.message.append(At(qq="bot"))
|
||||
abm.self_id = "bot"
|
||||
abm.sender = MessageMember(sender_id, "mika")
|
||||
abm.timestamp = 1234567890
|
||||
abm.message_id = "test"
|
||||
abm.session_id = session_id
|
||||
if is_group:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
return FakeAstrMessageEvent(abm)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_queue():
|
||||
return Queue()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config():
|
||||
cfg = AstrBotConfig()
|
||||
cfg["platform_settings"]["id_whitelist"] = [
|
||||
"test_platform:FriendMessage:test_sid_wl",
|
||||
"test_platform:GroupMessage:test_sid_wl",
|
||||
]
|
||||
cfg["admins_id"] = ["123456"]
|
||||
cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"]
|
||||
cfg["provider"] = [TEST_LLM_PROVIDER]
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def db():
|
||||
return SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def platform_manager(event_queue, config):
|
||||
return PlatformManager(config, event_queue)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def provider_manager(config, db):
|
||||
return ProviderManager(config, db)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def star_context(event_queue, config, db, platform_manager, provider_manager):
|
||||
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
|
||||
return star_context
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def plugin_manager(star_context, config):
|
||||
plugin_manager = PluginManager(star_context, config)
|
||||
# await plugin_manager.reload()
|
||||
asyncio.run(plugin_manager.reload())
|
||||
return plugin_manager
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_context(config, plugin_manager):
|
||||
return PipelineContext(config, plugin_manager)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_scheduler(pipeline_context):
|
||||
return PipelineScheduler(pipeline_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_initialization(platform_manager: PlatformManager):
|
||||
await platform_manager.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_initialization(provider_manager: ProviderManager):
|
||||
await provider_manager.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
|
||||
await pipeline_scheduler.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
"""测试唤醒"""
|
||||
# 群聊无 @ 无指令
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"执行阶段 WhitelistCheckStage" not in message for message in caplog.messages
|
||||
)
|
||||
# 群聊有 @ 无指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"test", is_group=True, is_at=True
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
|
||||
# 群聊有指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert mock_event._has_send_oper is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wl(
|
||||
pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog
|
||||
):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"test", SESSION_ID_IN_WHITELIST, sender_id="123"
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"不在会话白名单中,已终止事件传播。" not in message
|
||||
for message in caplog.messages
|
||||
), "日志中未找到预期的消息"
|
||||
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"不在会话白名单中,已终止事件传播。" in message for message in caplog.messages
|
||||
), "日志中未找到预期的消息"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
# 测试默认屏蔽词
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"色情", session_id=SESSION_ID_IN_WHITELIST
|
||||
) # 测试需要。
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), (
|
||||
"日志中未找到预期的消息"
|
||||
)
|
||||
# 测试额外屏蔽词
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), (
|
||||
"日志中未找到预期的消息"
|
||||
)
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" not in message for message in caplog.messages)
|
||||
# TODO: 测试 百度AI 的内容安全检查
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert mock_event.get_result() is not None
|
||||
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert any(
|
||||
"web_searcher - search_from_search_engine" in message
|
||||
for message in caplog.messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
for command in TEST_COMMANDS:
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
command[0], session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
|
||||
assert any(command[1] in message for message in caplog.messages)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user