Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c8c28ebd5 | ||
|
|
524285f767 | ||
|
|
c2a34475f1 | ||
|
|
a69195a02b | ||
|
|
19d7438499 | ||
|
|
ccb380ce06 | ||
|
|
a35c439bbd | ||
|
|
09d1f96603 | ||
|
|
26aa18d980 | ||
|
|
d10b542797 | ||
|
|
ce4e4fb8dd | ||
|
|
8f4a31cf8c | ||
|
|
23549f13d6 | ||
|
|
869d11f9a6 | ||
|
|
02e73b82ee | ||
|
|
f85f87f545 | ||
|
|
1fff5713f3 | ||
|
|
8453ec36f0 | ||
|
|
d5b3ce8424 | ||
|
|
80cbbfa5ca | ||
|
|
9177bb660f | ||
|
|
a3df39a01a | ||
|
|
25dce05cbb | ||
|
|
1542ea3e03 | ||
|
|
6084abbcfe | ||
|
|
ed19b63914 | ||
|
|
4efeb85296 | ||
|
|
fc76665615 | ||
|
|
3a044bb71a | ||
|
|
cddd606562 | ||
|
|
7a5bc51c11 | ||
|
|
9f939b4b6f | ||
|
|
80a86f5b1b | ||
|
|
a0ce1855ab | ||
|
|
a4b43b884a | ||
|
|
824c0f6667 | ||
|
|
a030fe8491 | ||
|
|
3a9429e8ef | ||
|
|
c4eb1ab748 | ||
|
|
29ed19d600 | ||
|
|
0cc65513a5 | ||
|
|
debc048659 | ||
|
|
92f5c918dd | ||
|
|
9519f1e8e2 | ||
|
|
a8f874bf05 | ||
|
|
9d9917e45b | ||
|
|
91ee0a870d | ||
|
|
6cbbffc5a9 | ||
|
|
8f26fd34d1 | ||
|
|
fda655f6d7 |
4
.github/workflows/code-format.yml
vendored
4
.github/workflows/code-format.yml
vendored
@@ -12,10 +12,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -18,7 +18,8 @@
|
|||||||
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||||
<a href="https://astrbot.app/">查看文档</a> |
|
<a href="https://astrbot.app/">文档</a> |
|
||||||
|
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -110,7 +111,6 @@ uv run main.py
|
|||||||
|
|
||||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
|
|
||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## ⚡ 消息平台支持情况
|
||||||
|
|
||||||
| 平台 | 支持性 |
|
| 平台 | 支持性 |
|
||||||
@@ -127,6 +127,8 @@ uv run main.py
|
|||||||
| Discord | ✔ |
|
| Discord | ✔ |
|
||||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||||
|
| Satori | ✔ |
|
||||||
|
| Misskey | ✔ |
|
||||||
|
|
||||||
## ⚡ 提供商支持情况
|
## ⚡ 提供商支持情况
|
||||||
|
|
||||||
@@ -172,7 +174,6 @@ pip install pre-commit
|
|||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## ❤️ Special Thanks
|
## ❤️ Special Thanks
|
||||||
|
|
||||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||||
@@ -205,9 +206,6 @@ pre-commit install
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
|
|||||||
@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
|
|||||||
class Agent(Generic[TContext]):
|
class Agent(Generic[TContext]):
|
||||||
name: str
|
name: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
tools: list[str, FunctionTool] | None = None
|
tools: list[str | FunctionTool] | None = None
|
||||||
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class MCPClient:
|
|||||||
self.session: Optional[mcp.ClientSession] = None
|
self.session: Optional[mcp.ClientSession] = None
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
self.name = None
|
self.name: str | None = None
|
||||||
self.active: bool = True
|
self.active: bool = True
|
||||||
self.tools: list[mcp.Tool] = []
|
self.tools: list[mcp.Tool] = []
|
||||||
self.server_errlogs: list[str] = []
|
self.server_errlogs: list[str] = []
|
||||||
@@ -198,6 +198,8 @@ class MCPClient:
|
|||||||
|
|
||||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||||
"""List all tools from the server and save them to self.tools"""
|
"""List all tools from the server and save them to self.tools"""
|
||||||
|
if not self.session:
|
||||||
|
raise Exception("MCP Client is not initialized")
|
||||||
response = await self.session.list_tools()
|
response = await self.session.list_tools()
|
||||||
self.tools = response.tools
|
self.tools = response.tools
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
yield MessageChain(
|
yield MessageChain(
|
||||||
type="tool_direct_result"
|
type="tool_direct_result"
|
||||||
).base64_image(res.content[0].data)
|
).base64_image(resource.blob)
|
||||||
else:
|
else:
|
||||||
tool_call_result_blocks.append(
|
tool_call_result_blocks.append(
|
||||||
ToolCallMessageSegment(
|
ToolCallMessageSegment(
|
||||||
@@ -269,17 +269,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
yield MessageChain().message("返回的数据类型不受支持。")
|
yield MessageChain().message("返回的数据类型不受支持。")
|
||||||
|
|
||||||
try:
|
|
||||||
await self.agent_hooks.on_tool_end(
|
|
||||||
self.run_context,
|
|
||||||
func_tool_name,
|
|
||||||
func_tool_args,
|
|
||||||
resp,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
elif resp is None:
|
elif resp is None:
|
||||||
# Tool 直接请求发送消息给用户
|
# Tool 直接请求发送消息给用户
|
||||||
# 这里我们将直接结束 Agent Loop。
|
# 这里我们将直接结束 Agent Loop。
|
||||||
@@ -289,14 +278,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
yield MessageChain(
|
yield MessageChain(
|
||||||
chain=res.chain, type="tool_direct_result"
|
chain=res.chain, type="tool_direct_result"
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
await self.agent_hooks.on_tool_end(
|
|
||||||
self.run_context, func_tool_name, func_tool_args, None
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||||
@@ -304,12 +285,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self.agent_hooks.on_tool_end(
|
await self.agent_hooks.on_tool_end(
|
||||||
self.run_context, func_tool_name, func_tool_args, None
|
self.run_context, func_tool, func_tool_args, None
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.run_context.event.clear_result()
|
self.run_context.event.clear_result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
from typing import Awaitable, Literal, Any, Optional
|
from typing import Awaitable, Callable, Literal, Any, Optional
|
||||||
from .mcp_client import MCPClient
|
from .mcp_client import MCPClient
|
||||||
|
|
||||||
|
|
||||||
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
|
|||||||
class FunctionTool:
|
class FunctionTool:
|
||||||
"""A class representing a function tool that can be used in function calling."""
|
"""A class representing a function tool that can be used in function calling."""
|
||||||
|
|
||||||
name: str | None = None
|
name: str
|
||||||
parameters: dict | None = None
|
parameters: dict | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
handler: Awaitable | None = None
|
handler: Callable[..., Awaitable[Any]] | None = None
|
||||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||||
handler_module_path: str | None = None
|
handler_module_path: str | None = None
|
||||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||||
@@ -51,7 +51,7 @@ class ToolSet:
|
|||||||
This class provides methods to add, remove, and retrieve tools, as well as
|
This class provides methods to add, remove, and retrieve tools, as well as
|
||||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||||
|
|
||||||
def __init__(self, tools: list[FunctionTool] = None):
|
def __init__(self, tools: list[FunctionTool] | None = None):
|
||||||
self.tools: list[FunctionTool] = tools or []
|
self.tools: list[FunctionTool] = tools or []
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
@@ -79,7 +79,13 @@ class ToolSet:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||||
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
def add_func(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
func_args: list,
|
||||||
|
desc: str,
|
||||||
|
handler: Callable[..., Awaitable[Any]],
|
||||||
|
):
|
||||||
"""Add a function tool to the set."""
|
"""Add a function tool to the set."""
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
@@ -104,7 +110,7 @@ class ToolSet:
|
|||||||
self.remove_tool(name)
|
self.remove_tool(name)
|
||||||
|
|
||||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||||
def get_func(self, name: str) -> list[FunctionTool]:
|
def get_func(self, name: str) -> FunctionTool | None:
|
||||||
"""Get all function tools."""
|
"""Get all function tools."""
|
||||||
return self.get_tool(name)
|
return self.get_tool(name)
|
||||||
|
|
||||||
@@ -125,7 +131,11 @@ class ToolSet:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool.parameters.get("properties") or not omit_empty_parameter_field:
|
if (
|
||||||
|
tool.parameters
|
||||||
|
and tool.parameters.get("properties")
|
||||||
|
or not omit_empty_parameter_field
|
||||||
|
):
|
||||||
func_def["function"]["parameters"] = tool.parameters
|
func_def["function"]["parameters"] = tool.parameters
|
||||||
|
|
||||||
result.append(func_def)
|
result.append(func_def)
|
||||||
@@ -135,14 +145,14 @@ class ToolSet:
|
|||||||
"""Convert tools to Anthropic API format."""
|
"""Convert tools to Anthropic API format."""
|
||||||
result = []
|
result = []
|
||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
|
input_schema = {"type": "object"}
|
||||||
|
if tool.parameters:
|
||||||
|
input_schema["properties"] = tool.parameters.get("properties", {})
|
||||||
|
input_schema["required"] = tool.parameters.get("required", [])
|
||||||
tool_def = {
|
tool_def = {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"input_schema": {
|
"input_schema": input_schema,
|
||||||
"type": "object",
|
|
||||||
"properties": tool.parameters.get("properties", {}),
|
|
||||||
"required": tool.parameters.get("required", []),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
result.append(tool_def)
|
result.append(tool_def)
|
||||||
return result
|
return result
|
||||||
@@ -210,14 +220,15 @@ class ToolSet:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
tools = [
|
tools = []
|
||||||
{
|
for tool in self.tools:
|
||||||
|
d = {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": convert_schema(tool.parameters),
|
|
||||||
}
|
}
|
||||||
for tool in self.tools
|
if tool.parameters:
|
||||||
]
|
d["parameters"] = convert_schema(tool.parameters)
|
||||||
|
tools.append(d)
|
||||||
|
|
||||||
declarations = {}
|
declarations = {}
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.1.2"
|
VERSION = "4.2.0"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
@@ -60,6 +60,7 @@ DEFAULT_CONFIG = {
|
|||||||
"web_search_link": False,
|
"web_search_link": False,
|
||||||
"display_reasoning_text": False,
|
"display_reasoning_text": False,
|
||||||
"identifier": False,
|
"identifier": False,
|
||||||
|
"group_name_display": False,
|
||||||
"datetime_system_prompt": True,
|
"datetime_system_prompt": True,
|
||||||
"default_personality": "default",
|
"default_personality": "default",
|
||||||
"persona_pool": ["*"],
|
"persona_pool": ["*"],
|
||||||
@@ -235,6 +236,16 @@ CONFIG_METADATA_2 = {
|
|||||||
"discord_guild_id_for_debug": "",
|
"discord_guild_id_for_debug": "",
|
||||||
"discord_activity_name": "",
|
"discord_activity_name": "",
|
||||||
},
|
},
|
||||||
|
"Misskey": {
|
||||||
|
"id": "misskey",
|
||||||
|
"type": "misskey",
|
||||||
|
"enable": False,
|
||||||
|
"misskey_instance_url": "https://misskey.example",
|
||||||
|
"misskey_token": "",
|
||||||
|
"misskey_default_visibility": "public",
|
||||||
|
"misskey_local_only": False,
|
||||||
|
"misskey_enable_chat": True,
|
||||||
|
},
|
||||||
"Slack": {
|
"Slack": {
|
||||||
"id": "slack",
|
"id": "slack",
|
||||||
"type": "slack",
|
"type": "slack",
|
||||||
@@ -252,7 +263,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "satori",
|
"type": "satori",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"satori_api_base_url": "http://localhost:5140/satori/v1",
|
"satori_api_base_url": "http://localhost:5140/satori/v1",
|
||||||
"satori_endpoint": "ws://127.0.0.1:5140/satori/v1/events",
|
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
|
||||||
"satori_token": "",
|
"satori_token": "",
|
||||||
"satori_auto_reconnect": True,
|
"satori_auto_reconnect": True,
|
||||||
"satori_heartbeat_interval": 10,
|
"satori_heartbeat_interval": 10,
|
||||||
@@ -261,34 +272,34 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
"satori_api_base_url": {
|
"satori_api_base_url": {
|
||||||
"description": "Satori API Base URL",
|
"description": "Satori API 终结点",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "The base URL for the Satori API.",
|
"hint": "Satori API 的基础地址。",
|
||||||
},
|
},
|
||||||
"satori_endpoint": {
|
"satori_endpoint": {
|
||||||
"description": "Satori WebSocket Endpoint",
|
"description": "Satori WebSocket 终结点",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "The WebSocket endpoint for Satori events.",
|
"hint": "Satori 事件的 WebSocket 端点。",
|
||||||
},
|
},
|
||||||
"satori_token": {
|
"satori_token": {
|
||||||
"description": "Satori Token",
|
"description": "Satori 令牌",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "The token used for authenticating with the Satori API.",
|
"hint": "用于 Satori API 身份验证的令牌。",
|
||||||
},
|
},
|
||||||
"satori_auto_reconnect": {
|
"satori_auto_reconnect": {
|
||||||
"description": "Enable Auto Reconnect",
|
"description": "启用自动重连",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "Whether to automatically reconnect the WebSocket on disconnection.",
|
"hint": "断开连接时是否自动重新连接 WebSocket。",
|
||||||
},
|
},
|
||||||
"satori_heartbeat_interval": {
|
"satori_heartbeat_interval": {
|
||||||
"description": "Satori Heartbeat Interval",
|
"description": "Satori 心跳间隔",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"hint": "The interval (in seconds) for sending heartbeat messages.",
|
"hint": "发送心跳消息的间隔(秒)。",
|
||||||
},
|
},
|
||||||
"satori_reconnect_delay": {
|
"satori_reconnect_delay": {
|
||||||
"description": "Satori Reconnect Delay",
|
"description": "Satori 重连延迟",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"hint": "The delay (in seconds) before attempting to reconnect.",
|
"hint": "尝试重新连接前的延迟时间(秒)。",
|
||||||
},
|
},
|
||||||
"slack_connection_mode": {
|
"slack_connection_mode": {
|
||||||
"description": "Slack Connection Mode",
|
"description": "Slack Connection Mode",
|
||||||
@@ -336,6 +347,32 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||||
},
|
},
|
||||||
|
"misskey_instance_url": {
|
||||||
|
"description": "Misskey 实例 URL",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "例如 https://misskey.example,填写 Bot 账号所在的 Misskey 实例地址",
|
||||||
|
},
|
||||||
|
"misskey_token": {
|
||||||
|
"description": "Misskey Access Token",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "连接服务设置生成的 API 鉴权访问令牌(Access token)",
|
||||||
|
},
|
||||||
|
"misskey_default_visibility": {
|
||||||
|
"description": "默认帖子可见性",
|
||||||
|
"type": "string",
|
||||||
|
"options": ["public", "home", "followers"],
|
||||||
|
"hint": "机器人发帖时的默认可见性设置。public:公开,home:主页时间线,followers:仅关注者。",
|
||||||
|
},
|
||||||
|
"misskey_local_only": {
|
||||||
|
"description": "仅限本站(不参与联合)",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
|
||||||
|
},
|
||||||
|
"misskey_enable_chat": {
|
||||||
|
"description": "启用聊天消息响应",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,机器人将会监听和响应私信聊天消息",
|
||||||
|
},
|
||||||
"telegram_command_register": {
|
"telegram_command_register": {
|
||||||
"description": "Telegram 命令注册",
|
"description": "Telegram 命令注册",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -832,6 +869,18 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||||
},
|
},
|
||||||
|
"Coze": {
|
||||||
|
"id": "coze",
|
||||||
|
"provider": "coze",
|
||||||
|
"provider_type": "chat_completion",
|
||||||
|
"type": "coze",
|
||||||
|
"enable": True,
|
||||||
|
"coze_api_key": "",
|
||||||
|
"bot_id": "",
|
||||||
|
"coze_api_base": "https://api.coze.cn",
|
||||||
|
"timeout": 60,
|
||||||
|
"auto_save_history": True,
|
||||||
|
},
|
||||||
"阿里云百炼应用": {
|
"阿里云百炼应用": {
|
||||||
"id": "dashscope",
|
"id": "dashscope",
|
||||||
"provider": "dashscope",
|
"provider": "dashscope",
|
||||||
@@ -1698,6 +1747,26 @@ CONFIG_METADATA_2 = {
|
|||||||
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
||||||
"obvious": True,
|
"obvious": True,
|
||||||
},
|
},
|
||||||
|
"coze_api_key": {
|
||||||
|
"description": "Coze API Key",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "Coze API 密钥,用于访问 Coze 服务。",
|
||||||
|
},
|
||||||
|
"bot_id": {
|
||||||
|
"description": "Bot ID",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "Coze 机器人的 ID,在 Coze 平台上创建机器人后获得。",
|
||||||
|
},
|
||||||
|
"coze_api_base": {
|
||||||
|
"description": "API Base URL",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
|
||||||
|
},
|
||||||
|
"auto_save_history": {
|
||||||
|
"description": "由 Coze 管理对话记录",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_settings": {
|
"provider_settings": {
|
||||||
@@ -1724,6 +1793,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"identifier": {
|
"identifier": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
|
"group_name_display": {
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
"datetime_system_prompt": {
|
"datetime_system_prompt": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
@@ -1903,17 +1975,31 @@ CONFIG_METADATA_3 = {
|
|||||||
"_special": "select_provider",
|
"_special": "select_provider",
|
||||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||||
},
|
},
|
||||||
|
"provider_stt_settings.enable": {
|
||||||
|
"description": "默认启用语音转文本",
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
"provider_stt_settings.provider_id": {
|
"provider_stt_settings.provider_id": {
|
||||||
"description": "语音转文本模型",
|
"description": "语音转文本模型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "留空代表不使用。",
|
"hint": "留空代表不使用。",
|
||||||
"_special": "select_provider_stt",
|
"_special": "select_provider_stt",
|
||||||
|
"condition": {
|
||||||
|
"provider_stt_settings.enable": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"provider_tts_settings.enable": {
|
||||||
|
"description": "默认启用文本转语音",
|
||||||
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"provider_tts_settings.provider_id": {
|
"provider_tts_settings.provider_id": {
|
||||||
"description": "文本转语音模型",
|
"description": "文本转语音模型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "留空代表不使用。",
|
"hint": "留空代表不使用。",
|
||||||
"_special": "select_provider_tts",
|
"_special": "select_provider_tts",
|
||||||
|
"condition": {
|
||||||
|
"provider_tts_settings.enable": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"provider_settings.image_caption_prompt": {
|
"provider_settings.image_caption_prompt": {
|
||||||
"description": "图片转述提示词",
|
"description": "图片转述提示词",
|
||||||
@@ -1983,6 +2069,11 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "用户识别",
|
"description": "用户识别",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
|
"provider_settings.group_name_display": {
|
||||||
|
"description": "显示群名称",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
|
||||||
|
},
|
||||||
"provider_settings.datetime_system_prompt": {
|
"provider_settings.datetime_system_prompt": {
|
||||||
"description": "现实世界时间感知",
|
"description": "现实世界时间感知",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
|
|||||||
@@ -87,14 +87,22 @@ class ConversationManager:
|
|||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
"""
|
"""
|
||||||
f = False
|
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
if conversation_id:
|
|
||||||
f = True
|
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
await self.db.delete_conversation(cid=conversation_id)
|
await self.db.delete_conversation(cid=conversation_id)
|
||||||
if f:
|
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
|
||||||
|
if curr_cid == conversation_id:
|
||||||
|
self.session_conversations.pop(unified_msg_origin, None)
|
||||||
|
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||||
|
|
||||||
|
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
|
||||||
|
"""删除会话的所有对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
"""
|
||||||
|
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
||||||
self.session_conversations.pop(unified_msg_origin, None)
|
self.session_conversations.pop(unified_msg_origin, None)
|
||||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,11 @@ class BaseDatabase(abc.ABC):
|
|||||||
"""Delete a conversation by its ID."""
|
"""Delete a conversation by its ID."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||||
|
"""Delete all conversations for a specific user."""
|
||||||
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def insert_platform_message_history(
|
async def insert_platform_message_history(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from astrbot.core.db.po import (
|
|||||||
from sqlalchemy import select, update, delete, text
|
from sqlalchemy import select, update, delete, text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||||
|
|
||||||
@@ -153,8 +154,22 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
ConversationV2.platform_id.in_(platform_ids)
|
ConversationV2.platform_id.in_(platform_ids)
|
||||||
)
|
)
|
||||||
if search_query:
|
if search_query:
|
||||||
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
ConversationV2.title.ilike(f"%{search_query}%")
|
or_(
|
||||||
|
ConversationV2.title.ilike(f"%{search_query}%"),
|
||||||
|
ConversationV2.content.ilike(f"%{search_query}%"),
|
||||||
|
ConversationV2.user_id.ilike(f"%{search_query}%"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||||
|
for msg_type in kwargs["message_types"]:
|
||||||
|
base_query = base_query.where(
|
||||||
|
ConversationV2.user_id.ilike(f"%:{msg_type}:%")
|
||||||
|
)
|
||||||
|
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||||
|
base_query = base_query.where(
|
||||||
|
ConversationV2.platform_id.in_(kwargs["platforms"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get total count matching the filters
|
# Get total count matching the filters
|
||||||
@@ -234,6 +249,14 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
|
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
await session.execute(
|
||||||
|
delete(ConversationV2).where(ConversationV2.user_id == user_id)
|
||||||
|
)
|
||||||
|
|
||||||
async def insert_platform_message_history(
|
async def insert_platform_message_history(
|
||||||
self,
|
self,
|
||||||
platform_id,
|
platform_id,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
|
|||||||
self.strategy_selector = StrategySelector(config)
|
self.strategy_selector = StrategySelector(config)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent, check_text: str = None
|
self, event: AstrMessageEvent, check_text: str | None = None
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
"""检查内容安全"""
|
"""检查内容安全"""
|
||||||
text = check_text if check_text else event.get_message_str()
|
text = check_text if check_text else event.get_message_str()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
|
|||||||
self.secret_key = sk
|
self.secret_key = sk
|
||||||
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
||||||
|
|
||||||
def check(self, content: str):
|
def check(self, content: str) -> tuple[bool, str]:
|
||||||
res = self.client.textCensorUserDefined(content)
|
res = self.client.textCensorUserDefined(content)
|
||||||
if "conclusionType" not in res:
|
if "conclusionType" not in res:
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
|
|||||||
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
||||||
# )
|
# )
|
||||||
|
|
||||||
def check(self, content: str) -> bool:
|
def check(self, content: str) -> tuple[bool, str]:
|
||||||
for keyword in self.keywords:
|
for keyword in self.keywords:
|
||||||
if re.search(keyword, content):
|
if re.search(keyword, content):
|
||||||
return False, "内容安全检查不通过,匹配到敏感词。"
|
return False, "内容安全检查不通过,匹配到敏感词。"
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
|
|
||||||
async def call_handler(
|
async def call_handler(
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
handler: T.Awaitable,
|
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T.AsyncGenerator[T.Any, None]:
|
) -> T.AsyncGenerator[T.Any, None]:
|
||||||
@@ -36,6 +36,9 @@ async def call_handler(
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||||
|
|
||||||
|
if not ready_to_call:
|
||||||
|
return
|
||||||
|
|
||||||
if inspect.isasyncgen(ready_to_call):
|
if inspect.isasyncgen(ready_to_call):
|
||||||
_has_yielded = False
|
_has_yielded = False
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from typing import AsyncGenerator, Union
|
from typing import AsyncGenerator, Union
|
||||||
|
from astrbot.core.conversation_mgr import Conversation
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
@@ -133,6 +134,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
|
|
||||||
if agent_runner.done():
|
if agent_runner.done():
|
||||||
llm_response = agent_runner.get_final_llm_resp()
|
llm_response = agent_runner.get_final_llm_resp()
|
||||||
|
|
||||||
|
if not llm_response:
|
||||||
|
text_content = mcp.types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"error when deligate task to {tool.agent.name}",
|
||||||
|
)
|
||||||
|
yield mcp.types.CallToolResult(content=[text_content])
|
||||||
|
return
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
||||||
)
|
)
|
||||||
@@ -148,7 +158,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
)
|
)
|
||||||
yield mcp.types.CallToolResult(content=[text_content])
|
yield mcp.types.CallToolResult(content=[text_content])
|
||||||
else:
|
else:
|
||||||
yield mcp.types.TextContent(
|
text_content = mcp.types.TextContent(
|
||||||
type="text",
|
type="text",
|
||||||
text=f"error when deligate task to {tool.agent.name}",
|
text=f"error when deligate task to {tool.agent.name}",
|
||||||
)
|
)
|
||||||
@@ -200,7 +210,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
):
|
):
|
||||||
if not tool.mcp_client:
|
if not tool.mcp_client:
|
||||||
raise ValueError("MCP client is not available for MCP function tools.")
|
raise ValueError("MCP client is not available for MCP function tools.")
|
||||||
res = await tool.mcp_client.session.call_tool(
|
|
||||||
|
session = tool.mcp_client.session
|
||||||
|
if not session:
|
||||||
|
raise ValueError("MCP session is not available for MCP function tools.")
|
||||||
|
res = await session.call_tool(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool_args,
|
arguments=tool_args,
|
||||||
)
|
)
|
||||||
@@ -271,19 +285,12 @@ async def run_agent(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
astr_event.set_result(
|
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||||
MessageEventResult().message(
|
if agent_runner.streaming:
|
||||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
yield MessageChain().message(err_msg)
|
||||||
)
|
else:
|
||||||
)
|
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||||
return
|
return
|
||||||
asyncio.create_task(
|
|
||||||
Metric.upload(
|
|
||||||
llm_tick=1,
|
|
||||||
model_name=agent_runner.provider.get_model(),
|
|
||||||
provider_type=agent_runner.provider.meta().type,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestSubStage(Stage):
|
class LLMRequestSubStage(Stage):
|
||||||
@@ -325,7 +332,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||||
|
|
||||||
async def _get_session_conv(self, event: AstrMessageEvent):
|
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
|
||||||
umo = event.unified_msg_origin
|
umo = event.unified_msg_origin
|
||||||
conv_mgr = self.conv_manager
|
conv_mgr = self.conv_manager
|
||||||
|
|
||||||
@@ -337,6 +344,8 @@ class LLMRequestSubStage(Stage):
|
|||||||
if not conversation:
|
if not conversation:
|
||||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||||
|
if not conversation:
|
||||||
|
raise RuntimeError("无法创建新的对话。")
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
@@ -444,7 +453,10 @@ class LLMRequestSubStage(Stage):
|
|||||||
if event.plugins_name is not None and req.func_tool:
|
if event.plugins_name is not None and req.func_tool:
|
||||||
new_tool_set = ToolSet()
|
new_tool_set = ToolSet()
|
||||||
for tool in req.func_tool.tools:
|
for tool in req.func_tool.tools:
|
||||||
plugin = star_map.get(tool.handler_module_path)
|
mp = tool.handler_module_path
|
||||||
|
if not mp:
|
||||||
|
continue
|
||||||
|
plugin = star_map.get(mp)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
continue
|
continue
|
||||||
if plugin.name in event.plugins_name or plugin.reserved:
|
if plugin.name in event.plugins_name or plugin.reserved:
|
||||||
@@ -505,6 +517,14 @@ class LLMRequestSubStage(Stage):
|
|||||||
if event.get_platform_name() == "webchat":
|
if event.get_platform_name() == "webchat":
|
||||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(
|
||||||
|
llm_tick=1,
|
||||||
|
model_name=agent_runner.provider.get_model(),
|
||||||
|
provider_type=agent_runner.provider.meta().type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def _handle_webchat(
|
async def _handle_webchat(
|
||||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||||
):
|
):
|
||||||
@@ -517,7 +537,23 @@ class LLMRequestSubStage(Stage):
|
|||||||
latest_pair = messages[-2:]
|
latest_pair = messages[-2:]
|
||||||
if not latest_pair:
|
if not latest_pair:
|
||||||
return
|
return
|
||||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
content = latest_pair[0].get("content", "")
|
||||||
|
if isinstance(content, list):
|
||||||
|
# 多模态
|
||||||
|
text_parts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
if item.get("type") == "text":
|
||||||
|
text_parts.append(item.get("text", ""))
|
||||||
|
elif item.get("type") == "image":
|
||||||
|
text_parts.append("[图片]")
|
||||||
|
elif isinstance(item, str):
|
||||||
|
text_parts.append(item)
|
||||||
|
cleaned_text = "User: " + " ".join(text_parts).strip()
|
||||||
|
elif isinstance(content, str):
|
||||||
|
cleaned_text = "User: " + content.strip()
|
||||||
|
else:
|
||||||
|
return
|
||||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||||
llm_resp = await prov.text_chat(
|
llm_resp = await prov.text_chat(
|
||||||
system_prompt="You are expert in summarizing user's query.",
|
system_prompt="You are expert in summarizing user's query.",
|
||||||
|
|||||||
@@ -34,12 +34,14 @@ class StarRequestSubStage(Stage):
|
|||||||
|
|
||||||
for handler in activated_handlers:
|
for handler in activated_handlers:
|
||||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||||
try:
|
md = star_map.get(handler.handler_module_path)
|
||||||
if handler.handler_module_path not in star_map:
|
if not md:
|
||||||
continue
|
logger.warning(
|
||||||
logger.debug(
|
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
|
||||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
|
||||||
|
try:
|
||||||
wrapper = call_handler(event, handler.handler, **params)
|
wrapper = call_handler(event, handler.handler, **params)
|
||||||
async for ret in wrapper:
|
async for ret in wrapper:
|
||||||
yield ret
|
yield ret
|
||||||
@@ -49,7 +51,7 @@ class StarRequestSubStage(Stage):
|
|||||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||||
|
|
||||||
if event.is_at_or_wake_command:
|
if event.is_at_or_wake_command:
|
||||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||||
event.set_result(MessageEventResult().message(ret))
|
event.set_result(MessageEventResult().message(ret))
|
||||||
yield
|
yield
|
||||||
event.clear_result()
|
event.clear_result()
|
||||||
|
|||||||
@@ -1,17 +1,15 @@
|
|||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
import traceback
|
|
||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ..stage import register_stage, Stage
|
from ..stage import register_stage, Stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext, call_event_hook
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
from astrbot.core.message.components import BaseMessageComponent, ComponentType
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import EventType
|
||||||
from astrbot.core.star.star import star_map
|
|
||||||
from astrbot.core.utils.path_util import path_Mapping
|
from astrbot.core.utils.path_util import path_Mapping
|
||||||
from astrbot.core.utils.session_lock import session_lock_manager
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
@@ -114,6 +112,43 @@ class RespondStage(Stage):
|
|||||||
# 如果所有组件都为空
|
# 如果所有组件都为空
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
|
||||||
|
"""检查是否需要分段回复"""
|
||||||
|
if not self.enable_seg:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if event.get_platform_name() in [
|
||||||
|
"qq_official",
|
||||||
|
"weixin_official_account",
|
||||||
|
"dingtalk",
|
||||||
|
]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _extract_comp(
|
||||||
|
self,
|
||||||
|
raw_chain: list[BaseMessageComponent],
|
||||||
|
extract_types: set[ComponentType],
|
||||||
|
modify_raw_chain: bool = True,
|
||||||
|
):
|
||||||
|
extracted = []
|
||||||
|
if modify_raw_chain:
|
||||||
|
remaining = []
|
||||||
|
for comp in raw_chain:
|
||||||
|
if comp.type in extract_types:
|
||||||
|
extracted.append(comp)
|
||||||
|
else:
|
||||||
|
remaining.append(comp)
|
||||||
|
raw_chain[:] = remaining
|
||||||
|
else:
|
||||||
|
extracted = [comp for comp in raw_chain if comp.type in extract_types]
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
@@ -123,7 +158,14 @@ class RespondStage(Stage):
|
|||||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||||
|
)
|
||||||
|
|
||||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
|
if result.async_stream is None:
|
||||||
|
logger.warning("async_stream 为空,跳过发送。")
|
||||||
|
return
|
||||||
# 流式结果直接交付平台适配器处理
|
# 流式结果直接交付平台适配器处理
|
||||||
use_fallback = self.config.get("provider_settings", {}).get(
|
use_fallback = self.config.get("provider_settings", {}).get(
|
||||||
"streaming_segmented", False
|
"streaming_segmented", False
|
||||||
@@ -148,87 +190,71 @@ class RespondStage(Stage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"空内容检查异常: {e}")
|
logger.warning(f"空内容检查异常: {e}")
|
||||||
|
|
||||||
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
# 发送消息链
|
||||||
non_record_comps = [
|
# Record 需要强制单独发送
|
||||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
need_separately = {ComponentType.Record}
|
||||||
]
|
if self.is_seg_reply_required(event):
|
||||||
|
header_comps = self._extract_comp(
|
||||||
if (
|
result.chain,
|
||||||
self.enable_seg
|
{ComponentType.Reply, ComponentType.At},
|
||||||
and (
|
modify_raw_chain=True,
|
||||||
(self.only_llm_result and result.is_llm_result())
|
|
||||||
or not self.only_llm_result
|
|
||||||
)
|
)
|
||||||
and event.get_platform_name()
|
if not result.chain or len(result.chain) == 0:
|
||||||
not in ["qq_official", "weixin_official_account", "dingtalk"]
|
# may fix #2670
|
||||||
):
|
logger.warning(
|
||||||
decorated_comps = []
|
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
|
||||||
if self.reply_with_mention:
|
)
|
||||||
for comp in result.chain:
|
return
|
||||||
if isinstance(comp, Comp.At):
|
|
||||||
decorated_comps.append(comp)
|
|
||||||
result.chain.remove(comp)
|
|
||||||
break
|
|
||||||
if self.reply_with_quote:
|
|
||||||
for comp in result.chain:
|
|
||||||
if isinstance(comp, Comp.Reply):
|
|
||||||
decorated_comps.append(comp)
|
|
||||||
result.chain.remove(comp)
|
|
||||||
break
|
|
||||||
|
|
||||||
# leverage lock to guarentee the order of message sending among different events
|
|
||||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||||
for rcomp in record_comps:
|
for comp in result.chain:
|
||||||
i = await self._calc_comp_interval(rcomp)
|
|
||||||
await asyncio.sleep(i)
|
|
||||||
try:
|
|
||||||
await event.send(MessageChain([rcomp]))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
|
||||||
break
|
|
||||||
# 分段回复
|
|
||||||
for comp in non_record_comps:
|
|
||||||
i = await self._calc_comp_interval(comp)
|
i = await self._calc_comp_interval(comp)
|
||||||
await asyncio.sleep(i)
|
await asyncio.sleep(i)
|
||||||
try:
|
try:
|
||||||
await event.send(MessageChain([*decorated_comps, comp]))
|
if comp.type in need_separately:
|
||||||
decorated_comps = [] # 清空已发送的装饰组件
|
await event.send(MessageChain([comp]))
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
for rcomp in record_comps:
|
await event.send(MessageChain([*header_comps, comp]))
|
||||||
try:
|
header_comps.clear()
|
||||||
await event.send(MessageChain([rcomp]))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
logger.error(
|
||||||
|
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if all(
|
||||||
|
comp.type in {ComponentType.Reply, ComponentType.At}
|
||||||
|
for comp in result.chain
|
||||||
|
):
|
||||||
|
# may fix #2670
|
||||||
|
logger.warning(
|
||||||
|
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
sep_comps = self._extract_comp(
|
||||||
|
result.chain,
|
||||||
|
need_separately,
|
||||||
|
modify_raw_chain=True,
|
||||||
|
)
|
||||||
|
for comp in sep_comps:
|
||||||
|
chain = MessageChain([comp])
|
||||||
try:
|
try:
|
||||||
await event.send(MessageChain(non_record_comps))
|
await event.send(chain)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
f"发送消息链失败: chain = {chain}, error = {e}",
|
||||||
|
exc_info=True,
|
||||||
logger.info(
|
|
||||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
|
||||||
)
|
)
|
||||||
|
chain = MessageChain(result.chain)
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
if result.chain and len(result.chain) > 0:
|
||||||
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
|
|
||||||
)
|
|
||||||
for handler in handlers:
|
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
await event.send(chain)
|
||||||
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"发送消息链失败: chain = {chain}, error = {e}",
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
await handler.handler(event)
|
|
||||||
except BaseException:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
if event.is_stopped():
|
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
|
||||||
logger.info(
|
|
||||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
event.clear_result()
|
event.clear_result()
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
|
|||||||
"""检查会话是否整体启用"""
|
"""检查会话是否整体启用"""
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
pass
|
self.ctx = ctx
|
||||||
|
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
|
|||||||
# 检查会话是否整体启用
|
# 检查会话是否整体启用
|
||||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||||
|
|
||||||
|
# workaround for #2309
|
||||||
|
conv_id = await self.conv_mgr.get_curr_conversation_id(
|
||||||
|
event.unified_msg_origin
|
||||||
|
)
|
||||||
|
if not conv_id:
|
||||||
|
await self.conv_mgr.new_conversation(
|
||||||
|
event.unified_msg_origin, platform_id=event.get_platform_id()
|
||||||
|
)
|
||||||
|
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
|
|||||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||||
|
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||||
@@ -170,10 +171,14 @@ class WakingCheckStage(Stage):
|
|||||||
is_wake = True
|
is_wake = True
|
||||||
event.is_wake = True
|
event.is_wake = True
|
||||||
|
|
||||||
|
is_group_cmd_handler = any(
|
||||||
|
isinstance(f, CommandGroupFilter) for f in handler.event_filters
|
||||||
|
)
|
||||||
|
if not is_group_cmd_handler:
|
||||||
activated_handlers.append(handler)
|
activated_handlers.append(handler)
|
||||||
if "parsed_params" in event.get_extra():
|
if "parsed_params" in event.get_extra(default={}):
|
||||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
handlers_parsed_params[handler.handler_full_name] = (
|
||||||
"parsed_params"
|
event.get_extra("parsed_params")
|
||||||
)
|
)
|
||||||
|
|
||||||
event._extras.pop("parsed_params", None)
|
event._extras.pop("parsed_params", None)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import re
|
|||||||
import hashlib
|
import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from typing import List, Union, Optional, AsyncGenerator
|
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.db.po import Conversation
|
from astrbot.core.db.po import Conversation
|
||||||
@@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group
|
|||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
from .message_session import MessageSession, MessageSesion # noqa
|
from .message_session import MessageSession, MessageSesion # noqa
|
||||||
|
|
||||||
|
_VT = TypeVar("_VT")
|
||||||
|
|
||||||
|
|
||||||
class AstrMessageEvent(abc.ABC):
|
class AstrMessageEvent(abc.ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
"""是否唤醒(是否通过 WakingStage)"""
|
"""是否唤醒(是否通过 WakingStage)"""
|
||||||
self.is_at_or_wake_command = False
|
self.is_at_or_wake_command = False
|
||||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||||
self._extras = {}
|
self._extras: dict[str, Any] = {}
|
||||||
self.session = MessageSesion(
|
self.session = MessageSesion(
|
||||||
platform_name=platform_meta.id,
|
platform_name=platform_meta.id,
|
||||||
message_type=message_obj.type,
|
message_type=message_obj.type,
|
||||||
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
)
|
)
|
||||||
self.unified_msg_origin = str(self.session)
|
self.unified_msg_origin = str(self.session)
|
||||||
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
||||||
self._result: MessageEventResult = None
|
self._result: MessageEventResult | None = None
|
||||||
"""消息事件的结果"""
|
"""消息事件的结果"""
|
||||||
|
|
||||||
self._has_send_oper = False
|
self._has_send_oper = False
|
||||||
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
self._extras[key] = value
|
self._extras[key] = value
|
||||||
|
|
||||||
def get_extra(self, key=None):
|
def get_extra(
|
||||||
|
self, key: str | None = None, default: _VT = None
|
||||||
|
) -> dict[str, Any] | _VT:
|
||||||
"""
|
"""
|
||||||
获取额外的信息。
|
获取额外的信息。
|
||||||
"""
|
"""
|
||||||
if key is None:
|
if key is None:
|
||||||
return self._extras
|
return self._extras
|
||||||
return self._extras.get(key, None)
|
return self._extras.get(key, default)
|
||||||
|
|
||||||
def clear_extra(self):
|
def clear_extra(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class AstrBotMessage:
|
|||||||
self_id: str # 机器人的识别id
|
self_id: str # 机器人的识别id
|
||||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||||
message_id: str # 消息id
|
message_id: str # 消息id
|
||||||
group_id: str = "" # 群组id,如果为私聊,则为空
|
group: Group # 群组
|
||||||
sender: MessageMember # 发送者
|
sender: MessageMember # 发送者
|
||||||
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||||
message_str: str # 最直观的纯文本消息字符串
|
message_str: str # 最直观的纯文本消息字符串
|
||||||
@@ -64,6 +64,28 @@ class AstrBotMessage:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.timestamp = int(time.time())
|
self.timestamp = int(time.time())
|
||||||
|
self.group = None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return str(self.__dict__)
|
return str(self.__dict__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def group_id(self) -> str:
|
||||||
|
"""
|
||||||
|
向后兼容的 group_id 属性
|
||||||
|
群组id,如果为私聊,则为空
|
||||||
|
"""
|
||||||
|
if self.group:
|
||||||
|
return self.group.group_id
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@group_id.setter
|
||||||
|
def group_id(self, value: str):
|
||||||
|
"""设置 group_id"""
|
||||||
|
if value:
|
||||||
|
if self.group:
|
||||||
|
self.group.group_id = value
|
||||||
|
else:
|
||||||
|
self.group = Group(group_id=value)
|
||||||
|
else:
|
||||||
|
self.group = None
|
||||||
|
|||||||
@@ -90,6 +90,10 @@ class PlatformManager:
|
|||||||
from .sources.discord.discord_platform_adapter import (
|
from .sources.discord.discord_platform_adapter import (
|
||||||
DiscordPlatformAdapter, # noqa: F401
|
DiscordPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
|
case "misskey":
|
||||||
|
from .sources.misskey.misskey_adapter import (
|
||||||
|
MisskeyPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "slack":
|
case "slack":
|
||||||
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
||||||
case "satori":
|
case "satori":
|
||||||
|
|||||||
@@ -182,11 +182,13 @@ class AiocqhttpAdapter(Platform):
|
|||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
str(event.sender["user_id"]), event.sender["nickname"]
|
str(event.sender["user_id"]),
|
||||||
|
event.sender.get("card") or event.sender.get("nickname", "N/A"),
|
||||||
)
|
)
|
||||||
if event["message_type"] == "group":
|
if event["message_type"] == "group":
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
abm.group_id = str(event.group_id)
|
abm.group_id = str(event.group_id)
|
||||||
|
abm.group.group_name = event.get("group_name", "N/A")
|
||||||
elif event["message_type"] == "private":
|
elif event["message_type"] == "private":
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
|||||||
391
astrbot/core/platform/sources/misskey/misskey_adapter.py
Normal file
391
astrbot/core/platform/sources/misskey/misskey_adapter.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Optional, Awaitable
|
||||||
|
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
AstrBotMessage,
|
||||||
|
Platform,
|
||||||
|
PlatformMetadata,
|
||||||
|
register_platform_adapter,
|
||||||
|
)
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSession
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
|
||||||
|
from .misskey_api import MisskeyAPI
|
||||||
|
from .misskey_event import MisskeyPlatformEvent
|
||||||
|
from .misskey_utils import (
|
||||||
|
serialize_message_chain,
|
||||||
|
resolve_message_visibility,
|
||||||
|
is_valid_user_session_id,
|
||||||
|
is_valid_room_session_id,
|
||||||
|
add_at_mention_if_needed,
|
||||||
|
process_files,
|
||||||
|
extract_sender_info,
|
||||||
|
create_base_message,
|
||||||
|
process_at_mention,
|
||||||
|
cache_user_info,
|
||||||
|
cache_room_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("misskey", "Misskey 平台适配器")
|
||||||
|
class MisskeyPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
self.config = platform_config or {}
|
||||||
|
self.settings = platform_settings or {}
|
||||||
|
self.instance_url = self.config.get("misskey_instance_url", "")
|
||||||
|
self.access_token = self.config.get("misskey_token", "")
|
||||||
|
self.max_message_length = self.config.get("max_message_length", 3000)
|
||||||
|
self.default_visibility = self.config.get(
|
||||||
|
"misskey_default_visibility", "public"
|
||||||
|
)
|
||||||
|
self.local_only = self.config.get("misskey_local_only", False)
|
||||||
|
self.enable_chat = self.config.get("misskey_enable_chat", True)
|
||||||
|
|
||||||
|
self.unique_session = platform_settings["unique_session"]
|
||||||
|
|
||||||
|
self.api: Optional[MisskeyAPI] = None
|
||||||
|
self._running = False
|
||||||
|
self.client_self_id = ""
|
||||||
|
self._bot_username = ""
|
||||||
|
self._user_cache = {}
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
default_config = {
|
||||||
|
"misskey_instance_url": "",
|
||||||
|
"misskey_token": "",
|
||||||
|
"max_message_length": 3000,
|
||||||
|
"misskey_default_visibility": "public",
|
||||||
|
"misskey_local_only": False,
|
||||||
|
"misskey_enable_chat": True,
|
||||||
|
}
|
||||||
|
default_config.update(self.config)
|
||||||
|
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="misskey",
|
||||||
|
description="Misskey 平台适配器",
|
||||||
|
id=self.config.get("id", "misskey"),
|
||||||
|
default_config_tmpl=default_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
if not self.instance_url or not self.access_token:
|
||||||
|
logger.error("[Misskey] 配置不完整,无法启动")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.api = MisskeyAPI(self.instance_url, self.access_token)
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_info = await self.api.get_current_user()
|
||||||
|
self.client_self_id = str(user_info.get("id", ""))
|
||||||
|
self._bot_username = user_info.get("username", "")
|
||||||
|
logger.info(
|
||||||
|
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey] 获取用户信息失败: {e}")
|
||||||
|
self._running = False
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._start_websocket_connection()
|
||||||
|
|
||||||
|
async def _start_websocket_connection(self):
|
||||||
|
backoff_delay = 1.0
|
||||||
|
max_backoff = 300.0
|
||||||
|
backoff_multiplier = 1.5
|
||||||
|
connection_attempts = 0
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
connection_attempts += 1
|
||||||
|
if not self.api:
|
||||||
|
logger.error("[Misskey] API 客户端未初始化")
|
||||||
|
break
|
||||||
|
|
||||||
|
streaming = self.api.get_streaming_client()
|
||||||
|
streaming.add_message_handler("notification", self._handle_notification)
|
||||||
|
if self.enable_chat:
|
||||||
|
streaming.add_message_handler(
|
||||||
|
"newChatMessage", self._handle_chat_message
|
||||||
|
)
|
||||||
|
streaming.add_message_handler("_debug", self._debug_handler)
|
||||||
|
|
||||||
|
if await streaming.connect():
|
||||||
|
logger.info(
|
||||||
|
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
|
||||||
|
)
|
||||||
|
connection_attempts = 0 # 重置计数器
|
||||||
|
await streaming.subscribe_channel("main")
|
||||||
|
if self.enable_chat:
|
||||||
|
await streaming.subscribe_channel("messaging")
|
||||||
|
await streaming.subscribe_channel("messagingIndex")
|
||||||
|
logger.info("[Misskey] 聊天频道已订阅")
|
||||||
|
|
||||||
|
backoff_delay = 1.0 # 重置延迟
|
||||||
|
await streaming.listen()
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._running:
|
||||||
|
logger.info(
|
||||||
|
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff_delay)
|
||||||
|
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
|
||||||
|
|
||||||
|
async def _handle_notification(self, data: Dict[str, Any]):
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
notification_type = data.get("type")
|
||||||
|
if notification_type in ["mention", "reply", "quote"]:
|
||||||
|
note = data.get("note")
|
||||||
|
if note and self._is_bot_mentioned(note):
|
||||||
|
logger.info(
|
||||||
|
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
|
||||||
|
)
|
||||||
|
message = await self.convert_message(note)
|
||||||
|
event = MisskeyPlatformEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
client=self.api,
|
||||||
|
)
|
||||||
|
self.commit_event(event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey] 处理通知失败: {e}")
|
||||||
|
|
||||||
|
async def _handle_chat_message(self, data: Dict[str, Any]):
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sender_id = str(
|
||||||
|
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
|
||||||
|
)
|
||||||
|
if sender_id == self.client_self_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
room_id = data.get("toRoomId")
|
||||||
|
if room_id:
|
||||||
|
raw_text = data.get("text", "")
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
message = await self.convert_room_message(data)
|
||||||
|
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
|
||||||
|
else:
|
||||||
|
message = await self.convert_chat_message(data)
|
||||||
|
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
|
||||||
|
|
||||||
|
event = MisskeyPlatformEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
client=self.api,
|
||||||
|
)
|
||||||
|
self.commit_event(event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
|
||||||
|
|
||||||
|
async def _debug_handler(self, data: Dict[str, Any]):
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
|
||||||
|
text = note.get("text", "")
|
||||||
|
if not text:
|
||||||
|
return False
|
||||||
|
|
||||||
|
mentions = note.get("mentions", [])
|
||||||
|
if self._bot_username and f"@{self._bot_username}" in text:
|
||||||
|
return True
|
||||||
|
if self.client_self_id in [str(uid) for uid in mentions]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
reply = note.get("reply")
|
||||||
|
if reply and isinstance(reply, dict):
|
||||||
|
reply_user_id = str(reply.get("user", {}).get("id", ""))
|
||||||
|
if reply_user_id == self.client_self_id:
|
||||||
|
return bool(self._bot_username and f"@{self._bot_username}" in text)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSession, message_chain: MessageChain
|
||||||
|
) -> Awaitable[Any]:
|
||||||
|
if not self.api:
|
||||||
|
logger.error("[Misskey] API 客户端未初始化")
|
||||||
|
return await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session_id = session.session_id
|
||||||
|
text, has_at_user = serialize_message_chain(message_chain.chain)
|
||||||
|
|
||||||
|
if not has_at_user and session_id:
|
||||||
|
user_info = self._user_cache.get(session_id)
|
||||||
|
text = add_at_mention_if_needed(text, user_info, has_at_user)
|
||||||
|
|
||||||
|
if not text or not text.strip():
|
||||||
|
logger.warning("[Misskey] 消息内容为空,跳过发送")
|
||||||
|
return await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
if len(text) > self.max_message_length:
|
||||||
|
text = text[: self.max_message_length] + "..."
|
||||||
|
|
||||||
|
if session_id and is_valid_user_session_id(session_id):
|
||||||
|
from .misskey_utils import extract_user_id_from_session_id
|
||||||
|
|
||||||
|
user_id = extract_user_id_from_session_id(session_id)
|
||||||
|
await self.api.send_message(user_id, text)
|
||||||
|
elif session_id and is_valid_room_session_id(session_id):
|
||||||
|
from .misskey_utils import extract_room_id_from_session_id
|
||||||
|
|
||||||
|
room_id = extract_room_id_from_session_id(session_id)
|
||||||
|
await self.api.send_room_message(room_id, text)
|
||||||
|
else:
|
||||||
|
visibility, visible_user_ids = resolve_message_visibility(
|
||||||
|
user_id=session_id,
|
||||||
|
user_cache=self._user_cache,
|
||||||
|
self_id=self.client_self_id,
|
||||||
|
default_visibility=self.default_visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.api.create_note(
|
||||||
|
text,
|
||||||
|
visibility=visibility,
|
||||||
|
visible_user_ids=visible_user_ids,
|
||||||
|
local_only=self.local_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey] 发送消息失败: {e}")
|
||||||
|
|
||||||
|
return await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||||
|
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
|
||||||
|
sender_info = extract_sender_info(raw_data, is_chat=False)
|
||||||
|
message = create_base_message(
|
||||||
|
raw_data,
|
||||||
|
sender_info,
|
||||||
|
self.client_self_id,
|
||||||
|
is_chat=False,
|
||||||
|
unique_session=self.unique_session,
|
||||||
|
)
|
||||||
|
cache_user_info(
|
||||||
|
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
||||||
|
)
|
||||||
|
|
||||||
|
message_parts = []
|
||||||
|
raw_text = raw_data.get("text", "")
|
||||||
|
|
||||||
|
if raw_text:
|
||||||
|
text_parts, processed_text = process_at_mention(
|
||||||
|
message, raw_text, self._bot_username, self.client_self_id
|
||||||
|
)
|
||||||
|
message_parts.extend(text_parts)
|
||||||
|
|
||||||
|
files = raw_data.get("files", [])
|
||||||
|
file_parts = process_files(message, files)
|
||||||
|
message_parts.extend(file_parts)
|
||||||
|
|
||||||
|
message.message_str = (
|
||||||
|
" ".join(part for part in message_parts if part.strip())
|
||||||
|
if message_parts
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||||
|
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
|
||||||
|
sender_info = extract_sender_info(raw_data, is_chat=True)
|
||||||
|
message = create_base_message(
|
||||||
|
raw_data,
|
||||||
|
sender_info,
|
||||||
|
self.client_self_id,
|
||||||
|
is_chat=True,
|
||||||
|
unique_session=self.unique_session,
|
||||||
|
)
|
||||||
|
cache_user_info(
|
||||||
|
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_text = raw_data.get("text", "")
|
||||||
|
if raw_text:
|
||||||
|
message.message.append(Comp.Plain(raw_text))
|
||||||
|
|
||||||
|
files = raw_data.get("files", [])
|
||||||
|
process_files(message, files, include_text_parts=False)
|
||||||
|
|
||||||
|
message.message_str = raw_text if raw_text else ""
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||||
|
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
|
||||||
|
sender_info = extract_sender_info(raw_data, is_chat=True)
|
||||||
|
room_id = raw_data.get("toRoomId", "")
|
||||||
|
message = create_base_message(
|
||||||
|
raw_data,
|
||||||
|
sender_info,
|
||||||
|
self.client_self_id,
|
||||||
|
is_chat=False,
|
||||||
|
room_id=room_id,
|
||||||
|
unique_session=self.unique_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_user_info(
|
||||||
|
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
||||||
|
)
|
||||||
|
cache_room_info(self._user_cache, raw_data, self.client_self_id)
|
||||||
|
|
||||||
|
raw_text = raw_data.get("text", "")
|
||||||
|
message_parts = []
|
||||||
|
|
||||||
|
if raw_text:
|
||||||
|
if self._bot_username and f"@{self._bot_username}" in raw_text:
|
||||||
|
text_parts, processed_text = process_at_mention(
|
||||||
|
message, raw_text, self._bot_username, self.client_self_id
|
||||||
|
)
|
||||||
|
message_parts.extend(text_parts)
|
||||||
|
else:
|
||||||
|
message.message.append(Comp.Plain(raw_text))
|
||||||
|
message_parts.append(raw_text)
|
||||||
|
|
||||||
|
files = raw_data.get("files", [])
|
||||||
|
file_parts = process_files(message, files)
|
||||||
|
message_parts.extend(file_parts)
|
||||||
|
|
||||||
|
message.message_str = (
|
||||||
|
" ".join(part for part in message_parts if part.strip())
|
||||||
|
if message_parts
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self._running = False
|
||||||
|
if self.api:
|
||||||
|
await self.api.close()
|
||||||
|
|
||||||
|
def get_client(self) -> Any:
|
||||||
|
return self.api
|
||||||
404
astrbot/core/platform/sources/misskey/misskey_api.py
Normal file
404
astrbot/core/platform/sources/misskey/misskey_api.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Optional, Dict, List, Callable, Awaitable
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
import websockets
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
API_MAX_RETRIES = 3
|
||||||
|
HTTP_OK = 200
|
||||||
|
|
||||||
|
|
||||||
|
class APIError(Exception):
|
||||||
|
"""Misskey API 基础异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class APIConnectionError(APIError):
|
||||||
|
"""网络连接异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class APIRateLimitError(APIError):
|
||||||
|
"""API 频率限制异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(APIError):
|
||||||
|
"""认证失败异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketError(APIError):
|
||||||
|
"""WebSocket 连接异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingClient:
|
||||||
|
def __init__(self, instance_url: str, access_token: str):
|
||||||
|
self.instance_url = instance_url.rstrip("/")
|
||||||
|
self.access_token = access_token
|
||||||
|
self.websocket: Optional[Any] = None
|
||||||
|
self.is_connected = False
|
||||||
|
self.message_handlers: Dict[str, Callable] = {}
|
||||||
|
self.channels: Dict[str, str] = {}
|
||||||
|
self._running = False
|
||||||
|
self._last_pong = None
|
||||||
|
|
||||||
|
async def connect(self) -> bool:
|
||||||
|
try:
|
||||||
|
ws_url = self.instance_url.replace("https://", "wss://").replace(
|
||||||
|
"http://", "ws://"
|
||||||
|
)
|
||||||
|
ws_url += f"/streaming?i={self.access_token}"
|
||||||
|
|
||||||
|
self.websocket = await websockets.connect(
|
||||||
|
ws_url, ping_interval=30, ping_timeout=10
|
||||||
|
)
|
||||||
|
self.is_connected = True
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
logger.info("[Misskey WebSocket] 已连接")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
self._running = False
|
||||||
|
if self.websocket:
|
||||||
|
await self.websocket.close()
|
||||||
|
self.websocket = None
|
||||||
|
self.is_connected = False
|
||||||
|
logger.info("[Misskey WebSocket] 连接已断开")
|
||||||
|
|
||||||
|
async def subscribe_channel(
|
||||||
|
self, channel_type: str, params: Optional[Dict] = None
|
||||||
|
) -> str:
|
||||||
|
if not self.is_connected or not self.websocket:
|
||||||
|
raise WebSocketError("WebSocket 未连接")
|
||||||
|
|
||||||
|
channel_id = str(uuid.uuid4())
|
||||||
|
message = {
|
||||||
|
"type": "connect",
|
||||||
|
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
self.channels[channel_id] = channel_type
|
||||||
|
return channel_id
|
||||||
|
|
||||||
|
async def unsubscribe_channel(self, channel_id: str):
|
||||||
|
if (
|
||||||
|
not self.is_connected
|
||||||
|
or not self.websocket
|
||||||
|
or channel_id not in self.channels
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
message = {"type": "disconnect", "body": {"id": channel_id}}
|
||||||
|
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
del self.channels[channel_id]
|
||||||
|
|
||||||
|
def add_message_handler(
|
||||||
|
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
|
||||||
|
):
|
||||||
|
self.message_handlers[event_type] = handler
|
||||||
|
|
||||||
|
async def listen(self):
|
||||||
|
if not self.is_connected or not self.websocket:
|
||||||
|
raise WebSocketError("WebSocket 未连接")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in self.websocket:
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
await self._handle_message(data)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosedError as e:
|
||||||
|
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
|
||||||
|
)
|
||||||
|
self.is_connected = False
|
||||||
|
except websockets.exceptions.InvalidHandshake as e:
|
||||||
|
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
async def _handle_message(self, data: Dict[str, Any]):
|
||||||
|
message_type = data.get("type")
|
||||||
|
body = data.get("body", {})
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if message_type == "channel":
|
||||||
|
channel_id = body.get("id")
|
||||||
|
event_type = body.get("type")
|
||||||
|
event_body = body.get("body", {})
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel_id in self.channels:
|
||||||
|
channel_type = self.channels[channel_id]
|
||||||
|
handler_key = f"{channel_type}:{event_type}"
|
||||||
|
|
||||||
|
if handler_key in self.message_handlers:
|
||||||
|
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
|
||||||
|
await self.message_handlers[handler_key](event_body)
|
||||||
|
elif event_type in self.message_handlers:
|
||||||
|
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
|
||||||
|
await self.message_handlers[event_type](event_body)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey WebSocket] 未找到处理器: {handler_key} 或 {event_type}"
|
||||||
|
)
|
||||||
|
if "_debug" in self.message_handlers:
|
||||||
|
await self.message_handlers["_debug"](
|
||||||
|
{
|
||||||
|
"type": event_type,
|
||||||
|
"body": event_body,
|
||||||
|
"channel": channel_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
elif message_type in self.message_handlers:
|
||||||
|
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
|
||||||
|
await self.message_handlers[message_type](body)
|
||||||
|
else:
|
||||||
|
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
|
||||||
|
if "_debug" in self.message_handlers:
|
||||||
|
await self.message_handlers["_debug"](data)
|
||||||
|
|
||||||
|
|
||||||
|
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
|
||||||
|
def decorator(func):
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
last_exc = None
|
||||||
|
for _ in range(max_retries):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except retryable_exceptions as e:
|
||||||
|
last_exc = e
|
||||||
|
continue
|
||||||
|
if last_exc:
|
||||||
|
raise last_exc
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class MisskeyAPI:
|
||||||
|
def __init__(self, instance_url: str, access_token: str):
|
||||||
|
self.instance_url = instance_url.rstrip("/")
|
||||||
|
self.access_token = access_token
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self.streaming: Optional[StreamingClient] = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self.streaming:
|
||||||
|
await self.streaming.disconnect()
|
||||||
|
self.streaming = None
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
logger.debug("[Misskey API] 客户端已关闭")
|
||||||
|
|
||||||
|
def get_streaming_client(self) -> StreamingClient:
|
||||||
|
if not self.streaming:
|
||||||
|
self.streaming = StreamingClient(self.instance_url, self.access_token)
|
||||||
|
return self.streaming
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
headers = {"Authorization": f"Bearer {self.access_token}"}
|
||||||
|
self._session = aiohttp.ClientSession(headers=headers)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
def _handle_response_status(self, status: int, endpoint: str):
|
||||||
|
"""处理 HTTP 响应状态码"""
|
||||||
|
if status == 400:
|
||||||
|
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
|
||||||
|
raise APIError(f"Bad request for {endpoint}")
|
||||||
|
elif status in (401, 403):
|
||||||
|
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
|
||||||
|
raise AuthenticationError(f"Authentication failed for {endpoint}")
|
||||||
|
elif status == 429:
|
||||||
|
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
|
||||||
|
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
|
||||||
|
else:
|
||||||
|
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
|
||||||
|
raise APIConnectionError(f"HTTP {status} for {endpoint}")
|
||||||
|
|
||||||
|
async def _process_response(
|
||||||
|
self, response: aiohttp.ClientResponse, endpoint: str
|
||||||
|
) -> Any:
|
||||||
|
"""处理 API 响应"""
|
||||||
|
if response.status == HTTP_OK:
|
||||||
|
try:
|
||||||
|
result = await response.json()
|
||||||
|
if endpoint == "i/notifications":
|
||||||
|
notifications_data = (
|
||||||
|
result
|
||||||
|
if isinstance(result, list)
|
||||||
|
else result.get("notifications", [])
|
||||||
|
if isinstance(result, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
if notifications_data:
|
||||||
|
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
|
||||||
|
else:
|
||||||
|
logger.debug(f"API 请求成功: {endpoint}")
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"响应不是有效的 JSON 格式: {e}")
|
||||||
|
raise APIConnectionError("Invalid JSON response") from e
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(
|
||||||
|
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
|
||||||
|
|
||||||
|
self._handle_response_status(response.status, endpoint)
|
||||||
|
raise APIConnectionError(f"Request failed for {endpoint}")
|
||||||
|
|
||||||
|
@retry_async(
|
||||||
|
max_retries=API_MAX_RETRIES,
|
||||||
|
retryable_exceptions=(APIConnectionError, APIRateLimitError),
|
||||||
|
)
|
||||||
|
async def _make_request(
|
||||||
|
self, endpoint: str, data: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Any:
|
||||||
|
url = f"{self.instance_url}/api/{endpoint}"
|
||||||
|
payload = {"i": self.access_token}
|
||||||
|
if data:
|
||||||
|
payload.update(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.session.post(url, json=payload) as response:
|
||||||
|
return await self._process_response(response, endpoint)
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
logger.error(f"HTTP 请求错误: {e}")
|
||||||
|
raise APIConnectionError(f"HTTP request failed: {e}") from e
|
||||||
|
|
||||||
|
async def create_note(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
visibility: str = "public",
|
||||||
|
reply_id: Optional[str] = None,
|
||||||
|
visible_user_ids: Optional[List[str]] = None,
|
||||||
|
local_only: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""创建新贴文"""
|
||||||
|
data: Dict[str, Any] = {
|
||||||
|
"text": text,
|
||||||
|
"visibility": visibility,
|
||||||
|
"localOnly": local_only,
|
||||||
|
}
|
||||||
|
if reply_id:
|
||||||
|
data["replyId"] = reply_id
|
||||||
|
if visible_user_ids and visibility == "specified":
|
||||||
|
data["visibleUserIds"] = visible_user_ids
|
||||||
|
|
||||||
|
result = await self._make_request("notes/create", data)
|
||||||
|
note_id = result.get("createdNote", {}).get("id", "unknown")
|
||||||
|
logger.debug(f"发帖成功,note_id: {note_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_current_user(self) -> Dict[str, Any]:
|
||||||
|
"""获取当前用户信息"""
|
||||||
|
return await self._make_request("i", {})
|
||||||
|
|
||||||
|
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
|
||||||
|
"""发送聊天消息"""
|
||||||
|
result = await self._make_request(
|
||||||
|
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
|
||||||
|
)
|
||||||
|
message_id = result.get("id", "unknown")
|
||||||
|
logger.debug(f"聊天发送成功,message_id: {message_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
|
||||||
|
"""发送房间消息"""
|
||||||
|
result = await self._make_request(
|
||||||
|
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
|
||||||
|
)
|
||||||
|
message_id = result.get("id", "unknown")
|
||||||
|
logger.debug(f"房间消息发送成功,message_id: {message_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_messages(
|
||||||
|
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""获取聊天消息历史"""
|
||||||
|
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
|
||||||
|
if since_id:
|
||||||
|
data["sinceId"] = since_id
|
||||||
|
|
||||||
|
result = await self._make_request("chat/messages/user-timeline", data)
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_mentions(
|
||||||
|
self, limit: int = 10, since_id: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""获取提及通知"""
|
||||||
|
data: Dict[str, Any] = {"limit": limit}
|
||||||
|
if since_id:
|
||||||
|
data["sinceId"] = since_id
|
||||||
|
data["includeTypes"] = ["mention", "reply", "quote"]
|
||||||
|
|
||||||
|
result = await self._make_request("i/notifications", data)
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
elif isinstance(result, dict) and "notifications" in result:
|
||||||
|
return result["notifications"]
|
||||||
|
else:
|
||||||
|
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
|
||||||
|
return []
|
||||||
123
astrbot/core/platform/sources/misskey/misskey_event.py
Normal file
123
astrbot/core/platform/sources/misskey/misskey_event.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
|
||||||
|
from astrbot.api.message_components import Plain
|
||||||
|
|
||||||
|
from .misskey_utils import (
|
||||||
|
serialize_message_chain,
|
||||||
|
resolve_visibility_from_raw_message,
|
||||||
|
is_valid_user_session_id,
|
||||||
|
is_valid_room_session_id,
|
||||||
|
add_at_mention_if_needed,
|
||||||
|
extract_user_id_from_session_id,
|
||||||
|
extract_room_id_from_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MisskeyPlatformEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
client,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def _is_system_command(self, message_str: str) -> bool:
|
||||||
|
"""检测是否为系统指令"""
|
||||||
|
if not message_str or not message_str.strip():
|
||||||
|
return False
|
||||||
|
|
||||||
|
system_prefixes = ["/", "!", "#", ".", "^"]
|
||||||
|
message_trimmed = message_str.strip()
|
||||||
|
|
||||||
|
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
content, has_at = serialize_message_chain(message.chain)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_message_id = getattr(self.message_obj, "message_id", None)
|
||||||
|
raw_message = getattr(self.message_obj, "raw_message", {})
|
||||||
|
|
||||||
|
if raw_message and not has_at:
|
||||||
|
user_data = raw_message.get("user", {})
|
||||||
|
user_info = {
|
||||||
|
"username": user_data.get("username", ""),
|
||||||
|
"nickname": user_data.get("name", user_data.get("username", "")),
|
||||||
|
}
|
||||||
|
content = add_at_mention_if_needed(content, user_info, has_at)
|
||||||
|
|
||||||
|
# 根据会话类型选择发送方式
|
||||||
|
if hasattr(self.client, "send_message") and is_valid_user_session_id(
|
||||||
|
self.session_id
|
||||||
|
):
|
||||||
|
user_id = extract_user_id_from_session_id(self.session_id)
|
||||||
|
await self.client.send_message(user_id, content)
|
||||||
|
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
|
||||||
|
self.session_id
|
||||||
|
):
|
||||||
|
room_id = extract_room_id_from_session_id(self.session_id)
|
||||||
|
await self.client.send_room_message(room_id, content)
|
||||||
|
elif original_message_id and hasattr(self.client, "create_note"):
|
||||||
|
visibility, visible_user_ids = resolve_visibility_from_raw_message(
|
||||||
|
raw_message
|
||||||
|
)
|
||||||
|
await self.client.create_note(
|
||||||
|
content,
|
||||||
|
reply_id=original_message_id,
|
||||||
|
visibility=visibility,
|
||||||
|
visible_user_ids=visible_user_ids,
|
||||||
|
)
|
||||||
|
elif hasattr(self.client, "create_note"):
|
||||||
|
logger.debug("[MisskeyEvent] 创建新帖子")
|
||||||
|
await self.client.create_note(content)
|
||||||
|
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[MisskeyEvent] 发送失败: {e}")
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||||
|
):
|
||||||
|
if not use_fallback:
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
for comp in chain.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
buffer += comp.text
|
||||||
|
if any(p in buffer for p in "。?!~…"):
|
||||||
|
buffer = await self.process_buffer(buffer, pattern)
|
||||||
|
else:
|
||||||
|
await self.send(MessageChain(chain=[comp]))
|
||||||
|
await asyncio.sleep(1.5) # 限速
|
||||||
|
|
||||||
|
if buffer.strip():
|
||||||
|
await self.send(MessageChain([Plain(buffer)]))
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
327
astrbot/core/platform/sources/misskey/misskey_utils.py
Normal file
327
astrbot/core/platform/sources/misskey/misskey_utils.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
"""Misskey 平台适配器通用工具函数"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Tuple, Optional, Union
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
|
||||||
|
"""将消息链序列化为文本字符串"""
|
||||||
|
text_parts = []
|
||||||
|
has_at = False
|
||||||
|
|
||||||
|
def process_component(component):
|
||||||
|
nonlocal has_at
|
||||||
|
if isinstance(component, Comp.Plain):
|
||||||
|
return component.text
|
||||||
|
elif isinstance(component, Comp.File):
|
||||||
|
file_name = getattr(component, "name", "文件")
|
||||||
|
return f"[文件: {file_name}]"
|
||||||
|
elif isinstance(component, Comp.At):
|
||||||
|
has_at = True
|
||||||
|
return f"@{component.qq}"
|
||||||
|
elif hasattr(component, "text"):
|
||||||
|
text = getattr(component, "text", "")
|
||||||
|
if "@" in text:
|
||||||
|
has_at = True
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
return str(component)
|
||||||
|
|
||||||
|
for component in chain:
|
||||||
|
if isinstance(component, Comp.Node) and component.content:
|
||||||
|
for node_comp in component.content:
|
||||||
|
result = process_component(node_comp)
|
||||||
|
if result:
|
||||||
|
text_parts.append(result)
|
||||||
|
else:
|
||||||
|
result = process_component(component)
|
||||||
|
if result:
|
||||||
|
text_parts.append(result)
|
||||||
|
|
||||||
|
return "".join(text_parts), has_at
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_message_visibility(
|
||||||
|
user_id: Optional[str],
|
||||||
|
user_cache: Dict[str, Any],
|
||||||
|
self_id: Optional[str],
|
||||||
|
default_visibility: str = "public",
|
||||||
|
) -> Tuple[str, Optional[List[str]]]:
|
||||||
|
"""解析 Misskey 消息的可见性设置"""
|
||||||
|
visibility = default_visibility
|
||||||
|
visible_user_ids = None
|
||||||
|
|
||||||
|
if user_id and user_cache:
|
||||||
|
user_info = user_cache.get(user_id)
|
||||||
|
if user_info:
|
||||||
|
original_visibility = user_info.get("visibility", default_visibility)
|
||||||
|
if original_visibility == "specified":
|
||||||
|
visibility = "specified"
|
||||||
|
original_visible_users = user_info.get("visible_user_ids", [])
|
||||||
|
users_to_include = [user_id]
|
||||||
|
if self_id:
|
||||||
|
users_to_include.append(self_id)
|
||||||
|
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||||
|
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||||
|
else:
|
||||||
|
visibility = original_visibility
|
||||||
|
|
||||||
|
return visibility, visible_user_ids
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_visibility_from_raw_message(
|
||||||
|
raw_message: Dict[str, Any], self_id: Optional[str] = None
|
||||||
|
) -> Tuple[str, Optional[List[str]]]:
|
||||||
|
"""从原始消息数据中解析可见性设置"""
|
||||||
|
visibility = "public"
|
||||||
|
visible_user_ids = None
|
||||||
|
|
||||||
|
if not raw_message:
|
||||||
|
return visibility, visible_user_ids
|
||||||
|
|
||||||
|
original_visibility = raw_message.get("visibility", "public")
|
||||||
|
if original_visibility == "specified":
|
||||||
|
visibility = "specified"
|
||||||
|
original_visible_users = raw_message.get("visibleUserIds", [])
|
||||||
|
sender_id = raw_message.get("userId", "")
|
||||||
|
|
||||||
|
users_to_include = []
|
||||||
|
if sender_id:
|
||||||
|
users_to_include.append(sender_id)
|
||||||
|
if self_id:
|
||||||
|
users_to_include.append(self_id)
|
||||||
|
|
||||||
|
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||||
|
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||||
|
else:
|
||||||
|
visibility = original_visibility
|
||||||
|
|
||||||
|
return visibility, visible_user_ids
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
|
||||||
|
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
|
||||||
|
if not isinstance(session_id, str) or "%" not in session_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
parts = session_id.split("%")
|
||||||
|
return (
|
||||||
|
len(parts) == 2
|
||||||
|
and parts[0] == "chat"
|
||||||
|
and bool(parts[1])
|
||||||
|
and parts[1] != "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
|
||||||
|
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
|
||||||
|
if not isinstance(session_id, str) or "%" not in session_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
parts = session_id.split("%")
|
||||||
|
return (
|
||||||
|
len(parts) == 2
|
||||||
|
and parts[0] == "room"
|
||||||
|
and bool(parts[1])
|
||||||
|
and parts[1] != "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_user_id_from_session_id(session_id: str) -> str:
|
||||||
|
"""从 session_id 中提取用户 ID"""
|
||||||
|
if "%" in session_id:
|
||||||
|
parts = session_id.split("%")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
return parts[1]
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def extract_room_id_from_session_id(session_id: str) -> str:
|
||||||
|
"""从 session_id 中提取房间 ID"""
|
||||||
|
if "%" in session_id:
|
||||||
|
parts = session_id.split("%")
|
||||||
|
if len(parts) >= 2 and parts[0] == "room":
|
||||||
|
return parts[1]
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def add_at_mention_if_needed(
|
||||||
|
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""如果需要且没有@用户,则添加@用户"""
|
||||||
|
if has_at or not user_info:
|
||||||
|
return text
|
||||||
|
|
||||||
|
username = user_info.get("username")
|
||||||
|
nickname = user_info.get("nickname")
|
||||||
|
|
||||||
|
if username:
|
||||||
|
mention = f"@{username}"
|
||||||
|
if not text.startswith(mention):
|
||||||
|
text = f"{mention}\n{text}".strip()
|
||||||
|
elif nickname:
|
||||||
|
mention = f"@{nickname}"
|
||||||
|
if not text.startswith(mention):
|
||||||
|
text = f"{mention}\n{text}".strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
|
||||||
|
"""创建文件组件和描述文本"""
|
||||||
|
file_url = file_info.get("url", "")
|
||||||
|
file_name = file_info.get("name", "未知文件")
|
||||||
|
file_type = file_info.get("type", "")
|
||||||
|
|
||||||
|
if file_type.startswith("image/"):
|
||||||
|
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
|
||||||
|
elif file_type.startswith("audio/"):
|
||||||
|
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
|
||||||
|
elif file_type.startswith("video/"):
|
||||||
|
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
|
||||||
|
else:
|
||||||
|
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
|
||||||
|
|
||||||
|
|
||||||
|
def process_files(
|
||||||
|
message: AstrBotMessage, files: list, include_text_parts: bool = True
|
||||||
|
) -> list:
|
||||||
|
"""处理文件列表,添加到消息组件中并返回文本描述"""
|
||||||
|
file_parts = []
|
||||||
|
for file_info in files:
|
||||||
|
component, part_text = create_file_component(file_info)
|
||||||
|
message.message.append(component)
|
||||||
|
if include_text_parts:
|
||||||
|
file_parts.append(part_text)
|
||||||
|
return file_parts
|
||||||
|
|
||||||
|
|
||||||
|
def extract_sender_info(
|
||||||
|
raw_data: Dict[str, Any], is_chat: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""提取发送者信息"""
|
||||||
|
if is_chat:
|
||||||
|
sender = raw_data.get("fromUser", {})
|
||||||
|
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
|
||||||
|
else:
|
||||||
|
sender = raw_data.get("user", {})
|
||||||
|
sender_id = str(sender.get("id", ""))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sender": sender,
|
||||||
|
"sender_id": sender_id,
|
||||||
|
"nickname": sender.get("name", sender.get("username", "")),
|
||||||
|
"username": sender.get("username", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_base_message(
|
||||||
|
raw_data: Dict[str, Any],
|
||||||
|
sender_info: Dict[str, Any],
|
||||||
|
client_self_id: str,
|
||||||
|
is_chat: bool = False,
|
||||||
|
room_id: Optional[str] = None,
|
||||||
|
unique_session: bool = False,
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
"""创建基础消息对象"""
|
||||||
|
message = AstrBotMessage()
|
||||||
|
message.raw_message = raw_data
|
||||||
|
message.message = []
|
||||||
|
|
||||||
|
message.sender = MessageMember(
|
||||||
|
user_id=sender_info["sender_id"],
|
||||||
|
nickname=sender_info["nickname"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if room_id:
|
||||||
|
session_prefix = "room"
|
||||||
|
session_id = f"{session_prefix}%{room_id}"
|
||||||
|
if unique_session:
|
||||||
|
session_id += f"_{sender_info['sender_id']}"
|
||||||
|
message.type = MessageType.GROUP_MESSAGE
|
||||||
|
message.group_id = room_id
|
||||||
|
elif is_chat:
|
||||||
|
session_prefix = "chat"
|
||||||
|
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||||
|
message.type = MessageType.FRIEND_MESSAGE
|
||||||
|
else:
|
||||||
|
session_prefix = "note"
|
||||||
|
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||||
|
message.type = MessageType.FRIEND_MESSAGE
|
||||||
|
|
||||||
|
message.session_id = (
|
||||||
|
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
|
||||||
|
)
|
||||||
|
message.message_id = str(raw_data.get("id", ""))
|
||||||
|
message.self_id = client_self_id
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def process_at_mention(
|
||||||
|
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
|
||||||
|
) -> Tuple[List[str], str]:
|
||||||
|
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
|
||||||
|
message_parts = []
|
||||||
|
|
||||||
|
if not raw_text:
|
||||||
|
return message_parts, ""
|
||||||
|
|
||||||
|
if bot_username and raw_text.startswith(f"@{bot_username}"):
|
||||||
|
at_mention = f"@{bot_username}"
|
||||||
|
message.message.append(Comp.At(qq=client_self_id))
|
||||||
|
remaining_text = raw_text[len(at_mention) :].strip()
|
||||||
|
if remaining_text:
|
||||||
|
message.message.append(Comp.Plain(remaining_text))
|
||||||
|
message_parts.append(remaining_text)
|
||||||
|
return message_parts, remaining_text
|
||||||
|
else:
|
||||||
|
message.message.append(Comp.Plain(raw_text))
|
||||||
|
message_parts.append(raw_text)
|
||||||
|
return message_parts, raw_text
|
||||||
|
|
||||||
|
|
||||||
|
def cache_user_info(
|
||||||
|
user_cache: Dict[str, Any],
|
||||||
|
sender_info: Dict[str, Any],
|
||||||
|
raw_data: Dict[str, Any],
|
||||||
|
client_self_id: str,
|
||||||
|
is_chat: bool = False,
|
||||||
|
):
|
||||||
|
"""缓存用户信息"""
|
||||||
|
if is_chat:
|
||||||
|
user_cache_data = {
|
||||||
|
"username": sender_info["username"],
|
||||||
|
"nickname": sender_info["nickname"],
|
||||||
|
"visibility": "specified",
|
||||||
|
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
user_cache_data = {
|
||||||
|
"username": sender_info["username"],
|
||||||
|
"nickname": sender_info["nickname"],
|
||||||
|
"visibility": raw_data.get("visibility", "public"),
|
||||||
|
"visible_user_ids": raw_data.get("visibleUserIds", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
user_cache[sender_info["sender_id"]] = user_cache_data
|
||||||
|
|
||||||
|
|
||||||
|
def cache_room_info(
|
||||||
|
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
|
||||||
|
):
|
||||||
|
"""缓存房间信息"""
|
||||||
|
room_data = raw_data.get("toRoom")
|
||||||
|
room_id = raw_data.get("toRoomId")
|
||||||
|
|
||||||
|
if room_data and room_id:
|
||||||
|
room_cache_key = f"room:{room_id}"
|
||||||
|
user_cache[room_cache_key] = {
|
||||||
|
"room_id": room_id,
|
||||||
|
"room_name": room_data.get("name", ""),
|
||||||
|
"room_description": room_data.get("description", ""),
|
||||||
|
"owner_id": room_data.get("ownerId", ""),
|
||||||
|
"visibility": "specified",
|
||||||
|
"visible_user_ids": [client_self_id],
|
||||||
|
}
|
||||||
@@ -17,7 +17,14 @@ from astrbot.api.platform import (
|
|||||||
register_platform_adapter,
|
register_platform_adapter,
|
||||||
)
|
)
|
||||||
from astrbot.core.platform.astr_message_event import MessageSession
|
from astrbot.core.platform.astr_message_event import MessageSession
|
||||||
from astrbot.api.message_components import Plain, Image, At, File, Record
|
from astrbot.api.message_components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
At,
|
||||||
|
File,
|
||||||
|
Record,
|
||||||
|
Reply,
|
||||||
|
)
|
||||||
from xml.etree import ElementTree as ET
|
from xml.etree import ElementTree as ET
|
||||||
|
|
||||||
|
|
||||||
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
self.token = self.config.get("satori_token", "")
|
self.token = self.config.get("satori_token", "")
|
||||||
self.endpoint = self.config.get(
|
self.endpoint = self.config.get(
|
||||||
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
|
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
|
||||||
)
|
)
|
||||||
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
|
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
|
||||||
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
|
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
|
||||||
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
|
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
|
||||||
|
|
||||||
|
self.metadata = PlatformMetadata(
|
||||||
|
name="satori",
|
||||||
|
description="Satori 通用协议适配器",
|
||||||
|
id=self.config["id"],
|
||||||
|
)
|
||||||
|
|
||||||
self.ws: Optional[ClientConnection] = None
|
self.ws: Optional[ClientConnection] = None
|
||||||
self.session: Optional[ClientSession] = None
|
self.session: Optional[ClientSession] = None
|
||||||
self.sequence = 0
|
self.sequence = 0
|
||||||
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
|
|||||||
await super().send_by_session(session, message_chain)
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
|
return self.metadata
|
||||||
|
|
||||||
def _is_websocket_closed(self, ws) -> bool:
|
def _is_websocket_closed(self, ws) -> bool:
|
||||||
"""检查WebSocket连接是否已关闭"""
|
"""检查WebSocket连接是否已关闭"""
|
||||||
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
|
|||||||
|
|
||||||
abm.self_id = login.get("user", {}).get("id", "")
|
abm.self_id = login.get("user", {}).get("id", "")
|
||||||
|
|
||||||
content = message.get("content", "")
|
# 消息链
|
||||||
abm.message = await self.parse_satori_elements(content)
|
abm.message = []
|
||||||
|
|
||||||
|
content = message.get("content", "")
|
||||||
|
|
||||||
|
quote = message.get("quote")
|
||||||
|
content_for_parsing = content # 副本
|
||||||
|
|
||||||
|
# 提取<quote>标签
|
||||||
|
if "<quote" in content:
|
||||||
|
try:
|
||||||
|
quote_info = await self._extract_quote_element(content)
|
||||||
|
if quote_info:
|
||||||
|
quote = quote_info["quote"]
|
||||||
|
content_for_parsing = quote_info["content_without_quote"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
|
||||||
|
|
||||||
|
if quote:
|
||||||
|
# 引用消息
|
||||||
|
quote_abm = await self._convert_quote_message(quote)
|
||||||
|
if quote_abm:
|
||||||
|
sender_id = quote_abm.sender.user_id
|
||||||
|
if isinstance(sender_id, str) and sender_id.isdigit():
|
||||||
|
sender_id = int(sender_id)
|
||||||
|
elif not isinstance(sender_id, int):
|
||||||
|
sender_id = 0 # 默认值
|
||||||
|
|
||||||
|
reply_component = Reply(
|
||||||
|
id=quote_abm.message_id,
|
||||||
|
chain=quote_abm.message,
|
||||||
|
sender_id=quote_abm.sender.user_id,
|
||||||
|
sender_nickname=quote_abm.sender.nickname,
|
||||||
|
time=quote_abm.timestamp,
|
||||||
|
message_str=quote_abm.message_str,
|
||||||
|
text=quote_abm.message_str,
|
||||||
|
qq=sender_id,
|
||||||
|
)
|
||||||
|
abm.message.append(reply_component)
|
||||||
|
|
||||||
|
# 解析消息内容
|
||||||
|
content_elements = await self.parse_satori_elements(content_for_parsing)
|
||||||
|
abm.message.extend(content_elements)
|
||||||
|
|
||||||
# parse message_str
|
|
||||||
abm.message_str = ""
|
abm.message_str = ""
|
||||||
for comp in abm.message:
|
for comp in content_elements:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
abm.message_str += comp.text
|
abm.message_str += comp.text
|
||||||
|
|
||||||
@@ -333,6 +386,163 @@ class SatoriPlatformAdapter(Platform):
|
|||||||
logger.error(f"转换 Satori 消息失败: {e}")
|
logger.error(f"转换 Satori 消息失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _extract_namespace_prefixes(self, content: str) -> set:
|
||||||
|
"""提取XML内容中的命名空间前缀"""
|
||||||
|
prefixes = set()
|
||||||
|
|
||||||
|
# 查找所有标签
|
||||||
|
i = 0
|
||||||
|
while i < len(content):
|
||||||
|
# 查找开始标签
|
||||||
|
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
|
||||||
|
# 找到标签结束位置
|
||||||
|
tag_end = content.find(">", i)
|
||||||
|
if tag_end != -1:
|
||||||
|
# 提取标签内容
|
||||||
|
tag_content = content[i + 1 : tag_end]
|
||||||
|
# 检查是否有命名空间前缀
|
||||||
|
if ":" in tag_content and "xmlns:" not in tag_content:
|
||||||
|
# 分割标签名
|
||||||
|
parts = tag_content.split()
|
||||||
|
if parts:
|
||||||
|
tag_name = parts[0]
|
||||||
|
if ":" in tag_name:
|
||||||
|
prefix = tag_name.split(":")[0]
|
||||||
|
# 确保是有效的命名空间前缀
|
||||||
|
if (
|
||||||
|
prefix.isalnum()
|
||||||
|
or prefix.replace("_", "").isalnum()
|
||||||
|
):
|
||||||
|
prefixes.add(prefix)
|
||||||
|
i = tag_end + 1
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
# 查找结束标签
|
||||||
|
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
|
||||||
|
# 找到标签结束位置
|
||||||
|
tag_end = content.find(">", i)
|
||||||
|
if tag_end != -1:
|
||||||
|
# 提取标签内容
|
||||||
|
tag_content = content[i + 2 : tag_end]
|
||||||
|
# 检查是否有命名空间前缀
|
||||||
|
if ":" in tag_content:
|
||||||
|
prefix = tag_content.split(":")[0]
|
||||||
|
# 确保是有效的命名空间前缀
|
||||||
|
if prefix.isalnum() or prefix.replace("_", "").isalnum():
|
||||||
|
prefixes.add(prefix)
|
||||||
|
i = tag_end + 1
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return prefixes
|
||||||
|
|
||||||
|
async def _extract_quote_element(self, content: str) -> Optional[dict]:
|
||||||
|
"""提取<quote>标签信息"""
|
||||||
|
try:
|
||||||
|
# 处理命名空间前缀问题
|
||||||
|
processed_content = content
|
||||||
|
if ":" in content and not content.startswith("<root"):
|
||||||
|
prefixes = self._extract_namespace_prefixes(content)
|
||||||
|
|
||||||
|
# 构建命名空间声明
|
||||||
|
ns_declarations = " ".join(
|
||||||
|
[
|
||||||
|
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
||||||
|
for prefix in prefixes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 包装内容
|
||||||
|
processed_content = f"<root {ns_declarations}>{content}</root>"
|
||||||
|
elif not content.startswith("<root"):
|
||||||
|
processed_content = f"<root>{content}</root>"
|
||||||
|
else:
|
||||||
|
processed_content = content
|
||||||
|
|
||||||
|
root = ET.fromstring(processed_content)
|
||||||
|
|
||||||
|
# 查找<quote>标签
|
||||||
|
quote_element = None
|
||||||
|
for elem in root.iter():
|
||||||
|
tag_name = elem.tag
|
||||||
|
if "}" in tag_name:
|
||||||
|
tag_name = tag_name.split("}")[1]
|
||||||
|
if tag_name.lower() == "quote":
|
||||||
|
quote_element = elem
|
||||||
|
break
|
||||||
|
|
||||||
|
if quote_element is not None:
|
||||||
|
# 提取quote标签的属性
|
||||||
|
quote_id = quote_element.get("id", "")
|
||||||
|
|
||||||
|
# 提取<quote>标签内部的内容
|
||||||
|
inner_content = ""
|
||||||
|
if quote_element.text:
|
||||||
|
inner_content += quote_element.text
|
||||||
|
for child in quote_element:
|
||||||
|
inner_content += ET.tostring(
|
||||||
|
child, encoding="unicode", method="xml"
|
||||||
|
)
|
||||||
|
if child.tail:
|
||||||
|
inner_content += child.tail
|
||||||
|
|
||||||
|
# 构造移除了<quote>标签的内容
|
||||||
|
content_without_quote = content.replace(
|
||||||
|
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"quote": {"id": quote_id, "content": inner_content},
|
||||||
|
"content_without_quote": content_without_quote,
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取<quote>标签时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
|
||||||
|
"""转换引用消息"""
|
||||||
|
try:
|
||||||
|
quote_abm = AstrBotMessage()
|
||||||
|
quote_abm.message_id = quote.get("id", "")
|
||||||
|
|
||||||
|
# 解析引用消息的发送者
|
||||||
|
quote_author = quote.get("author", {})
|
||||||
|
if quote_author:
|
||||||
|
quote_abm.sender = MessageMember(
|
||||||
|
user_id=quote_author.get("id", ""),
|
||||||
|
nickname=quote_author.get("nick", quote_author.get("name", "")),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果没有作者信息,使用默认值
|
||||||
|
quote_abm.sender = MessageMember(
|
||||||
|
user_id=quote.get("user_id", ""),
|
||||||
|
nickname="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析引用消息内容
|
||||||
|
quote_content = quote.get("content", "")
|
||||||
|
quote_abm.message = await self.parse_satori_elements(quote_content)
|
||||||
|
|
||||||
|
quote_abm.message_str = ""
|
||||||
|
for comp in quote_abm.message:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
quote_abm.message_str += comp.text
|
||||||
|
|
||||||
|
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
|
||||||
|
|
||||||
|
# 如果没有任何内容,使用默认文本
|
||||||
|
if not quote_abm.message_str.strip():
|
||||||
|
quote_abm.message_str = "[引用消息]"
|
||||||
|
|
||||||
|
return quote_abm
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"转换引用消息失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def parse_satori_elements(self, content: str) -> list:
|
async def parse_satori_elements(self, content: str) -> list:
|
||||||
"""解析 Satori 消息元素"""
|
"""解析 Satori 消息元素"""
|
||||||
elements = []
|
elements = []
|
||||||
@@ -341,12 +551,35 @@ class SatoriPlatformAdapter(Platform):
|
|||||||
return elements
|
return elements
|
||||||
|
|
||||||
try:
|
try:
|
||||||
wrapped_content = f"<root>{content}</root>"
|
# 处理命名空间前缀问题
|
||||||
root = ET.fromstring(wrapped_content)
|
processed_content = content
|
||||||
|
if ":" in content and not content.startswith("<root"):
|
||||||
|
prefixes = self._extract_namespace_prefixes(content)
|
||||||
|
|
||||||
|
# 构建命名空间声明
|
||||||
|
ns_declarations = " ".join(
|
||||||
|
[
|
||||||
|
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
||||||
|
for prefix in prefixes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 包装内容
|
||||||
|
processed_content = f"<root {ns_declarations}>{content}</root>"
|
||||||
|
elif not content.startswith("<root"):
|
||||||
|
processed_content = f"<root>{content}</root>"
|
||||||
|
else:
|
||||||
|
processed_content = content
|
||||||
|
|
||||||
|
root = ET.fromstring(processed_content)
|
||||||
await self._parse_xml_node(root, elements)
|
await self._parse_xml_node(root, elements)
|
||||||
except ET.ParseError as e:
|
except ET.ParseError as e:
|
||||||
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
|
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
|
||||||
|
# 如果解析失败,将整个内容当作纯文本
|
||||||
|
if content.strip():
|
||||||
|
elements.append(Plain(text=content))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# 如果没有解析到任何元素,将整个内容当作纯文本
|
# 如果没有解析到任何元素,将整个内容当作纯文本
|
||||||
@@ -361,7 +594,12 @@ class SatoriPlatformAdapter(Platform):
|
|||||||
elements.append(Plain(text=node.text))
|
elements.append(Plain(text=node.text))
|
||||||
|
|
||||||
for child in node:
|
for child in node:
|
||||||
tag_name = child.tag.lower()
|
# 获取标签名,去除命名空间前缀
|
||||||
|
tag_name = child.tag
|
||||||
|
if "}" in tag_name:
|
||||||
|
tag_name = tag_name.split("}")[1]
|
||||||
|
tag_name = tag_name.lower()
|
||||||
|
|
||||||
attrs = child.attrib
|
attrs = child.attrib
|
||||||
|
|
||||||
if tag_name == "at":
|
if tag_name == "at":
|
||||||
@@ -372,31 +610,59 @@ class SatoriPlatformAdapter(Platform):
|
|||||||
src = attrs.get("src", "")
|
src = attrs.get("src", "")
|
||||||
if not src:
|
if not src:
|
||||||
continue
|
continue
|
||||||
if src.startswith("data:image/"):
|
elements.append(Image(file=src))
|
||||||
src = src.split(",")[1]
|
|
||||||
elements.append(Image.fromBase64(src))
|
|
||||||
elif src.startswith("http"):
|
|
||||||
elements.append(Image.fromURL(src))
|
|
||||||
else:
|
|
||||||
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
|
|
||||||
|
|
||||||
elif tag_name == "file":
|
elif tag_name == "file":
|
||||||
src = attrs.get("src", "")
|
src = attrs.get("src", "")
|
||||||
name = attrs.get("name", "文件")
|
name = attrs.get("name", "文件")
|
||||||
if src:
|
if src:
|
||||||
elements.append(File(file=src, name=name))
|
elements.append(File(name=name, file=src))
|
||||||
|
|
||||||
elif tag_name in ("audio", "record"):
|
elif tag_name in ("audio", "record"):
|
||||||
src = attrs.get("src", "")
|
src = attrs.get("src", "")
|
||||||
if not src:
|
if not src:
|
||||||
continue
|
continue
|
||||||
if src.startswith("data:audio/"):
|
elements.append(Record(file=src))
|
||||||
src = src.split(",")[1]
|
|
||||||
elements.append(Record.fromBase64(src))
|
elif tag_name == "quote":
|
||||||
elif src.startswith("http"):
|
# quote标签已经被特殊处理
|
||||||
elements.append(Record.fromURL(src))
|
pass
|
||||||
|
|
||||||
|
elif tag_name == "face":
|
||||||
|
face_id = attrs.get("id", "")
|
||||||
|
face_name = attrs.get("name", "")
|
||||||
|
face_type = attrs.get("type", "")
|
||||||
|
|
||||||
|
if face_name:
|
||||||
|
elements.append(Plain(text=f"[表情:{face_name}]"))
|
||||||
|
elif face_id and face_type:
|
||||||
|
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
|
||||||
|
elif face_id:
|
||||||
|
elements.append(Plain(text=f"[表情ID:{face_id}]"))
|
||||||
else:
|
else:
|
||||||
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
|
elements.append(Plain(text="[表情]"))
|
||||||
|
|
||||||
|
elif tag_name == "ark":
|
||||||
|
# 作为纯文本添加到消息链中
|
||||||
|
data = attrs.get("data", "")
|
||||||
|
if data:
|
||||||
|
import html
|
||||||
|
|
||||||
|
decoded_data = html.unescape(data)
|
||||||
|
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
||||||
|
else:
|
||||||
|
elements.append(Plain(text="[ARK卡片]"))
|
||||||
|
|
||||||
|
elif tag_name == "json":
|
||||||
|
# JSON标签 视为ARK卡片消息
|
||||||
|
data = attrs.get("data", "")
|
||||||
|
if data:
|
||||||
|
import html
|
||||||
|
|
||||||
|
decoded_data = html.unescape(data)
|
||||||
|
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
||||||
|
else:
|
||||||
|
elements.append(Plain(text="[JSON卡片]"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 未知标签,递归处理其内容
|
# 未知标签,递归处理其内容
|
||||||
|
|||||||
@@ -17,6 +17,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
adapter: "SatoriPlatformAdapter",
|
adapter: "SatoriPlatformAdapter",
|
||||||
):
|
):
|
||||||
|
# 更新平台元数据
|
||||||
|
if adapter and hasattr(adapter, "logins") and adapter.logins:
|
||||||
|
current_login = adapter.logins[0]
|
||||||
|
platform_name = current_login.get("platform", "satori")
|
||||||
|
user = current_login.get("user", {})
|
||||||
|
user_id = user.get("id", "") if user else ""
|
||||||
|
if not platform_meta.id and user_id:
|
||||||
|
platform_meta.id = f"{platform_name}({user_id})"
|
||||||
|
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.adapter = adapter
|
self.adapter = adapter
|
||||||
self.platform = None
|
self.platform = None
|
||||||
|
|||||||
@@ -218,7 +218,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
try:
|
try:
|
||||||
msg = await self.client.send_message(text=delta, **payload)
|
msg = await self.client.send_message(text=delta, **payload)
|
||||||
current_content = delta
|
current_content = delta
|
||||||
delta = ""
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
message_id = msg.message_id
|
message_id = msg.message_id
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"wecom",
|
"wecom",
|
||||||
"wecom 适配器",
|
"wecom 适配器",
|
||||||
|
id=self.config.get("id", "wecom"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -184,6 +184,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"weixin_official_account",
|
"weixin_official_account",
|
||||||
"微信公众平台 适配器",
|
"微信公众平台 适配器",
|
||||||
|
id=self.config.get("id", "weixin_official_account"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -65,13 +65,16 @@ class AssistantMessageSegment:
|
|||||||
role: str = "assistant"
|
role: str = "assistant"
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
ret = {
|
ret: dict[str, str | list[dict]] = {
|
||||||
"role": self.role,
|
"role": self.role,
|
||||||
}
|
}
|
||||||
if self.content:
|
if self.content:
|
||||||
ret["content"] = self.content
|
ret["content"] = self.content
|
||||||
if self.tool_calls:
|
if self.tool_calls:
|
||||||
ret["tool_calls"] = self.tool_calls
|
tool_calls_dict = [
|
||||||
|
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
|
||||||
|
]
|
||||||
|
ret["tool_calls"] = tool_calls_dict
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +120,14 @@ class ProviderRequest:
|
|||||||
"""模型名称,为 None 时使用提供商的默认模型"""
|
"""模型名称,为 None 时使用提供商的默认模型"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
return (
|
||||||
|
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
|
||||||
|
f"image_count={len(self.image_urls or [])}, "
|
||||||
|
f"func_tool={self.func_tool}, "
|
||||||
|
f"contexts={self._print_friendly_context()}, "
|
||||||
|
f"system_prompt={self.system_prompt}, "
|
||||||
|
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
|
||||||
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from typing import Dict, List, Awaitable
|
from typing import Dict, List, Awaitable, Callable, Any
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core import sp
|
from astrbot.core import sp
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
func_args: list,
|
func_args: list,
|
||||||
desc: str,
|
desc: str,
|
||||||
handler: Awaitable,
|
handler: Callable[..., Awaitable[Any]],
|
||||||
) -> FuncTool:
|
) -> FuncTool:
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
@@ -132,7 +132,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
func_args: list,
|
func_args: list,
|
||||||
desc: str,
|
desc: str,
|
||||||
handler: Awaitable,
|
handler: Callable[..., Awaitable[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""添加函数调用工具
|
"""添加函数调用工具
|
||||||
|
|
||||||
@@ -220,7 +220,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
cfg: dict,
|
cfg: dict,
|
||||||
event: asyncio.Event,
|
event: asyncio.Event,
|
||||||
ready_future: asyncio.Future = None,
|
ready_future: asyncio.Future | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -7,7 +7,13 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
|
|
||||||
from .entities import ProviderType
|
from .entities import ProviderType
|
||||||
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
|
from .provider import (
|
||||||
|
Provider,
|
||||||
|
STTProvider,
|
||||||
|
TTSProvider,
|
||||||
|
EmbeddingProvider,
|
||||||
|
RerankProvider,
|
||||||
|
)
|
||||||
from .register import llm_tools, provider_cls_map
|
from .register import llm_tools, provider_cls_map
|
||||||
from ..persona_mgr import PersonaManager
|
from ..persona_mgr import PersonaManager
|
||||||
|
|
||||||
@@ -38,7 +44,12 @@ class ProviderManager:
|
|||||||
"""加载的 Text To Speech Provider 的实例"""
|
"""加载的 Text To Speech Provider 的实例"""
|
||||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||||
"""加载的 Embedding Provider 的实例"""
|
"""加载的 Embedding Provider 的实例"""
|
||||||
self.inst_map: dict[str, Provider] = {}
|
self.rerank_provider_insts: List[RerankProvider] = []
|
||||||
|
"""加载的 Rerank Provider 的实例"""
|
||||||
|
self.inst_map: dict[
|
||||||
|
str,
|
||||||
|
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
||||||
|
] = {}
|
||||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||||
self.llm_tools = llm_tools
|
self.llm_tools = llm_tools
|
||||||
|
|
||||||
@@ -87,19 +98,31 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
# 不启用提供商会话隔离模式的情况
|
# 不启用提供商会话隔离模式的情况
|
||||||
self.curr_provider_inst = self.inst_map[provider_id]
|
|
||||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
prov = self.inst_map[provider_id]
|
||||||
|
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
||||||
|
prov, TTSProvider
|
||||||
|
):
|
||||||
|
self.curr_tts_provider_inst = prov
|
||||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
||||||
|
prov, STTProvider
|
||||||
|
):
|
||||||
|
self.curr_stt_provider_inst = prov
|
||||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
||||||
|
prov, Provider
|
||||||
|
):
|
||||||
|
self.curr_provider_inst = prov
|
||||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||||
|
|
||||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||||
"""根据提供商 ID 获取提供商实例"""
|
"""根据提供商 ID 获取提供商实例"""
|
||||||
return self.inst_map.get(provider_id)
|
return self.inst_map.get(provider_id)
|
||||||
|
|
||||||
def get_using_provider(self, provider_type: ProviderType, umo=None):
|
def get_using_provider(
|
||||||
|
self, provider_type: ProviderType, umo=None
|
||||||
|
) -> Provider | STTProvider | TTSProvider | None:
|
||||||
"""获取正在使用的提供商实例。
|
"""获取正在使用的提供商实例。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -211,6 +234,8 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
case "dify":
|
case "dify":
|
||||||
from .sources.dify_source import ProviderDify as ProviderDify
|
from .sources.dify_source import ProviderDify as ProviderDify
|
||||||
|
case "coze":
|
||||||
|
from .sources.coze_source import ProviderCoze as ProviderCoze
|
||||||
case "dashscope":
|
case "dashscope":
|
||||||
from .sources.dashscope_source import (
|
from .sources.dashscope_source import (
|
||||||
ProviderDashscope as ProviderDashscope,
|
ProviderDashscope as ProviderDashscope,
|
||||||
@@ -303,12 +328,14 @@ class ProviderManager:
|
|||||||
provider_metadata = provider_cls_map[provider_config["type"]]
|
provider_metadata = provider_cls_map[provider_config["type"]]
|
||||||
try:
|
try:
|
||||||
# 按任务实例化提供商
|
# 按任务实例化提供商
|
||||||
|
cls_type = provider_metadata.cls_type
|
||||||
|
if not cls_type:
|
||||||
|
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
||||||
|
return
|
||||||
|
|
||||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||||
# STT 任务
|
# STT 任务
|
||||||
inst = provider_metadata.cls_type(
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
provider_config, self.provider_settings
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
@@ -327,9 +354,7 @@ class ProviderManager:
|
|||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||||
# TTS 任务
|
# TTS 任务
|
||||||
inst = provider_metadata.cls_type(
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
provider_config, self.provider_settings
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
@@ -345,7 +370,7 @@ class ProviderManager:
|
|||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||||
# 文本生成任务
|
# 文本生成任务
|
||||||
inst = provider_metadata.cls_type(
|
inst = cls_type(
|
||||||
provider_config,
|
provider_config,
|
||||||
self.provider_settings,
|
self.provider_settings,
|
||||||
self.selected_default_persona,
|
self.selected_default_persona,
|
||||||
@@ -366,16 +391,16 @@ class ProviderManager:
|
|||||||
if not self.curr_provider_inst:
|
if not self.curr_provider_inst:
|
||||||
self.curr_provider_inst = inst
|
self.curr_provider_inst = inst
|
||||||
|
|
||||||
elif provider_metadata.provider_type in [
|
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||||
ProviderType.EMBEDDING,
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
ProviderType.RERANK,
|
|
||||||
]:
|
|
||||||
inst = provider_metadata.cls_type(
|
|
||||||
provider_config, self.provider_settings
|
|
||||||
)
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
self.embedding_provider_insts.append(inst)
|
self.embedding_provider_insts.append(inst)
|
||||||
|
elif provider_metadata.provider_type == ProviderType.RERANK:
|
||||||
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
|
if getattr(inst, "initialize", None):
|
||||||
|
await inst.initialize()
|
||||||
|
self.rerank_provider_insts.append(inst)
|
||||||
|
|
||||||
self.inst_map[provider_config["id"]] = inst
|
self.inst_map[provider_config["id"]] = inst
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -430,11 +455,17 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.inst_map[provider_id] in self.provider_insts:
|
if self.inst_map[provider_id] in self.provider_insts:
|
||||||
self.provider_insts.remove(self.inst_map[provider_id])
|
prov_inst = self.inst_map[provider_id]
|
||||||
|
if isinstance(prov_inst, Provider):
|
||||||
|
self.provider_insts.remove(prov_inst)
|
||||||
if self.inst_map[provider_id] in self.stt_provider_insts:
|
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||||||
self.stt_provider_insts.remove(self.inst_map[provider_id])
|
prov_inst = self.inst_map[provider_id]
|
||||||
|
if isinstance(prov_inst, STTProvider):
|
||||||
|
self.stt_provider_insts.remove(prov_inst)
|
||||||
if self.inst_map[provider_id] in self.tts_provider_insts:
|
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||||||
self.tts_provider_insts.remove(self.inst_map[provider_id])
|
prov_inst = self.inst_map[provider_id]
|
||||||
|
if isinstance(prov_inst, TTSProvider):
|
||||||
|
self.tts_provider_insts.remove(prov_inst)
|
||||||
|
|
||||||
if self.inst_map[provider_id] == self.curr_provider_inst:
|
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||||||
self.curr_provider_inst = None
|
self.curr_provider_inst = None
|
||||||
|
|||||||
314
astrbot/core/provider/sources/coze_api_client.py
Normal file
314
astrbot/core/provider/sources/coze_api_client.py
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import io
|
||||||
|
from typing import Dict, List, Any, AsyncGenerator
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
class CozeAPIClient:
|
||||||
|
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_base = api_base
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
async def _ensure_session(self):
|
||||||
|
"""确保HTTP session存在"""
|
||||||
|
if self.session is None:
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
ssl=False if self.api_base.startswith("http://") else True,
|
||||||
|
limit=100,
|
||||||
|
limit_per_host=30,
|
||||||
|
keepalive_timeout=30,
|
||||||
|
enable_cleanup_closed=True,
|
||||||
|
)
|
||||||
|
timeout = aiohttp.ClientTimeout(
|
||||||
|
total=120, # 默认超时时间
|
||||||
|
connect=30,
|
||||||
|
sock_read=120,
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
}
|
||||||
|
self.session = aiohttp.ClientSession(
|
||||||
|
headers=headers, timeout=timeout, connector=connector
|
||||||
|
)
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
async def upload_file(
|
||||||
|
self,
|
||||||
|
file_data: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""上传文件到 Coze 并返回 file_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_data (bytes): 文件的二进制数据
|
||||||
|
Returns:
|
||||||
|
str: 上传成功后返回的 file_id
|
||||||
|
"""
|
||||||
|
session = await self._ensure_session()
|
||||||
|
url = f"{self.api_base}/v1/files/upload"
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_io = io.BytesIO(file_data)
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
data={
|
||||||
|
"file": file_io,
|
||||||
|
},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=60),
|
||||||
|
) as response:
|
||||||
|
if response.status == 401:
|
||||||
|
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||||
|
|
||||||
|
response_text = await response.text()
|
||||||
|
logger.debug(
|
||||||
|
f"文件上传响应状态: {response.status}, 内容: {response_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await response.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise Exception(f"文件上传响应解析失败: {response_text}")
|
||||||
|
|
||||||
|
if result.get("code") != 0:
|
||||||
|
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
|
||||||
|
|
||||||
|
file_id = result["data"]["id"]
|
||||||
|
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("文件上传超时")
|
||||||
|
raise Exception("文件上传超时")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文件上传失败: {str(e)}")
|
||||||
|
raise Exception(f"文件上传失败: {str(e)}")
|
||||||
|
|
||||||
|
async def download_image(self, image_url: str) -> bytes:
|
||||||
|
"""下载图片并返回字节数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url (str): 图片的URL
|
||||||
|
Returns:
|
||||||
|
bytes: 图片的二进制数据
|
||||||
|
"""
|
||||||
|
session = await self._ensure_session()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(image_url) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(f"下载图片失败,状态码: {response.status}")
|
||||||
|
|
||||||
|
image_data = await response.read()
|
||||||
|
return image_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载图片失败 {image_url}: {str(e)}")
|
||||||
|
raise Exception(f"下载图片失败: {str(e)}")
|
||||||
|
|
||||||
|
async def chat_messages(
|
||||||
|
self,
|
||||||
|
bot_id: str,
|
||||||
|
user_id: str,
|
||||||
|
additional_messages: List[Dict] | None = None,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
auto_save_history: bool = True,
|
||||||
|
stream: bool = True,
|
||||||
|
timeout: float = 120,
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""发送聊天消息并返回流式响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot_id: Bot ID
|
||||||
|
user_id: 用户ID
|
||||||
|
additional_messages: 额外消息列表
|
||||||
|
conversation_id: 会话ID
|
||||||
|
auto_save_history: 是否自动保存历史
|
||||||
|
stream: 是否流式响应
|
||||||
|
timeout: 超时时间
|
||||||
|
"""
|
||||||
|
session = await self._ensure_session()
|
||||||
|
url = f"{self.api_base}/v3/chat"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"bot_id": bot_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"stream": stream,
|
||||||
|
"auto_save_history": auto_save_history,
|
||||||
|
}
|
||||||
|
|
||||||
|
if additional_messages:
|
||||||
|
payload["additional_messages"] = additional_messages
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
if conversation_id:
|
||||||
|
params["conversation_id"] = conversation_id
|
||||||
|
|
||||||
|
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
params=params,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||||
|
) as response:
|
||||||
|
if response.status == 401:
|
||||||
|
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
||||||
|
|
||||||
|
# SSE
|
||||||
|
buffer = ""
|
||||||
|
event_type = None
|
||||||
|
event_data = None
|
||||||
|
|
||||||
|
async for chunk in response.content:
|
||||||
|
if chunk:
|
||||||
|
buffer += chunk.decode("utf-8", errors="ignore")
|
||||||
|
lines = buffer.split("\n")
|
||||||
|
buffer = lines[-1]
|
||||||
|
|
||||||
|
for line in lines[:-1]:
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
if not line:
|
||||||
|
if event_type and event_data:
|
||||||
|
yield {"event": event_type, "data": event_data}
|
||||||
|
event_type = None
|
||||||
|
event_data = None
|
||||||
|
elif line.startswith("event:"):
|
||||||
|
event_type = line[6:].strip()
|
||||||
|
elif line.startswith("data:"):
|
||||||
|
data_str = line[5:].strip()
|
||||||
|
if data_str and data_str != "[DONE]":
|
||||||
|
try:
|
||||||
|
event_data = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
event_data = {"content": data_str}
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Coze API 流式请求失败: {str(e)}")
|
||||||
|
|
||||||
|
async def clear_context(self, conversation_id: str):
|
||||||
|
"""清空会话上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: 会话ID
|
||||||
|
Returns:
|
||||||
|
dict: API响应结果
|
||||||
|
"""
|
||||||
|
session = await self._ensure_session()
|
||||||
|
url = f"{self.api_base}/v3/conversation/message/clear_context"
|
||||||
|
payload = {"conversation_id": conversation_id}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(url, json=payload) as response:
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
if response.status == 401:
|
||||||
|
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(response_text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise Exception("Coze API 返回非JSON格式")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise Exception("Coze API 请求超时")
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise Exception(f"Coze API 请求失败: {str(e)}")
|
||||||
|
|
||||||
|
async def get_message_list(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
order: str = "desc",
|
||||||
|
limit: int = 10,
|
||||||
|
offset: int = 0,
|
||||||
|
):
|
||||||
|
"""获取消息列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: 会话ID
|
||||||
|
order: 排序方式 (asc/desc)
|
||||||
|
limit: 限制数量
|
||||||
|
offset: 偏移量
|
||||||
|
Returns:
|
||||||
|
dict: API响应结果
|
||||||
|
"""
|
||||||
|
session = await self._ensure_session()
|
||||||
|
url = f"{self.api_base}/v3/conversation/message/list"
|
||||||
|
params = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"order": order,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(url, params=params) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取Coze消息列表失败: {str(e)}")
|
||||||
|
raise Exception(f"获取Coze消息列表失败: {str(e)}")
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭会话"""
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def test_coze_api_client():
|
||||||
|
api_key = os.getenv("COZE_API_KEY", "")
|
||||||
|
bot_id = os.getenv("COZE_BOT_ID", "")
|
||||||
|
client = CozeAPIClient(api_key=api_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open("README.md", "rb") as f:
|
||||||
|
file_data = f.read()
|
||||||
|
file_id = await client.upload_file(file_data)
|
||||||
|
print(f"Uploaded file_id: {file_id}")
|
||||||
|
async for event in client.chat_messages(
|
||||||
|
bot_id=bot_id,
|
||||||
|
user_id="test_user",
|
||||||
|
additional_messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": json.dumps(
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "这是什么"},
|
||||||
|
{"type": "file", "file_id": file_id},
|
||||||
|
],
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
"content_type": "object_string",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
print(f"Event: {event}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
asyncio.run(test_coze_api_client())
|
||||||
635
astrbot/core/provider/sources/coze_source.py
Normal file
635
astrbot/core/provider/sources/coze_source.py
Normal file
@@ -0,0 +1,635 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from typing import AsyncGenerator, Dict
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
|
from astrbot.api.provider import Provider
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
|
from ..register import register_provider_adapter
|
||||||
|
from .coze_api_client import CozeAPIClient
|
||||||
|
|
||||||
|
|
||||||
|
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
|
||||||
|
class ProviderCoze(Provider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_config,
|
||||||
|
provider_settings,
|
||||||
|
default_persona=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
provider_config,
|
||||||
|
provider_settings,
|
||||||
|
default_persona,
|
||||||
|
)
|
||||||
|
self.api_key = provider_config.get("coze_api_key", "")
|
||||||
|
if not self.api_key:
|
||||||
|
raise Exception("Coze API Key 不能为空。")
|
||||||
|
self.bot_id = provider_config.get("bot_id", "")
|
||||||
|
if not self.bot_id:
|
||||||
|
raise Exception("Coze Bot ID 不能为空。")
|
||||||
|
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||||
|
|
||||||
|
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||||
|
("http://", "https://")
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timeout = provider_config.get("timeout", 120)
|
||||||
|
if isinstance(self.timeout, str):
|
||||||
|
self.timeout = int(self.timeout)
|
||||||
|
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||||
|
self.conversation_ids: Dict[str, str] = {}
|
||||||
|
self.file_id_cache: Dict[str, Dict[str, str]] = {}
|
||||||
|
|
||||||
|
# 创建 API 客户端
|
||||||
|
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||||
|
|
||||||
|
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
|
||||||
|
"""生成统一的缓存键
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 图片数据或路径
|
||||||
|
is_base64: 是否是 base64 数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 缓存键
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_base64 and data.startswith("data:image/"):
|
||||||
|
try:
|
||||||
|
header, encoded = data.split(",", 1)
|
||||||
|
image_bytes = base64.b64decode(encoded)
|
||||||
|
cache_key = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
return cache_key
|
||||||
|
except Exception:
|
||||||
|
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
|
||||||
|
return cache_key
|
||||||
|
else:
|
||||||
|
if data.startswith(("http://", "https://")):
|
||||||
|
# URL图片,使用URL作为缓存键
|
||||||
|
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||||
|
return cache_key
|
||||||
|
else:
|
||||||
|
clean_path = (
|
||||||
|
data.split("_")[0]
|
||||||
|
if "_" in data and len(data.split("_")) >= 3
|
||||||
|
else data
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.exists(clean_path):
|
||||||
|
with open(clean_path, "rb") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
cache_key = hashlib.md5(file_content).hexdigest()
|
||||||
|
return cache_key
|
||||||
|
else:
|
||||||
|
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
|
||||||
|
return cache_key
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||||
|
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
|
||||||
|
return cache_key
|
||||||
|
|
||||||
|
async def _upload_file(
|
||||||
|
self,
|
||||||
|
file_data: bytes,
|
||||||
|
session_id: str | None = None,
|
||||||
|
cache_key: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""上传文件到 Coze 并返回 file_id"""
|
||||||
|
# 使用 API 客户端上传文件
|
||||||
|
file_id = await self.api_client.upload_file(file_data)
|
||||||
|
|
||||||
|
# 缓存 file_id
|
||||||
|
if session_id and cache_key:
|
||||||
|
if session_id not in self.file_id_cache:
|
||||||
|
self.file_id_cache[session_id] = {}
|
||||||
|
self.file_id_cache[session_id][cache_key] = file_id
|
||||||
|
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||||
|
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
async def _download_and_upload_image(
|
||||||
|
self, image_url: str, session_id: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""下载图片并上传到 Coze,返回 file_id"""
|
||||||
|
# 计算哈希实现缓存
|
||||||
|
cache_key = self._generate_cache_key(image_url) if session_id else None
|
||||||
|
|
||||||
|
if session_id and cache_key:
|
||||||
|
if session_id not in self.file_id_cache:
|
||||||
|
self.file_id_cache[session_id] = {}
|
||||||
|
|
||||||
|
if cache_key in self.file_id_cache[session_id]:
|
||||||
|
file_id = self.file_id_cache[session_id][cache_key]
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_data = await self.api_client.download_image(image_url)
|
||||||
|
|
||||||
|
file_id = await self._upload_file(image_data, session_id, cache_key)
|
||||||
|
|
||||||
|
if session_id and cache_key:
|
||||||
|
self.file_id_cache[session_id][cache_key] = file_id
|
||||||
|
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理图片失败 {image_url}: {str(e)}")
|
||||||
|
raise Exception(f"处理图片失败: {str(e)}")
|
||||||
|
|
||||||
|
async def _process_context_images(
|
||||||
|
self, content: str | list, session_id: str
|
||||||
|
) -> str:
|
||||||
|
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
|
||||||
|
processed_content = []
|
||||||
|
if session_id not in self.file_id_cache:
|
||||||
|
self.file_id_cache[session_id] = {}
|
||||||
|
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
processed_content.append(item)
|
||||||
|
continue
|
||||||
|
if item.get("type") == "text":
|
||||||
|
processed_content.append(item)
|
||||||
|
elif item.get("type") == "image_url":
|
||||||
|
# 处理图片逻辑
|
||||||
|
if "file_id" in item:
|
||||||
|
# 已经有 file_id
|
||||||
|
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
|
||||||
|
processed_content.append(item)
|
||||||
|
else:
|
||||||
|
# 获取图片数据
|
||||||
|
image_data = ""
|
||||||
|
if "image_url" in item and isinstance(item["image_url"], dict):
|
||||||
|
image_data = item["image_url"].get("url", "")
|
||||||
|
elif "data" in item:
|
||||||
|
image_data = item.get("data", "")
|
||||||
|
elif "url" in item:
|
||||||
|
image_data = item.get("url", "")
|
||||||
|
|
||||||
|
if not image_data:
|
||||||
|
continue
|
||||||
|
# 计算哈希用于缓存
|
||||||
|
cache_key = self._generate_cache_key(
|
||||||
|
image_data, is_base64=image_data.startswith("data:image/")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if cache_key in self.file_id_cache[session_id]:
|
||||||
|
file_id = self.file_id_cache[session_id][cache_key]
|
||||||
|
processed_content.append(
|
||||||
|
{"type": "image", "file_id": file_id}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 上传图片并缓存
|
||||||
|
if image_data.startswith("data:image/"):
|
||||||
|
# base64 处理
|
||||||
|
_, encoded = image_data.split(",", 1)
|
||||||
|
image_bytes = base64.b64decode(encoded)
|
||||||
|
file_id = await self._upload_file(
|
||||||
|
image_bytes,
|
||||||
|
session_id,
|
||||||
|
cache_key,
|
||||||
|
)
|
||||||
|
elif image_data.startswith(("http://", "https://")):
|
||||||
|
# URL 图片
|
||||||
|
file_id = await self._download_and_upload_image(
|
||||||
|
image_data, session_id
|
||||||
|
)
|
||||||
|
# 为URL图片也添加缓存
|
||||||
|
self.file_id_cache[session_id][cache_key] = file_id
|
||||||
|
elif os.path.exists(image_data):
|
||||||
|
# 本地文件
|
||||||
|
with open(image_data, "rb") as f:
|
||||||
|
image_bytes = f.read()
|
||||||
|
file_id = await self._upload_file(
|
||||||
|
image_bytes,
|
||||||
|
session_id,
|
||||||
|
cache_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"无法处理的图片格式: {image_data[:50]}..."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_content.append(
|
||||||
|
{"type": "image", "file_id": file_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = json.dumps(processed_content, ensure_ascii=False)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理上下文图片失败: {str(e)}")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
return json.dumps(content, ensure_ascii=False)
|
||||||
|
|
||||||
|
async def text_chat(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
session_id=None,
|
||||||
|
image_urls=None,
|
||||||
|
func_tool=None,
|
||||||
|
contexts=None,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""文本对话, 内部使用流式接口实现非流式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): 用户提示词
|
||||||
|
session_id (str): 会话ID
|
||||||
|
image_urls (List[str]): 图片URL列表
|
||||||
|
func_tool (FuncCall): 函数调用工具(不支持)
|
||||||
|
contexts (List): 上下文列表
|
||||||
|
system_prompt (str): 系统提示语
|
||||||
|
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
|
||||||
|
model (str): 模型名称(不支持)
|
||||||
|
Returns:
|
||||||
|
LLMResponse: LLM响应对象
|
||||||
|
"""
|
||||||
|
accumulated_content = ""
|
||||||
|
final_response = None
|
||||||
|
|
||||||
|
async for llm_response in self.text_chat_stream(
|
||||||
|
prompt=prompt,
|
||||||
|
session_id=session_id,
|
||||||
|
image_urls=image_urls,
|
||||||
|
func_tool=func_tool,
|
||||||
|
contexts=contexts,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tool_calls_result=tool_calls_result,
|
||||||
|
model=model,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if llm_response.is_chunk:
|
||||||
|
if llm_response.completion_text:
|
||||||
|
accumulated_content += llm_response.completion_text
|
||||||
|
else:
|
||||||
|
final_response = llm_response
|
||||||
|
|
||||||
|
if final_response:
|
||||||
|
return final_response
|
||||||
|
|
||||||
|
if accumulated_content:
|
||||||
|
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||||
|
return LLMResponse(role="assistant", result_chain=chain)
|
||||||
|
else:
|
||||||
|
return LLMResponse(role="assistant", completion_text="")
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
session_id=None,
|
||||||
|
image_urls=None,
|
||||||
|
func_tool=None,
|
||||||
|
contexts=None,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
"""流式对话接口"""
|
||||||
|
# 用户ID参数(参考文档, 可以自定义)
|
||||||
|
user_id = session_id or kwargs.get("user", "default_user")
|
||||||
|
|
||||||
|
# 获取或创建会话ID
|
||||||
|
conversation_id = self.conversation_ids.get(user_id)
|
||||||
|
|
||||||
|
# 构建消息
|
||||||
|
additional_messages = []
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
if not self.auto_save_history or not conversation_id:
|
||||||
|
additional_messages.append(
|
||||||
|
{"role": "system", "content": system_prompt, "content_type": "text"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.auto_save_history and contexts:
|
||||||
|
# 如果关闭了自动保存历史,传入上下文
|
||||||
|
for ctx in contexts:
|
||||||
|
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||||
|
content = ctx["content"]
|
||||||
|
content_type = ctx.get("content_type", "text")
|
||||||
|
|
||||||
|
# 处理可能包含图片的上下文
|
||||||
|
if (
|
||||||
|
content_type == "object_string"
|
||||||
|
or (isinstance(content, str) and content.startswith("["))
|
||||||
|
or (
|
||||||
|
isinstance(content, list)
|
||||||
|
and any(
|
||||||
|
isinstance(item, dict)
|
||||||
|
and item.get("type") == "image_url"
|
||||||
|
for item in content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
processed_content = await self._process_context_images(
|
||||||
|
content, user_id
|
||||||
|
)
|
||||||
|
additional_messages.append(
|
||||||
|
{
|
||||||
|
"role": ctx["role"],
|
||||||
|
"content": processed_content,
|
||||||
|
"content_type": "object_string",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 纯文本
|
||||||
|
additional_messages.append(
|
||||||
|
{
|
||||||
|
"role": ctx["role"],
|
||||||
|
"content": (
|
||||||
|
content
|
||||||
|
if isinstance(content, str)
|
||||||
|
else json.dumps(content, ensure_ascii=False)
|
||||||
|
),
|
||||||
|
"content_type": "text",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
|
||||||
|
|
||||||
|
if prompt or image_urls:
|
||||||
|
if image_urls:
|
||||||
|
# 多模态
|
||||||
|
object_string_content = []
|
||||||
|
if prompt:
|
||||||
|
object_string_content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
|
for url in image_urls:
|
||||||
|
try:
|
||||||
|
if url.startswith(("http://", "https://")):
|
||||||
|
# 网络图片
|
||||||
|
file_id = await self._download_and_upload_image(
|
||||||
|
url, user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 本地文件或 base64
|
||||||
|
if url.startswith("data:image/"):
|
||||||
|
# base64
|
||||||
|
_, encoded = url.split(",", 1)
|
||||||
|
image_data = base64.b64decode(encoded)
|
||||||
|
cache_key = self._generate_cache_key(
|
||||||
|
url, is_base64=True
|
||||||
|
)
|
||||||
|
file_id = await self._upload_file(
|
||||||
|
image_data, user_id, cache_key
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 本地文件
|
||||||
|
if os.path.exists(url):
|
||||||
|
with open(url, "rb") as f:
|
||||||
|
image_data = f.read()
|
||||||
|
# 用文件路径和修改时间来缓存
|
||||||
|
file_stat = os.stat(url)
|
||||||
|
cache_key = self._generate_cache_key(
|
||||||
|
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
|
||||||
|
is_base64=False,
|
||||||
|
)
|
||||||
|
file_id = await self._upload_file(
|
||||||
|
image_data, user_id, cache_key
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"图片文件不存在: {url}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
object_string_content.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"file_id": file_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理图片失败 {url}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if object_string_content:
|
||||||
|
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||||
|
additional_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": content,
|
||||||
|
"content_type": "object_string",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 纯文本
|
||||||
|
if prompt:
|
||||||
|
additional_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
"content_type": "text",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
accumulated_content = ""
|
||||||
|
message_started = False
|
||||||
|
|
||||||
|
async for chunk in self.api_client.chat_messages(
|
||||||
|
bot_id=self.bot_id,
|
||||||
|
user_id=user_id,
|
||||||
|
additional_messages=additional_messages,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
auto_save_history=self.auto_save_history,
|
||||||
|
stream=True,
|
||||||
|
timeout=self.timeout,
|
||||||
|
):
|
||||||
|
event_type = chunk.get("event")
|
||||||
|
data = chunk.get("data", {})
|
||||||
|
|
||||||
|
if event_type == "conversation.chat.created":
|
||||||
|
if isinstance(data, dict) and "conversation_id" in data:
|
||||||
|
self.conversation_ids[user_id] = data["conversation_id"]
|
||||||
|
|
||||||
|
elif event_type == "conversation.message.delta":
|
||||||
|
if isinstance(data, dict):
|
||||||
|
content = data.get("content", "")
|
||||||
|
if not content and "delta" in data:
|
||||||
|
content = data["delta"].get("content", "")
|
||||||
|
if not content and "text" in data:
|
||||||
|
content = data.get("text", "")
|
||||||
|
|
||||||
|
if content:
|
||||||
|
message_started = True
|
||||||
|
accumulated_content += content
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text=content,
|
||||||
|
is_chunk=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "conversation.message.completed":
|
||||||
|
if isinstance(data, dict):
|
||||||
|
msg_type = data.get("type")
|
||||||
|
if msg_type == "answer" and data.get("role") == "assistant":
|
||||||
|
final_content = data.get("content", "")
|
||||||
|
if not accumulated_content and final_content:
|
||||||
|
chain = MessageChain(chain=[Comp.Plain(final_content)])
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
result_chain=chain,
|
||||||
|
is_chunk=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "conversation.chat.completed":
|
||||||
|
if accumulated_content:
|
||||||
|
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
result_chain=chain,
|
||||||
|
is_chunk=False,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
elif event_type == "done":
|
||||||
|
break
|
||||||
|
|
||||||
|
elif event_type == "error":
|
||||||
|
error_msg = (
|
||||||
|
data.get("message", "未知错误")
|
||||||
|
if isinstance(data, dict)
|
||||||
|
else str(data)
|
||||||
|
)
|
||||||
|
logger.error(f"Coze 流式响应错误: {error_msg}")
|
||||||
|
yield LLMResponse(
|
||||||
|
role="err",
|
||||||
|
completion_text=f"Coze 错误: {error_msg}",
|
||||||
|
is_chunk=False,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not message_started and not accumulated_content:
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text="LLM 未响应任何内容。",
|
||||||
|
is_chunk=False,
|
||||||
|
)
|
||||||
|
elif message_started and accumulated_content:
|
||||||
|
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
result_chain=chain,
|
||||||
|
is_chunk=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Coze 流式请求失败: {str(e)}")
|
||||||
|
yield LLMResponse(
|
||||||
|
role="err",
|
||||||
|
completion_text=f"Coze 流式请求失败: {str(e)}",
|
||||||
|
is_chunk=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def forget(self, session_id: str):
|
||||||
|
"""清空指定会话的上下文"""
|
||||||
|
user_id = session_id
|
||||||
|
conversation_id = self.conversation_ids.get(user_id)
|
||||||
|
|
||||||
|
if user_id in self.file_id_cache:
|
||||||
|
self.file_id_cache.pop(user_id, None)
|
||||||
|
|
||||||
|
if not conversation_id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.api_client.clear_context(conversation_id)
|
||||||
|
|
||||||
|
if "code" in response and response["code"] == 0:
|
||||||
|
self.conversation_ids.pop(user_id, None)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"清空 Coze 会话上下文失败: {response}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清空 Coze 会话失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_current_key(self):
|
||||||
|
"""获取当前API Key"""
|
||||||
|
return self.api_key
|
||||||
|
|
||||||
|
async def set_key(self, key: str):
|
||||||
|
"""设置新的API Key"""
|
||||||
|
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
|
||||||
|
|
||||||
|
async def get_models(self):
|
||||||
|
"""获取可用模型列表"""
|
||||||
|
return [f"bot_{self.bot_id}"]
|
||||||
|
|
||||||
|
def get_model(self):
|
||||||
|
"""获取当前模型"""
|
||||||
|
return f"bot_{self.bot_id}"
|
||||||
|
|
||||||
|
def set_model(self, model: str):
|
||||||
|
"""设置模型(在Coze中是Bot ID)"""
|
||||||
|
if model.startswith("bot_"):
|
||||||
|
self.bot_id = model[4:]
|
||||||
|
else:
|
||||||
|
self.bot_id = model
|
||||||
|
|
||||||
|
async def get_human_readable_context(
|
||||||
|
self, session_id: str, page: int = 1, page_size: int = 10
|
||||||
|
):
|
||||||
|
"""获取人类可读的上下文历史"""
|
||||||
|
user_id = session_id
|
||||||
|
conversation_id = self.conversation_ids.get(user_id)
|
||||||
|
|
||||||
|
if not conversation_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self.api_client.get_message_list(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
order="desc",
|
||||||
|
limit=page_size,
|
||||||
|
offset=(page - 1) * page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get("code") != 0:
|
||||||
|
logger.warning(f"获取 Coze 消息历史失败: {data}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
messages = data.get("data", {}).get("messages", [])
|
||||||
|
|
||||||
|
readable_history = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "unknown")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
msg_type = msg.get("type", "")
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
readable_history.append(f"用户: {content}")
|
||||||
|
elif role == "assistant" and msg_type == "answer":
|
||||||
|
readable_history.append(f"助手: {content}")
|
||||||
|
|
||||||
|
return readable_history
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
"""清理资源"""
|
||||||
|
await self.api_client.close()
|
||||||
@@ -6,6 +6,7 @@ from astrbot.core.provider.provider import (
|
|||||||
TTSProvider,
|
TTSProvider,
|
||||||
STTProvider,
|
STTProvider,
|
||||||
EmbeddingProvider,
|
EmbeddingProvider,
|
||||||
|
RerankProvider,
|
||||||
)
|
)
|
||||||
from astrbot.core.provider.entities import ProviderType
|
from astrbot.core.provider.entities import ProviderType
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
@@ -23,7 +24,7 @@ from .star import star_registry, StarMetadata, star_map
|
|||||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||||
from .filter.command import CommandFilter
|
from .filter.command import CommandFilter
|
||||||
from .filter.regex import RegexFilter
|
from .filter.regex import RegexFilter
|
||||||
from typing import Awaitable
|
from typing import Awaitable, Any, Callable
|
||||||
from astrbot.core.conversation_mgr import ConversationManager
|
from astrbot.core.conversation_mgr import ConversationManager
|
||||||
from astrbot.core.star.filter.platform_adapter_type import (
|
from astrbot.core.star.filter.platform_adapter_type import (
|
||||||
PlatformAdapterType,
|
PlatformAdapterType,
|
||||||
@@ -103,9 +104,14 @@ class Context:
|
|||||||
"""
|
"""
|
||||||
self.provider_manager.provider_insts.append(provider)
|
self.provider_manager.provider_insts.append(provider)
|
||||||
|
|
||||||
def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
def get_provider_by_id(
|
||||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
self, provider_id: str
|
||||||
return self.provider_manager.inst_map.get(provider_id)
|
) -> (
|
||||||
|
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
|
||||||
|
):
|
||||||
|
"""通过 ID 获取对应的 LLM Provider。"""
|
||||||
|
prov = self.provider_manager.inst_map.get(provider_id)
|
||||||
|
return prov
|
||||||
|
|
||||||
def get_all_providers(self) -> List[Provider]:
|
def get_all_providers(self) -> List[Provider]:
|
||||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||||
@@ -130,34 +136,43 @@ class Context:
|
|||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
return self.provider_manager.get_using_provider(
|
prov = self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.CHAT_COMPLETION,
|
provider_type=ProviderType.CHAT_COMPLETION,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
|
if prov and not isinstance(prov, Provider):
|
||||||
|
raise ValueError("返回的 Provider 不是 Provider 类型")
|
||||||
|
return prov
|
||||||
|
|
||||||
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
|
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
|
||||||
"""
|
"""
|
||||||
获取当前使用的用于 TTS 任务的 Provider。
|
获取当前使用的用于 TTS 任务的 Provider。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
return self.provider_manager.get_using_provider(
|
prov = self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
|
if prov and not isinstance(prov, TTSProvider):
|
||||||
|
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
|
||||||
|
return prov
|
||||||
|
|
||||||
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
|
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
|
||||||
"""
|
"""
|
||||||
获取当前使用的用于 STT 任务的 Provider。
|
获取当前使用的用于 STT 任务的 Provider。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
return self.provider_manager.get_using_provider(
|
prov = self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
|
if prov and not isinstance(prov, STTProvider):
|
||||||
|
raise ValueError("返回的 Provider 不是 STTProvider 类型")
|
||||||
|
return prov
|
||||||
|
|
||||||
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
||||||
"""获取 AstrBot 的配置。"""
|
"""获取 AstrBot 的配置。"""
|
||||||
@@ -245,7 +260,11 @@ class Context:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def register_llm_tool(
|
def register_llm_tool(
|
||||||
self, name: str, func_args: list, desc: str, func_obj: Awaitable
|
self,
|
||||||
|
name: str,
|
||||||
|
func_args: list,
|
||||||
|
desc: str,
|
||||||
|
func_obj: Callable[..., Awaitable[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为函数调用(function-calling / tools-use)添加工具。
|
为函数调用(function-calling / tools-use)添加工具。
|
||||||
@@ -267,9 +286,7 @@ class Context:
|
|||||||
desc=desc,
|
desc=desc,
|
||||||
)
|
)
|
||||||
star_handlers_registry.append(md)
|
star_handlers_registry.append(md)
|
||||||
self.provider_manager.llm_tools.add_func(
|
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
|
||||||
name, func_args, desc, func_obj, func_obj
|
|
||||||
)
|
|
||||||
|
|
||||||
def unregister_llm_tool(self, name: str) -> None:
|
def unregister_llm_tool(self, name: str) -> None:
|
||||||
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
|
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
|
||||||
@@ -281,7 +298,7 @@ class Context:
|
|||||||
command_name: str,
|
command_name: str,
|
||||||
desc: str,
|
desc: str,
|
||||||
priority: int,
|
priority: int,
|
||||||
awaitable: Awaitable,
|
awaitable: Callable[..., Awaitable[Any]],
|
||||||
use_regex=False,
|
use_regex=False,
|
||||||
ignore_prefix=False,
|
ignore_prefix=False,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter):
|
|||||||
self.init_handler_md(handler_md)
|
self.init_handler_md(handler_md)
|
||||||
self.custom_filter_list: List[CustomFilter] = []
|
self.custom_filter_list: List[CustomFilter] = []
|
||||||
|
|
||||||
|
# Cache for complete command names list
|
||||||
|
self._cmpl_cmd_names: list | None = None
|
||||||
|
|
||||||
def print_types(self):
|
def print_types(self):
|
||||||
result = ""
|
result = ""
|
||||||
for k, v in self.handler_params.items():
|
for k, v in self.handler_params.items():
|
||||||
@@ -136,6 +139,28 @@ class CommandFilter(HandlerFilter):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_complete_command_names(self):
|
||||||
|
if self._cmpl_cmd_names is not None:
|
||||||
|
return self._cmpl_cmd_names
|
||||||
|
self._cmpl_cmd_names = [
|
||||||
|
f"{parent} {cmd}" if parent else cmd
|
||||||
|
for cmd in [self.command_name] + list(self.alias)
|
||||||
|
for parent in self.parent_command_names or [""]
|
||||||
|
]
|
||||||
|
return self._cmpl_cmd_names
|
||||||
|
|
||||||
|
def startswith(self, message_str: str) -> bool:
|
||||||
|
for full_cmd in self.get_complete_command_names():
|
||||||
|
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def equals(self, message_str: str) -> bool:
|
||||||
|
for full_cmd in self.get_complete_command_names():
|
||||||
|
if message_str == full_cmd:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||||
if not event.is_at_or_wake_command:
|
if not event.is_at_or_wake_command:
|
||||||
return False
|
return False
|
||||||
@@ -145,19 +170,7 @@ class CommandFilter(HandlerFilter):
|
|||||||
|
|
||||||
# 检查是否以指令开头
|
# 检查是否以指令开头
|
||||||
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
||||||
candidates = [self.command_name] + list(self.alias)
|
if not self.startswith(message_str):
|
||||||
ok = False
|
|
||||||
for candidate in candidates:
|
|
||||||
for parent_command_name in self.parent_command_names:
|
|
||||||
if parent_command_name:
|
|
||||||
_full = f"{parent_command_name} {candidate}"
|
|
||||||
else:
|
|
||||||
_full = candidate
|
|
||||||
if message_str.startswith(f"{_full} ") or message_str == _full:
|
|
||||||
message_str = message_str[len(_full) :].strip()
|
|
||||||
ok = True
|
|
||||||
break
|
|
||||||
if not ok:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 分割为列表
|
# 分割为列表
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group_name: str,
|
group_name: str,
|
||||||
alias: set = None,
|
alias: set | None = None,
|
||||||
parent_group: CommandGroupFilter = None,
|
parent_group: CommandGroupFilter | None = None,
|
||||||
):
|
):
|
||||||
self.group_name = group_name
|
self.group_name = group_name
|
||||||
self.alias = alias if alias else set()
|
self.alias = alias if alias else set()
|
||||||
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
self.custom_filter_list: List[CustomFilter] = []
|
self.custom_filter_list: List[CustomFilter] = []
|
||||||
self.parent_group = parent_group
|
self.parent_group = parent_group
|
||||||
|
|
||||||
|
# Cache for complete command names list
|
||||||
|
self._cmpl_cmd_names: list | None = None
|
||||||
|
|
||||||
def add_sub_command_filter(
|
def add_sub_command_filter(
|
||||||
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
|
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
|
||||||
):
|
):
|
||||||
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
"""遍历父节点获取完整的指令名。
|
"""遍历父节点获取完整的指令名。
|
||||||
|
|
||||||
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
|
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
|
||||||
|
if self._cmpl_cmd_names is not None:
|
||||||
|
return self._cmpl_cmd_names
|
||||||
|
|
||||||
parent_cmd_names = (
|
parent_cmd_names = (
|
||||||
self.parent_group.get_complete_command_names() if self.parent_group else []
|
self.parent_group.get_complete_command_names() if self.parent_group else []
|
||||||
)
|
)
|
||||||
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
for parent_cmd_name in parent_cmd_names:
|
for parent_cmd_name in parent_cmd_names:
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
result.append(parent_cmd_name + " " + candidate)
|
result.append(parent_cmd_name + " " + candidate)
|
||||||
|
self._cmpl_cmd_names = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 以树的形式打印出来
|
# 以树的形式打印出来
|
||||||
@@ -54,8 +61,8 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
self,
|
self,
|
||||||
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
event: AstrMessageEvent = None,
|
event: AstrMessageEvent | None = None,
|
||||||
cfg: AstrBotConfig = None,
|
cfg: AstrBotConfig | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
result = ""
|
result = ""
|
||||||
for sub_filter in sub_command_filters:
|
for sub_filter in sub_command_filters:
|
||||||
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def startswith(self, message_str: str) -> bool:
|
||||||
|
return message_str.startswith(tuple(self.get_complete_command_names()))
|
||||||
|
|
||||||
|
def equals(self, message_str: str) -> bool:
|
||||||
|
return message_str in self.get_complete_command_names()
|
||||||
|
|
||||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||||
if not event.is_at_or_wake_command:
|
if not event.is_at_or_wake_command:
|
||||||
return False
|
return False
|
||||||
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
if not self.custom_filter_ok(event, cfg):
|
if not self.custom_filter_ok(event, cfg):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
complete_command_names = self.get_complete_command_names()
|
if self.equals(event.message_str.strip()):
|
||||||
if event.message_str.strip() in complete_command_names:
|
|
||||||
tree = (
|
tree = (
|
||||||
self.group_name
|
self.group_name
|
||||||
+ "\n"
|
+ "\n"
|
||||||
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||||
)
|
)
|
||||||
|
|
||||||
# complete_command_names = [name + " " for name in complete_command_names]
|
return self.startswith(event.message_str)
|
||||||
# return event.message_str.startswith(tuple(complete_command_names))
|
|
||||||
return False
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import enum
|
|||||||
from . import HandlerFilter
|
from . import HandlerFilter
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
|
|
||||||
class PlatformAdapterType(enum.Flag):
|
class PlatformAdapterType(enum.Flag):
|
||||||
@@ -19,6 +18,7 @@ class PlatformAdapterType(enum.Flag):
|
|||||||
VOCECHAT = enum.auto()
|
VOCECHAT = enum.auto()
|
||||||
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
||||||
SATORI = enum.auto()
|
SATORI = enum.auto()
|
||||||
|
MISSKEY = enum.auto()
|
||||||
ALL = (
|
ALL = (
|
||||||
AIOCQHTTP
|
AIOCQHTTP
|
||||||
| QQOFFICIAL
|
| QQOFFICIAL
|
||||||
@@ -33,6 +33,7 @@ class PlatformAdapterType(enum.Flag):
|
|||||||
| VOCECHAT
|
| VOCECHAT
|
||||||
| WEIXIN_OFFICIAL_ACCOUNT
|
| WEIXIN_OFFICIAL_ACCOUNT
|
||||||
| SATORI
|
| SATORI
|
||||||
|
| MISSKEY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -50,15 +51,19 @@ ADAPTER_NAME_2_TYPE = {
|
|||||||
"vocechat": PlatformAdapterType.VOCECHAT,
|
"vocechat": PlatformAdapterType.VOCECHAT,
|
||||||
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
||||||
"satori": PlatformAdapterType.SATORI,
|
"satori": PlatformAdapterType.SATORI,
|
||||||
|
"misskey": PlatformAdapterType.MISSKEY,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PlatformAdapterTypeFilter(HandlerFilter):
|
class PlatformAdapterTypeFilter(HandlerFilter):
|
||||||
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
|
def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str):
|
||||||
self.type_or_str = platform_adapter_type_or_str
|
if isinstance(platform_adapter_type_or_str, str):
|
||||||
|
self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str)
|
||||||
|
else:
|
||||||
|
self.platform_type = platform_adapter_type_or_str
|
||||||
|
|
||||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||||
adapter_name = event.get_platform_name()
|
adapter_name = event.get_platform_name()
|
||||||
if adapter_name in ADAPTER_NAME_2_TYPE:
|
if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None:
|
||||||
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
|
return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ from astrbot.core.star import StarMetadata, star_map
|
|||||||
_warned_register_star = False
|
_warned_register_star = False
|
||||||
|
|
||||||
|
|
||||||
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
def register_star(
|
||||||
|
name: str, author: str, desc: str, version: str, repo: str | None = None
|
||||||
|
):
|
||||||
"""注册一个插件(Star)。
|
"""注册一个插件(Star)。
|
||||||
|
|
||||||
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from ..filter.platform_adapter_type import (
|
|||||||
from ..filter.permission import PermissionTypeFilter, PermissionType
|
from ..filter.permission import PermissionTypeFilter, PermissionType
|
||||||
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
||||||
from ..filter.regex import RegexFilter
|
from ..filter.regex import RegexFilter
|
||||||
from typing import Awaitable
|
from typing import Awaitable, Any, Callable
|
||||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
from astrbot.core.agent.agent import Agent
|
from astrbot.core.agent.agent import Agent
|
||||||
@@ -20,15 +20,19 @@ from astrbot.core.agent.tool import FunctionTool
|
|||||||
from astrbot.core.agent.handoff import HandoffTool
|
from astrbot.core.agent.handoff import HandoffTool
|
||||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
def get_handler_full_name(awaitable: Awaitable) -> str:
|
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
|
||||||
"""获取 Handler 的全名"""
|
"""获取 Handler 的全名"""
|
||||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||||
|
|
||||||
|
|
||||||
def get_handler_or_create(
|
def get_handler_or_create(
|
||||||
handler: Awaitable, event_type: EventType, dont_add=False, **kwargs
|
handler: Callable[..., Awaitable[Any]],
|
||||||
|
event_type: EventType,
|
||||||
|
dont_add=False,
|
||||||
|
**kwargs,
|
||||||
) -> StarHandlerMetadata:
|
) -> StarHandlerMetadata:
|
||||||
"""获取 Handler 或者创建一个新的 Handler"""
|
"""获取 Handler 或者创建一个新的 Handler"""
|
||||||
handler_full_name = get_handler_full_name(handler)
|
handler_full_name = get_handler_full_name(handler)
|
||||||
@@ -59,20 +63,33 @@ def get_handler_or_create(
|
|||||||
|
|
||||||
|
|
||||||
def register_command(
|
def register_command(
|
||||||
command_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
command_name: str | None = None,
|
||||||
|
sub_command: str | None = None,
|
||||||
|
alias: set | None = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""注册一个 Command."""
|
"""注册一个 Command."""
|
||||||
new_command = None
|
new_command = None
|
||||||
add_to_event_filters = False
|
add_to_event_filters = False
|
||||||
if isinstance(command_name, RegisteringCommandable):
|
if isinstance(command_name, RegisteringCommandable):
|
||||||
# 子指令
|
# 子指令
|
||||||
parent_command_names = command_name.parent_group.get_complete_command_names()
|
if sub_command is not None:
|
||||||
|
parent_command_names = (
|
||||||
|
command_name.parent_group.get_complete_command_names()
|
||||||
|
)
|
||||||
new_command = CommandFilter(
|
new_command = CommandFilter(
|
||||||
sub_command, alias, None, parent_command_names=parent_command_names
|
sub_command, alias, None, parent_command_names=parent_command_names
|
||||||
)
|
)
|
||||||
command_name.parent_group.add_sub_command_filter(new_command)
|
command_name.parent_group.add_sub_command_filter(new_command)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"注册指令{command_name} 的子指令时未提供 sub_command 参数。"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 裸指令
|
# 裸指令
|
||||||
|
if command_name is None:
|
||||||
|
logger.warning("注册裸指令时未提供 command_name 参数。")
|
||||||
|
else:
|
||||||
new_command = CommandFilter(command_name, alias, None)
|
new_command = CommandFilter(command_name, alias, None)
|
||||||
add_to_event_filters = True
|
add_to_event_filters = True
|
||||||
|
|
||||||
@@ -84,6 +101,7 @@ def register_command(
|
|||||||
handler_md = get_handler_or_create(
|
handler_md = get_handler_or_create(
|
||||||
awaitable, EventType.AdapterMessageEvent, **kwargs
|
awaitable, EventType.AdapterMessageEvent, **kwargs
|
||||||
)
|
)
|
||||||
|
if new_command:
|
||||||
new_command.init_handler_md(handler_md)
|
new_command.init_handler_md(handler_md)
|
||||||
handler_md.event_filters.append(new_command)
|
handler_md.event_filters.append(new_command)
|
||||||
return awaitable
|
return awaitable
|
||||||
@@ -163,23 +181,35 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def register_command_group(
|
def register_command_group(
|
||||||
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
command_group_name: str | None = None,
|
||||||
|
sub_command: str | None = None,
|
||||||
|
alias: set | None = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""注册一个 CommandGroup"""
|
"""注册一个 CommandGroup"""
|
||||||
new_group = None
|
new_group = None
|
||||||
if isinstance(command_group_name, RegisteringCommandable):
|
if isinstance(command_group_name, RegisteringCommandable):
|
||||||
# 子指令组
|
# 子指令组
|
||||||
|
if sub_command is None:
|
||||||
|
logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定")
|
||||||
|
else:
|
||||||
new_group = CommandGroupFilter(
|
new_group = CommandGroupFilter(
|
||||||
sub_command, alias, parent_group=command_group_name.parent_group
|
sub_command, alias, parent_group=command_group_name.parent_group
|
||||||
)
|
)
|
||||||
command_group_name.parent_group.add_sub_command_filter(new_group)
|
command_group_name.parent_group.add_sub_command_filter(new_group)
|
||||||
else:
|
else:
|
||||||
# 根指令组
|
# 根指令组
|
||||||
|
if command_group_name is None:
|
||||||
|
logger.warning("根指令组的名称未指定")
|
||||||
|
else:
|
||||||
new_group = CommandGroupFilter(command_group_name, alias)
|
new_group = CommandGroupFilter(command_group_name, alias)
|
||||||
|
|
||||||
def decorator(obj):
|
def decorator(obj):
|
||||||
# 根指令组
|
# 根指令组
|
||||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
|
if new_group:
|
||||||
|
handler_md = get_handler_or_create(
|
||||||
|
obj, EventType.AdapterMessageEvent, **kwargs
|
||||||
|
)
|
||||||
handler_md.event_filters.append(new_group)
|
handler_md.event_filters.append(new_group)
|
||||||
|
|
||||||
return RegisteringCommandable(new_group)
|
return RegisteringCommandable(new_group)
|
||||||
@@ -323,7 +353,7 @@ def register_on_llm_response(**kwargs):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def register_llm_tool(name: str = None, **kwargs):
|
def register_llm_tool(name: str | None = None, **kwargs):
|
||||||
"""为函数调用(function-calling / tools-use)添加工具。
|
"""为函数调用(function-calling / tools-use)添加工具。
|
||||||
|
|
||||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||||
@@ -361,9 +391,10 @@ def register_llm_tool(name: str = None, **kwargs):
|
|||||||
if kwargs.get("registering_agent"):
|
if kwargs.get("registering_agent"):
|
||||||
registering_agent = kwargs["registering_agent"]
|
registering_agent = kwargs["registering_agent"]
|
||||||
|
|
||||||
def decorator(awaitable: Awaitable):
|
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||||
docstring = docstring_parser.parse(awaitable.__doc__)
|
func_doc = awaitable.__doc__ or ""
|
||||||
|
docstring = docstring_parser.parse(func_doc)
|
||||||
args = []
|
args = []
|
||||||
for arg in docstring.params:
|
for arg in docstring.params:
|
||||||
if arg.type_name not in SUPPORTED_TYPES:
|
if arg.type_name not in SUPPORTED_TYPES:
|
||||||
@@ -379,20 +410,18 @@ def register_llm_tool(name: str = None, **kwargs):
|
|||||||
)
|
)
|
||||||
# print(llm_tool_name, registering_agent)
|
# print(llm_tool_name, registering_agent)
|
||||||
if not registering_agent:
|
if not registering_agent:
|
||||||
|
doc_desc = docstring.description.strip() if docstring.description else ""
|
||||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||||
llm_tools.add_func(
|
llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler)
|
||||||
llm_tool_name, args, docstring.description.strip(), md.handler
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(registering_agent, RegisteringAgent)
|
assert isinstance(registering_agent, RegisteringAgent)
|
||||||
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
||||||
if registering_agent._agent.tools is None:
|
if registering_agent._agent.tools is None:
|
||||||
registering_agent._agent.tools = []
|
registering_agent._agent.tools = []
|
||||||
registering_agent._agent.tools.append(
|
|
||||||
llm_tools.spec_to_func(
|
desc = docstring.description.strip() if docstring.description else ""
|
||||||
llm_tool_name, args, docstring.description.strip(), awaitable
|
tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable)
|
||||||
)
|
registering_agent._agent.tools.append(tool)
|
||||||
)
|
|
||||||
|
|
||||||
return awaitable
|
return awaitable
|
||||||
|
|
||||||
@@ -413,8 +442,8 @@ class RegisteringAgent:
|
|||||||
def register_agent(
|
def register_agent(
|
||||||
name: str,
|
name: str,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
tools: list[str | FunctionTool] = None,
|
tools: list[str | FunctionTool] | None = None,
|
||||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
|
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
|
||||||
):
|
):
|
||||||
"""注册一个 Agent
|
"""注册一个 Agent
|
||||||
|
|
||||||
@@ -426,7 +455,7 @@ def register_agent(
|
|||||||
"""
|
"""
|
||||||
tools_ = tools or []
|
tools_ = tools or []
|
||||||
|
|
||||||
def decorator(awaitable: Awaitable):
|
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||||
AstrAgent = Agent[AstrAgentContext]
|
AstrAgent = Agent[AstrAgentContext]
|
||||||
agent = AstrAgent(
|
agent = AstrAgent(
|
||||||
name=name,
|
name=name,
|
||||||
|
|||||||
@@ -52,10 +52,6 @@ class SessionServiceManager:
|
|||||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||||
"""检查是否应该处理LLM请求
|
"""检查是否应该处理LLM请求
|
||||||
|
|||||||
@@ -140,6 +140,9 @@ class SessionPluginManager:
|
|||||||
filtered_handlers.append(handler)
|
filtered_handlers.append(handler)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if plugin.name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# 检查插件是否在当前会话中启用
|
# 检查插件是否在当前会话中启用
|
||||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||||
session_id, plugin.name
|
session_id, plugin.name
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic
|
||||||
from .filter import HandlerFilter
|
from .filter import HandlerFilter
|
||||||
from .star import star_map
|
from .star import star_map
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
handlers.append(handler)
|
handlers.append(handler)
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None:
|
||||||
return self.star_handlers_map.get(full_name, None)
|
return self.star_handlers_map.get(full_name, None)
|
||||||
|
|
||||||
def get_handlers_by_module_name(
|
def get_handlers_by_module_name(
|
||||||
@@ -87,7 +87,7 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
return len(self._handlers)
|
return len(self._handlers)
|
||||||
|
|
||||||
|
|
||||||
star_handlers_registry = StarHandlerRegistry()
|
star_handlers_registry = StarHandlerRegistry() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class EventType(enum.Enum):
|
class EventType(enum.Enum):
|
||||||
@@ -123,7 +123,7 @@ class StarHandlerMetadata:
|
|||||||
handler_module_path: str
|
handler_module_path: str
|
||||||
"""Handler 所在的模块路径。"""
|
"""Handler 所在的模块路径。"""
|
||||||
|
|
||||||
handler: Awaitable
|
handler: Callable[..., Awaitable[Any]]
|
||||||
"""Handler 的函数对象,应当是一个异步函数"""
|
"""Handler 的函数对象,应当是一个异步函数"""
|
||||||
|
|
||||||
event_filters: List[HandlerFilter]
|
event_filters: List[HandlerFilter]
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class PluginManager:
|
|||||||
self.updator = PluginUpdator()
|
self.updator = PluginUpdator()
|
||||||
|
|
||||||
self.context = context
|
self.context = context
|
||||||
self.context._star_manager = self
|
self.context._star_manager = self # type: ignore
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.plugin_store_path = get_astrbot_plugin_path()
|
self.plugin_store_path = get_astrbot_plugin_path()
|
||||||
@@ -478,6 +478,7 @@ class PluginManager:
|
|||||||
if isinstance(func_tool, HandoffTool):
|
if isinstance(func_tool, HandoffTool):
|
||||||
need_apply = []
|
need_apply = []
|
||||||
sub_tools = func_tool.agent.tools
|
sub_tools = func_tool.agent.tools
|
||||||
|
if sub_tools:
|
||||||
for sub_tool in sub_tools:
|
for sub_tool in sub_tools:
|
||||||
if isinstance(sub_tool, FunctionTool):
|
if isinstance(sub_tool, FunctionTool):
|
||||||
need_apply.append(sub_tool)
|
need_apply.append(sub_tool)
|
||||||
@@ -686,6 +687,9 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 从 star_registry 和 star_map 中删除
|
# 从 star_registry 和 star_map 中删除
|
||||||
|
if plugin.module_path is None or root_dir_name is None:
|
||||||
|
raise Exception(f"插件 {plugin_name} 数据不完整,无法卸载。")
|
||||||
|
|
||||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -800,6 +804,8 @@ class PluginManager:
|
|||||||
|
|
||||||
async def turn_on_plugin(self, plugin_name: str):
|
async def turn_on_plugin(self, plugin_name: str):
|
||||||
plugin = self.context.get_registered_star(plugin_name)
|
plugin = self.context.get_registered_star(plugin_name)
|
||||||
|
if plugin is None:
|
||||||
|
raise Exception(f"插件 {plugin_name} 不存在。")
|
||||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||||
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
||||||
if plugin.module_path in inactivated_plugins:
|
if plugin.module_path in inactivated_plugins:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar
|
||||||
from astrbot.core.message.components import BaseMessageComponent
|
from astrbot.core.message.components import BaseMessageComponent
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
|
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
|
||||||
@@ -221,7 +221,11 @@ class StarTools:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_llm_tool(
|
def register_llm_tool(
|
||||||
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
|
cls,
|
||||||
|
name: str,
|
||||||
|
func_args: list,
|
||||||
|
desc: str,
|
||||||
|
func_obj: Callable[..., Awaitable[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为函数调用(function-calling/tools-use)添加工具
|
为函数调用(function-calling/tools-use)添加工具
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ class PluginUpdator(RepoZipUpdator):
|
|||||||
if not repo_url:
|
if not repo_url:
|
||||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||||
|
|
||||||
|
if not plugin.root_dir_name:
|
||||||
|
raise Exception(f"插件 {plugin.name} 的根目录名未指定。")
|
||||||
|
|
||||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||||
|
|
||||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||||
|
|||||||
@@ -1,9 +1,33 @@
|
|||||||
|
import codecs
|
||||||
import json
|
import json
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession, ClientResponse
|
||||||
from typing import Dict, List, Any, AsyncGenerator
|
from typing import Dict, List, Any, AsyncGenerator
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
|
||||||
|
decoder = codecs.getincrementaldecoder("utf-8")()
|
||||||
|
buffer = ""
|
||||||
|
async for chunk in resp.content.iter_chunked(8192):
|
||||||
|
buffer += decoder.decode(chunk)
|
||||||
|
while "\n\n" in buffer:
|
||||||
|
block, buffer = buffer.split("\n\n", 1)
|
||||||
|
if block.strip().startswith("data:"):
|
||||||
|
try:
|
||||||
|
yield json.loads(block[5:])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Drop invalid dify json data: {block[5:]}")
|
||||||
|
continue
|
||||||
|
# flush any remaining text
|
||||||
|
buffer += decoder.decode(b"", final=True)
|
||||||
|
if buffer.strip().startswith("data:"):
|
||||||
|
try:
|
||||||
|
yield json.loads(buffer[5:])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Drop invalid dify json data: {buffer[5:]}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DifyAPIClient:
|
class DifyAPIClient:
|
||||||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -33,31 +57,11 @@ class DifyAPIClient:
|
|||||||
) as resp:
|
) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
|
raise Exception(
|
||||||
|
f"Dify /chat-messages 接口请求失败:{resp.status}. {text}"
|
||||||
buffer = ""
|
)
|
||||||
while True:
|
async for event in _stream_sse(resp):
|
||||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
yield event
|
||||||
chunk = await resp.content.read(8192)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
|
|
||||||
buffer += chunk.decode("utf-8")
|
|
||||||
blocks = buffer.split("\n\n")
|
|
||||||
|
|
||||||
# 处理完整的数据块
|
|
||||||
for block in blocks[:-1]:
|
|
||||||
if block.strip() and block.startswith("data:"):
|
|
||||||
try:
|
|
||||||
json_str = block[5:] # 移除 "data:" 前缀
|
|
||||||
json_obj = json.loads(json_str)
|
|
||||||
yield json_obj
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"JSON解析错误: {str(e)}")
|
|
||||||
logger.error(f"原始数据块: {json_str}")
|
|
||||||
|
|
||||||
# 保留最后一个可能不完整的块
|
|
||||||
buffer = blocks[-1] if blocks else ""
|
|
||||||
|
|
||||||
async def workflow_run(
|
async def workflow_run(
|
||||||
self,
|
self,
|
||||||
@@ -77,31 +81,11 @@ class DifyAPIClient:
|
|||||||
) as resp:
|
) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
|
raise Exception(
|
||||||
|
f"Dify /workflows/run 接口请求失败:{resp.status}. {text}"
|
||||||
buffer = ""
|
)
|
||||||
while True:
|
async for event in _stream_sse(resp):
|
||||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
yield event
|
||||||
chunk = await resp.content.read(8192)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
|
|
||||||
buffer += chunk.decode("utf-8")
|
|
||||||
blocks = buffer.split("\n\n")
|
|
||||||
|
|
||||||
# 处理完整的数据块
|
|
||||||
for block in blocks[:-1]:
|
|
||||||
if block.strip() and block.startswith("data:"):
|
|
||||||
try:
|
|
||||||
json_str = block[5:] # 移除 "data:" 前缀
|
|
||||||
json_obj = json.loads(json_str)
|
|
||||||
yield json_obj
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"JSON解析错误: {str(e)}")
|
|
||||||
logger.error(f"原始数据块: {json_str}")
|
|
||||||
|
|
||||||
# 保留最后一个可能不完整的块
|
|
||||||
buffer = blocks[-1] if blocks else ""
|
|
||||||
|
|
||||||
async def file_upload(
|
async def file_upload(
|
||||||
self,
|
self,
|
||||||
@@ -109,11 +93,14 @@ class DifyAPIClient:
|
|||||||
user: str,
|
user: str,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
url = f"{self.api_base}/files/upload"
|
url = f"{self.api_base}/files/upload"
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
payload = {
|
payload = {
|
||||||
"user": user,
|
"user": user,
|
||||||
"file": open(file_path, "rb"),
|
"file": f,
|
||||||
}
|
}
|
||||||
async with self.session.post(url, data=payload, headers=self.headers) as resp:
|
async with self.session.post(
|
||||||
|
url, data=payload, headers=self.headers
|
||||||
|
) as resp:
|
||||||
return await resp.json() # {"id": "xxx", ...}
|
return await resp.json() # {"id": "xxx", ...}
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
|||||||
@@ -227,9 +227,11 @@ async def download_dashboard(
|
|||||||
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
||||||
|
|
||||||
if latest or len(str(version)) != 40:
|
if latest or len(str(version)) != 40:
|
||||||
logger.info(f"准备下载 {version} 发行版本的 AstrBot WebUI 文件")
|
|
||||||
ver_name = "latest" if latest else version
|
ver_name = "latest" if latest else version
|
||||||
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
|
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
|
||||||
|
logger.info(
|
||||||
|
f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
await download_file(dashboard_release_url, path, show_progress=True)
|
await download_file(dashboard_release_url, path, show_progress=True)
|
||||||
except BaseException as _:
|
except BaseException as _:
|
||||||
@@ -241,24 +243,10 @@ async def download_dashboard(
|
|||||||
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
|
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
|
||||||
await download_file(dashboard_release_url, path, show_progress=True)
|
await download_file(dashboard_release_url, path, show_progress=True)
|
||||||
else:
|
else:
|
||||||
logger.info(f"准备下载指定版本的 AstrBot WebUI: {version}")
|
url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
|
||||||
|
logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
|
||||||
url = (
|
|
||||||
"https://api.github.com/repos/AstrBotDevs/astrbot-release-harbour/releases"
|
|
||||||
)
|
|
||||||
if proxy:
|
if proxy:
|
||||||
url = f"{proxy}/{url}"
|
url = f"{proxy}/{url}"
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
await download_file(url, path, show_progress=True)
|
||||||
async with session.get(url) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
releases = await resp.json()
|
|
||||||
for release in releases:
|
|
||||||
if version in release["tag_name"]:
|
|
||||||
download_url = release["assets"][0]["browser_download_url"]
|
|
||||||
await download_file(download_url, path, show_progress=True)
|
|
||||||
else:
|
|
||||||
logger.warning(f"未找到指定的版本的 Dashboard 构建文件: {version}")
|
|
||||||
return
|
|
||||||
|
|
||||||
with zipfile.ZipFile(path, "r") as z:
|
with zipfile.ZipFile(path, "r") as z:
|
||||||
z.extractall(extract_path)
|
z.extractall(extract_path)
|
||||||
|
|||||||
@@ -1,17 +1,27 @@
|
|||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from .route import Route, Response, RouteContext
|
from .route import Route, Response, RouteContext
|
||||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||||
from quart import request, Response as QuartResponse, g, make_response
|
from quart import request, Response as QuartResponse, g, make_response
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
import asyncio
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
from astrbot.core.platform.astr_message_event import MessageSession
|
from astrbot.core.platform.astr_message_event import MessageSession
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def track_conversation(convs: dict, conv_id: str):
|
||||||
|
convs[conv_id] = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
convs.pop(conv_id, None)
|
||||||
|
|
||||||
|
|
||||||
class ChatRoute(Route):
|
class ChatRoute(Route):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -40,6 +50,8 @@ class ChatRoute(Route):
|
|||||||
self.conv_mgr = core_lifecycle.conversation_manager
|
self.conv_mgr = core_lifecycle.conversation_manager
|
||||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||||
|
|
||||||
|
self.running_convs: dict[str, bool] = {}
|
||||||
|
|
||||||
async def get_file(self):
|
async def get_file(self):
|
||||||
filename = request.args.get("filename")
|
filename = request.args.get("filename")
|
||||||
if not filename:
|
if not filename:
|
||||||
@@ -139,12 +151,20 @@ class ChatRoute(Route):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def stream():
|
async def stream():
|
||||||
|
client_disconnected = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with track_conversation(self.running_convs, webchat_conv_id):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(back_queue.get(), timeout=10)
|
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||||
|
client_disconnected = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebChat stream error: {e}")
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
continue
|
continue
|
||||||
@@ -152,8 +172,23 @@ class ChatRoute(Route):
|
|||||||
result_text = result["data"]
|
result_text = result["data"]
|
||||||
type = result.get("type")
|
type = result.get("type")
|
||||||
streaming = result.get("streaming", False)
|
streaming = result.get("streaming", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not client_disconnected:
|
||||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
if not client_disconnected:
|
||||||
|
logger.debug(
|
||||||
|
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
|
||||||
|
)
|
||||||
|
client_disconnected = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not client_disconnected:
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||||
|
client_disconnected = True
|
||||||
|
|
||||||
if type == "end":
|
if type == "end":
|
||||||
break
|
break
|
||||||
@@ -171,10 +206,8 @@ class ChatRoute(Route):
|
|||||||
sender_id="bot",
|
sender_id="bot",
|
||||||
sender_name="bot",
|
sender_name="bot",
|
||||||
)
|
)
|
||||||
|
except BaseException as e:
|
||||||
except BaseException as _:
|
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||||
logger.debug(f"用户 {username} 断开聊天长连接。")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Put message to conversation-specific queue
|
# Put message to conversation-specific queue
|
||||||
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
||||||
@@ -291,6 +324,7 @@ class ChatRoute(Route):
|
|||||||
.ok(
|
.ok(
|
||||||
data={
|
data={
|
||||||
"history": history_res,
|
"history": history_res,
|
||||||
|
"is_running": self.running_convs.get(webchat_conv_id, False),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
.__dict__
|
.__dict__
|
||||||
|
|||||||
@@ -51,24 +51,6 @@ def validate_config(
|
|||||||
def validate(data: dict, metadata: dict = schema, path=""):
|
def validate(data: dict, metadata: dict = schema, path=""):
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key not in metadata:
|
if key not in metadata:
|
||||||
# 无 schema 的配置项,执行类型猜测
|
|
||||||
if isinstance(value, str):
|
|
||||||
try:
|
|
||||||
data[key] = int(value)
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
data[key] = float(value)
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if value.lower() == "true":
|
|
||||||
data[key] = True
|
|
||||||
elif value.lower() == "false":
|
|
||||||
data[key] = False
|
|
||||||
continue
|
continue
|
||||||
meta = metadata[key]
|
meta = metadata[key]
|
||||||
if "type" not in meta:
|
if "type" not in meta:
|
||||||
@@ -127,12 +109,12 @@ def validate_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_core:
|
if is_core:
|
||||||
for key, group in schema.items():
|
meta_all = {
|
||||||
group_meta = group.get("metadata")
|
**schema["platform_group"]["metadata"],
|
||||||
if not group_meta:
|
**schema["provider_group"]["metadata"],
|
||||||
continue
|
**schema["misc_config_group"]["metadata"],
|
||||||
# logger.info(f"验证配置: 组 {key} ...")
|
}
|
||||||
validate(data, group_meta, path=f"{key}.")
|
validate(data, meta_all)
|
||||||
else:
|
else:
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
@@ -142,6 +124,7 @@ def validate_config(
|
|||||||
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
|
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
|
||||||
"""验证并保存配置"""
|
"""验证并保存配置"""
|
||||||
errors = None
|
errors = None
|
||||||
|
logger.info(f"Saving config, is_core={is_core}")
|
||||||
try:
|
try:
|
||||||
if is_core:
|
if is_core:
|
||||||
errors, post_config = validate_config(
|
errors, post_config = validate_config(
|
||||||
|
|||||||
@@ -169,11 +169,61 @@ class ConversationRoute(Route):
|
|||||||
"""删除对话"""
|
"""删除对话"""
|
||||||
try:
|
try:
|
||||||
data = await request.get_json()
|
data = await request.get_json()
|
||||||
|
|
||||||
|
# 检查是否是批量删除
|
||||||
|
if "conversations" in data:
|
||||||
|
# 批量删除
|
||||||
|
conversations = data.get("conversations", [])
|
||||||
|
if not conversations:
|
||||||
|
return (
|
||||||
|
Response().error("批量删除时conversations参数不能为空").__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
failed_items = []
|
||||||
|
|
||||||
|
for conv in conversations:
|
||||||
|
user_id = conv.get("user_id")
|
||||||
|
cid = conv.get("cid")
|
||||||
|
|
||||||
|
if not user_id or not cid:
|
||||||
|
failed_items.append(
|
||||||
|
f"user_id:{user_id}, cid:{cid} - 缺少必要参数"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||||
|
unified_msg_origin=user_id, conversation_id=cid
|
||||||
|
)
|
||||||
|
deleted_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
failed_items.append(f"user_id:{user_id}, cid:{cid} - {str(e)}")
|
||||||
|
|
||||||
|
message = f"成功删除 {deleted_count} 个对话"
|
||||||
|
if failed_items:
|
||||||
|
message += f",失败 {len(failed_items)} 个"
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"message": message,
|
||||||
|
"deleted_count": deleted_count,
|
||||||
|
"failed_count": len(failed_items),
|
||||||
|
"failed_items": failed_items,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 单个删除
|
||||||
user_id = data.get("user_id")
|
user_id = data.get("user_id")
|
||||||
cid = data.get("cid")
|
cid = data.get("cid")
|
||||||
|
|
||||||
if not user_id or not cid:
|
if not user_id or not cid:
|
||||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||||
|
|
||||||
await self.core_lifecycle.conversation_manager.delete_conversation(
|
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||||
unified_msg_origin=user_id, conversation_id=cid
|
unified_msg_origin=user_id, conversation_id=cid
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class SessionManagementRoute(Route):
|
|||||||
"/session/update_tts": ("POST", self.update_session_tts),
|
"/session/update_tts": ("POST", self.update_session_tts),
|
||||||
"/session/update_name": ("POST", self.update_session_name),
|
"/session/update_name": ("POST", self.update_session_name),
|
||||||
"/session/update_status": ("POST", self.update_session_status),
|
"/session/update_status": ("POST", self.update_session_status),
|
||||||
|
"/session/delete": ("POST", self.delete_session),
|
||||||
}
|
}
|
||||||
self.conv_mgr = core_lifecycle.conversation_manager
|
self.conv_mgr = core_lifecycle.conversation_manager
|
||||||
self.core_lifecycle = core_lifecycle
|
self.core_lifecycle = core_lifecycle
|
||||||
@@ -180,39 +181,101 @@ class SessionManagementRoute(Route):
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return Response().error(f"获取会话列表失败: {str(e)}").__dict__
|
return Response().error(f"获取会话列表失败: {str(e)}").__dict__
|
||||||
|
|
||||||
async def update_session_persona(self):
|
async def _update_single_session_persona(self, session_id: str, persona_name: str):
|
||||||
"""更新指定会话的 persona"""
|
"""更新单个会话的 persona 的内部方法"""
|
||||||
try:
|
|
||||||
data = await request.get_json()
|
|
||||||
session_id = data.get("session_id")
|
|
||||||
persona_name = data.get("persona_name")
|
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
return Response().error("缺少必要参数: session_id").__dict__
|
|
||||||
|
|
||||||
if persona_name is None:
|
|
||||||
return Response().error("缺少必要参数: persona_name").__dict__
|
|
||||||
|
|
||||||
# 获取会话当前的对话 ID
|
|
||||||
conversation_manager = self.core_lifecycle.star_context.conversation_manager
|
conversation_manager = self.core_lifecycle.star_context.conversation_manager
|
||||||
conversation_id = await conversation_manager.get_curr_conversation_id(
|
conversation_id = await conversation_manager.get_curr_conversation_id(
|
||||||
session_id
|
session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not conversation_id:
|
conv = None
|
||||||
# 如果没有对话,创建一个新的对话
|
if conversation_id:
|
||||||
conversation_id = await conversation_manager.new_conversation(
|
conv = await conversation_manager.get_conversation(
|
||||||
session_id
|
unified_msg_origin=session_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
)
|
)
|
||||||
|
if not conv or not conversation_id:
|
||||||
|
conversation_id = await conversation_manager.new_conversation(session_id)
|
||||||
|
|
||||||
# 更新 persona
|
# 更新 persona
|
||||||
await conversation_manager.update_conversation_persona_id(
|
await conversation_manager.update_conversation_persona_id(
|
||||||
session_id, persona_name
|
session_id, persona_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _handle_batch_operation(
|
||||||
|
self, session_ids: list, operation_func, operation_name: str, **kwargs
|
||||||
|
):
|
||||||
|
"""通用的批量操作处理方法"""
|
||||||
|
success_count = 0
|
||||||
|
error_sessions = []
|
||||||
|
|
||||||
|
for session_id in session_ids:
|
||||||
|
try:
|
||||||
|
await operation_func(session_id, **kwargs)
|
||||||
|
success_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量{operation_name} 会话 {session_id} 失败: {str(e)}")
|
||||||
|
error_sessions.append(session_id)
|
||||||
|
|
||||||
|
if error_sessions:
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"})
|
.ok(
|
||||||
|
{
|
||||||
|
"message": f"批量更新完成,成功: {success_count},失败: {len(error_sessions)}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": len(error_sessions),
|
||||||
|
"error_sessions": error_sessions,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"message": f"成功批量{operation_name} {success_count} 个会话",
|
||||||
|
"success_count": success_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_session_persona(self):
|
||||||
|
"""更新指定会话的 persona,支持批量操作"""
|
||||||
|
try:
|
||||||
|
data = await request.get_json()
|
||||||
|
is_batch = data.get("is_batch", False)
|
||||||
|
persona_name = data.get("persona_name")
|
||||||
|
|
||||||
|
if persona_name is None:
|
||||||
|
return Response().error("缺少必要参数: persona_name").__dict__
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
session_ids = data.get("session_ids", [])
|
||||||
|
if not session_ids:
|
||||||
|
return Response().error("缺少必要参数: session_ids").__dict__
|
||||||
|
|
||||||
|
return await self._handle_batch_operation(
|
||||||
|
session_ids,
|
||||||
|
self._update_single_session_persona,
|
||||||
|
"更新人格",
|
||||||
|
persona_name=persona_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
session_id = data.get("session_id")
|
||||||
|
if not session_id:
|
||||||
|
return Response().error("缺少必要参数: session_id").__dict__
|
||||||
|
|
||||||
|
await self._update_single_session_persona(session_id, persona_name)
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"
|
||||||
|
}
|
||||||
|
)
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -221,19 +284,29 @@ class SessionManagementRoute(Route):
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return Response().error(f"更新会话人格失败: {str(e)}").__dict__
|
return Response().error(f"更新会话人格失败: {str(e)}").__dict__
|
||||||
|
|
||||||
|
async def _update_single_session_provider(
|
||||||
|
self, session_id: str, provider_id: str, provider_type_enum
|
||||||
|
):
|
||||||
|
"""更新单个会话的 provider 的内部方法"""
|
||||||
|
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||||
|
await provider_manager.set_provider(
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_type=provider_type_enum,
|
||||||
|
umo=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
async def update_session_provider(self):
|
async def update_session_provider(self):
|
||||||
"""更新指定会话的 provider"""
|
"""更新指定会话的 provider,支持批量操作"""
|
||||||
try:
|
try:
|
||||||
data = await request.get_json()
|
data = await request.get_json()
|
||||||
session_id = data.get("session_id")
|
is_batch = data.get("is_batch", False)
|
||||||
provider_id = data.get("provider_id")
|
provider_id = data.get("provider_id")
|
||||||
# "chat_completion", "speech_to_text", "text_to_speech"
|
|
||||||
provider_type = data.get("provider_type")
|
provider_type = data.get("provider_type")
|
||||||
|
|
||||||
if not session_id or not provider_id or not provider_type:
|
if not provider_id or not provider_type:
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error("缺少必要参数: session_id, provider_id, provider_type")
|
.error("缺少必要参数: provider_id, provider_type")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -251,14 +324,26 @@ class SessionManagementRoute(Route):
|
|||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置 provider
|
if is_batch:
|
||||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
session_ids = data.get("session_ids", [])
|
||||||
await provider_manager.set_provider(
|
if not session_ids:
|
||||||
provider_id=provider_id,
|
return Response().error("缺少必要参数: session_ids").__dict__
|
||||||
provider_type=provider_type_enum,
|
|
||||||
umo=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
return await self._handle_batch_operation(
|
||||||
|
session_ids,
|
||||||
|
self._update_single_session_provider,
|
||||||
|
f"更新 {provider_type} 提供商",
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_type_enum=provider_type_enum,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
session_id = data.get("session_id")
|
||||||
|
if not session_id:
|
||||||
|
return Response().error("缺少必要参数: session_id").__dict__
|
||||||
|
|
||||||
|
await self._update_single_session_provider(
|
||||||
|
session_id, provider_id, provider_type_enum
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.ok(
|
.ok(
|
||||||
@@ -376,22 +461,38 @@ class SessionManagementRoute(Route):
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
|
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
|
||||||
|
|
||||||
|
async def _update_single_session_llm(self, session_id: str, enabled: bool):
|
||||||
|
"""更新单个会话的LLM状态的内部方法"""
|
||||||
|
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
|
||||||
|
|
||||||
async def update_session_llm(self):
|
async def update_session_llm(self):
|
||||||
"""更新指定会话的LLM启停状态"""
|
"""更新指定会话的LLM启停状态,支持批量操作"""
|
||||||
try:
|
try:
|
||||||
data = await request.get_json()
|
data = await request.get_json()
|
||||||
session_id = data.get("session_id")
|
is_batch = data.get("is_batch", False)
|
||||||
enabled = data.get("enabled")
|
enabled = data.get("enabled")
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
return Response().error("缺少必要参数: session_id").__dict__
|
|
||||||
|
|
||||||
if enabled is None:
|
if enabled is None:
|
||||||
return Response().error("缺少必要参数: enabled").__dict__
|
return Response().error("缺少必要参数: enabled").__dict__
|
||||||
|
|
||||||
# 使用 SessionServiceManager 更新LLM状态
|
if is_batch:
|
||||||
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
|
session_ids = data.get("session_ids", [])
|
||||||
|
if not session_ids:
|
||||||
|
return Response().error("缺少必要参数: session_ids").__dict__
|
||||||
|
|
||||||
|
result = await self._handle_batch_operation(
|
||||||
|
session_ids,
|
||||||
|
self._update_single_session_llm,
|
||||||
|
f"{'启用' if enabled else '禁用'}LLM",
|
||||||
|
enabled=enabled,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
session_id = data.get("session_id")
|
||||||
|
if not session_id:
|
||||||
|
return Response().error("缺少必要参数: session_id").__dict__
|
||||||
|
|
||||||
|
await self._update_single_session_llm(session_id, enabled)
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.ok(
|
.ok(
|
||||||
@@ -409,22 +510,38 @@ class SessionManagementRoute(Route):
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
|
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
|
||||||
|
|
||||||
|
async def _update_single_session_tts(self, session_id: str, enabled: bool):
|
||||||
|
"""更新单个会话的TTS状态的内部方法"""
|
||||||
|
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
|
||||||
|
|
||||||
async def update_session_tts(self):
|
async def update_session_tts(self):
|
||||||
"""更新指定会话的TTS启停状态"""
|
"""更新指定会话的TTS启停状态,支持批量操作"""
|
||||||
try:
|
try:
|
||||||
data = await request.get_json()
|
data = await request.get_json()
|
||||||
session_id = data.get("session_id")
|
is_batch = data.get("is_batch", False)
|
||||||
enabled = data.get("enabled")
|
enabled = data.get("enabled")
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
return Response().error("缺少必要参数: session_id").__dict__
|
|
||||||
|
|
||||||
if enabled is None:
|
if enabled is None:
|
||||||
return Response().error("缺少必要参数: enabled").__dict__
|
return Response().error("缺少必要参数: enabled").__dict__
|
||||||
|
|
||||||
# 使用 SessionServiceManager 更新TTS状态
|
if is_batch:
|
||||||
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
|
session_ids = data.get("session_ids", [])
|
||||||
|
if not session_ids:
|
||||||
|
return Response().error("缺少必要参数: session_ids").__dict__
|
||||||
|
|
||||||
|
result = await self._handle_batch_operation(
|
||||||
|
session_ids,
|
||||||
|
self._update_single_session_tts,
|
||||||
|
f"{'启用' if enabled else '禁用'}TTS",
|
||||||
|
enabled=enabled,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
session_id = data.get("session_id")
|
||||||
|
if not session_id:
|
||||||
|
return Response().error("缺少必要参数: session_id").__dict__
|
||||||
|
|
||||||
|
await self._update_single_session_tts(session_id, enabled)
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.ok(
|
.ok(
|
||||||
@@ -507,3 +624,43 @@ class SessionManagementRoute(Route):
|
|||||||
error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}"
|
error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__
|
return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__
|
||||||
|
|
||||||
|
async def delete_session(self):
|
||||||
|
"""删除指定会话及其所有相关数据"""
|
||||||
|
try:
|
||||||
|
data = await request.get_json()
|
||||||
|
session_id = data.get("session_id")
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
return Response().error("缺少必要参数: session_id").__dict__
|
||||||
|
|
||||||
|
# 删除会话的所有相关数据
|
||||||
|
conversation_manager = self.core_lifecycle.conversation_manager
|
||||||
|
|
||||||
|
# 1. 删除会话的所有对话
|
||||||
|
try:
|
||||||
|
await conversation_manager.delete_conversations_by_user_id(session_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"删除会话 {session_id} 的对话失败: {str(e)}")
|
||||||
|
|
||||||
|
# 2. 清除会话的偏好设置数据(清空该会话的所有配置)
|
||||||
|
try:
|
||||||
|
await sp.clear_async("umo", session_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清除会话 {session_id} 的偏好设置失败: {str(e)}")
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"message": f"会话 {session_id} 及其相关所有对话数据已成功删除",
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"删除会话失败: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return Response().error(f"删除会话失败: {str(e)}").__dict__
|
||||||
|
|||||||
8
changelogs/v4.1.3.md
Normal file
8
changelogs/v4.1.3.md
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
0. ‼️ fix: 修复 4.0.0 版本之后,配置默认 TTS 或者 STT 模型之后仍无法生效的问题 ([#2758](https://github.com/Soulter/AstrBot/issues/2758))
|
||||||
|
1. ‼️ fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 ([#2757](https://github.com/Soulter/AstrBot/issues/2757))
|
||||||
|
2. feat: 支持在 WebUI 复制提供商配置以简化操作 ([#2767](https://github.com/Soulter/AstrBot/issues/2767))
|
||||||
|
3. fix: handle image value correctly for mcp BlobResourceContents ([#2753](https://github.com/Soulter/AstrBot/issues/2753))
|
||||||
|
4. feat: 增加 QQ 群名称识别到 system prompt, 并提供相应的配置 ([#2770](https://github.com/Soulter/AstrBot/issues/2770))
|
||||||
|
5. fix: parameter type/default handling in CommandFilter
|
||||||
10
changelogs/v4.1.4.md
Normal file
10
changelogs/v4.1.4.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
0. ‼️ fix: 修复 4.0.0 版本之后,配置默认 TTS 或者 STT 模型之后仍无法生效的问题 ([#2758](https://github.com/Soulter/AstrBot/issues/2758))
|
||||||
|
1. ‼️ fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 ([#2757](https://github.com/Soulter/AstrBot/issues/2757))
|
||||||
|
2. feat: 支持在 WebUI 复制提供商配置以简化操作 ([#2767](https://github.com/Soulter/AstrBot/issues/2767))
|
||||||
|
3. fix: handle image value correctly for mcp BlobResourceContents ([#2753](https://github.com/Soulter/AstrBot/issues/2753))
|
||||||
|
4. feat: 增加 QQ 群名称识别到 system prompt, 并提供相应的配置 ([#2770](https://github.com/Soulter/AstrBot/issues/2770))
|
||||||
|
5. fix: 修复 4.1.3 的异常问题
|
||||||
|
|
||||||
|
**总之上个版本有很严重的 bug 赶快更新!**
|
||||||
11
changelogs/v4.1.5.md
Normal file
11
changelogs/v4.1.5.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
0. feat: 新增 Misskey 平台适配器 ([#2774](https://github.com/AstrBotDevs/AstrBot/issues/2774))
|
||||||
|
1. fix: 修复aiocqhttp适配器at会获取群昵称而消息不会获取的逻辑不一致 ([#2769](https://github.com/AstrBotDevs/AstrBot/issues/2769))
|
||||||
|
2. fix: 修复「对话管理」页面的关键词搜索功能失效的问题并优化一些 UI 样式 ([#2837](https://github.com/AstrBotDevs/AstrBot/issues/2837))
|
||||||
|
3. fix: 识别「引用消息」的图片时优先使用默认图片转述提供商 ([#2836](https://github.com/AstrBotDevs/AstrBot/issues/2836))
|
||||||
|
5. fix: 修复 Telegram 下流式传输时,第一次输出的内容会被覆盖掉的问题
|
||||||
|
6. perf: 优化统计页内存占用和消息数据趋势的样式 ([#2826](https://github.com/AstrBotDevs/AstrBot/issues/2826))
|
||||||
|
7. perf: 优化 「插件页」、「对话管理页」、「会话管理页」的样式
|
||||||
|
8. fix: on_tool_end hook unavailable
|
||||||
|
9. feat: add audioop-lts dependencies ([#2809](https://github.com/AstrBotDevs/AstrBot/issues/2809))
|
||||||
3
changelogs/v4.1.6.md
Normal file
3
changelogs/v4.1.6.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
1. fix: 修复在某些情况下,出现 「返回的 Provider 不是 Provider 类型的错误」
|
||||||
8
changelogs/v4.1.7.md
Normal file
8
changelogs/v4.1.7.md
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
1. perf: 优化 WebChat 等组件的 UI 风格
|
||||||
|
2. fix: 修复 4.1.6 版本可能无法点击更新按钮的问题
|
||||||
|
3. fix: 修复更新开发版的时候,可能无法同时更新 WebUI 的问题
|
||||||
|
4. feat: 支持在「对话数据」页批量删除对话
|
||||||
|
5. fix: 修复部分错误地显示「格式校验未通过」的问题
|
||||||
|
6. perf: WebChat 支持手动填写模型名称
|
||||||
1
changelogs/v4.2.0.md
Normal file
1
changelogs/v4.2.0.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# What's Changed
|
||||||
@@ -1,7 +1,28 @@
|
|||||||
<template>
|
<template>
|
||||||
<RouterView></RouterView>
|
<RouterView></RouterView>
|
||||||
|
|
||||||
|
<!-- 全局唯一 snackbar -->
|
||||||
|
<v-snackbar v-if="toastStore.current" v-model="snackbarShow" :color="toastStore.current.color"
|
||||||
|
:timeout="toastStore.current.timeout" :multi-line="toastStore.current.multiLine"
|
||||||
|
:location="toastStore.current.location" close-on-back>
|
||||||
|
{{ toastStore.current.message }}
|
||||||
|
<template #actions v-if="toastStore.current.closable">
|
||||||
|
<v-btn variant="text" @click="snackbarShow = false">关闭</v-btn>
|
||||||
|
</template>
|
||||||
|
</v-snackbar>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup>
|
||||||
import { RouterView } from 'vue-router';
|
import { RouterView } from 'vue-router';
|
||||||
|
import { computed } from 'vue'
|
||||||
|
import { useToastStore } from '@/stores/toast'
|
||||||
|
|
||||||
|
const toastStore = useToastStore()
|
||||||
|
|
||||||
|
const snackbarShow = computed({
|
||||||
|
get: () => !!toastStore.current,
|
||||||
|
set: (val) => {
|
||||||
|
if (!val) toastStore.shift()
|
||||||
|
}
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
BIN
dashboard/src/assets/images/platform_logos/misskey.png
Normal file
BIN
dashboard/src/assets/images/platform_logos/misskey.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
1376
dashboard/src/components/chat/Chat.vue
Normal file
1376
dashboard/src/components/chat/Chat.vue
Normal file
File diff suppressed because it is too large
Load Diff
774
dashboard/src/components/chat/MessageList.vue
Normal file
774
dashboard/src/components/chat/MessageList.vue
Normal file
@@ -0,0 +1,774 @@
|
|||||||
|
<template>
|
||||||
|
<div class="messages-container" ref="messageContainer">
|
||||||
|
<!-- 聊天消息列表 -->
|
||||||
|
<div class="message-list">
|
||||||
|
<div class="message-item fade-in" v-for="(msg, index) in messages" :key="index">
|
||||||
|
<!-- 用户消息 -->
|
||||||
|
<div v-if="msg.content.type == 'user'" class="user-message">
|
||||||
|
<div class="message-bubble user-bubble" :class="{ 'has-audio': msg.content.audio_url }"
|
||||||
|
:style="{ backgroundColor: isDark ? '#2d2e30' : '#e7ebf4' }">
|
||||||
|
<pre
|
||||||
|
style="font-family: inherit; white-space: pre-wrap; word-wrap: break-word;">{{ msg.content.message }}</pre>
|
||||||
|
|
||||||
|
<!-- 图片附件 -->
|
||||||
|
<div class="image-attachments" v-if="msg.content.image_url && msg.content.image_url.length > 0">
|
||||||
|
<div v-for="(img, index) in msg.content.image_url" :key="index" class="image-attachment">
|
||||||
|
<img :src="img" class="attached-image" @click="$emit('openImagePreview', img)" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 音频附件 -->
|
||||||
|
<div class="audio-attachment" v-if="msg.content.audio_url && msg.content.audio_url.length > 0">
|
||||||
|
<audio controls class="audio-player">
|
||||||
|
<source :src="msg.content.audio_url" type="audio/wav">
|
||||||
|
{{ t('messages.errors.browser.audioNotSupported') }}
|
||||||
|
</audio>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Bot Messages -->
|
||||||
|
<div v-else class="bot-message">
|
||||||
|
|
||||||
|
<v-avatar class="bot-avatar" size="36">
|
||||||
|
<v-progress-circular :index="index" v-if="isStreaming && index === messages.length - 1" indeterminate size="28"
|
||||||
|
width="2"></v-progress-circular>
|
||||||
|
<span v-else-if="messages[index - 1]?.content.type !== 'bot'" class="text-h2">✨</span>
|
||||||
|
</v-avatar>
|
||||||
|
<div class="bot-message-content">
|
||||||
|
<div class="message-bubble bot-bubble">
|
||||||
|
<!-- Text -->
|
||||||
|
<div v-if="msg.content.message && msg.content.message.trim()"
|
||||||
|
v-html="md.render(msg.content.message)" class="markdown-content"></div>
|
||||||
|
|
||||||
|
<!-- Image -->
|
||||||
|
<div class="embedded-images"
|
||||||
|
v-if="msg.content.embedded_images && msg.content.embedded_images.length > 0">
|
||||||
|
<div v-for="(img, imgIndex) in msg.content.embedded_images" :key="imgIndex"
|
||||||
|
class="embedded-image">
|
||||||
|
<img :src="img" class="bot-embedded-image"
|
||||||
|
@click="$emit('openImagePreview', img)" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Audio -->
|
||||||
|
<div class="embedded-audio" v-if="msg.content.embedded_audio">
|
||||||
|
<audio controls class="audio-player">
|
||||||
|
<source :src="msg.content.embedded_audio" type="audio/wav">
|
||||||
|
{{ t('messages.errors.browser.audioNotSupported') }}
|
||||||
|
</audio>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="message-actions">
|
||||||
|
<v-btn :icon="getCopyIcon(index)" size="small" variant="text" class="copy-message-btn"
|
||||||
|
:class="{ 'copy-success': isCopySuccess(index) }"
|
||||||
|
@click="copyBotMessage(msg.content.message, index)" :title="t('core.common.copy')" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||||
|
import MarkdownIt from 'markdown-it';
|
||||||
|
import hljs from 'highlight.js';
|
||||||
|
import 'highlight.js/styles/github.css';
|
||||||
|
|
||||||
|
const md = new MarkdownIt({
|
||||||
|
html: false,
|
||||||
|
breaks: true,
|
||||||
|
linkify: true,
|
||||||
|
highlight: function (code, lang) {
|
||||||
|
if (lang && hljs.getLanguage(lang)) {
|
||||||
|
try {
|
||||||
|
return hljs.highlight(code, { language: lang }).value;
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Highlight error:', err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hljs.highlightAuto(code).value;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
export default {
|
||||||
|
name: 'MessageList',
|
||||||
|
props: {
|
||||||
|
messages: {
|
||||||
|
type: Array,
|
||||||
|
required: true
|
||||||
|
},
|
||||||
|
isDark: {
|
||||||
|
type: Boolean,
|
||||||
|
default: false
|
||||||
|
},
|
||||||
|
isStreaming: {
|
||||||
|
type: Boolean,
|
||||||
|
default: false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
emits: ['openImagePreview'],
|
||||||
|
setup() {
|
||||||
|
const { t } = useI18n();
|
||||||
|
const { tm } = useModuleI18n('features/chat');
|
||||||
|
|
||||||
|
return {
|
||||||
|
t,
|
||||||
|
tm,
|
||||||
|
md
|
||||||
|
};
|
||||||
|
},
|
||||||
|
data() {
|
||||||
|
return {
|
||||||
|
copiedMessages: new Set(),
|
||||||
|
isUserNearBottom: true,
|
||||||
|
scrollThreshold: 1,
|
||||||
|
scrollTimer: null
|
||||||
|
};
|
||||||
|
},
|
||||||
|
mounted() {
|
||||||
|
this.initCodeCopyButtons();
|
||||||
|
this.initImageClickEvents();
|
||||||
|
this.addScrollListener();
|
||||||
|
this.scrollToBottom();
|
||||||
|
},
|
||||||
|
updated() {
|
||||||
|
this.initCodeCopyButtons();
|
||||||
|
this.initImageClickEvents();
|
||||||
|
if (this.isUserNearBottom) {
|
||||||
|
this.scrollToBottom();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
methods: {
|
||||||
|
// 复制代码到剪贴板
|
||||||
|
copyCodeToClipboard(code) {
|
||||||
|
navigator.clipboard.writeText(code).then(() => {
|
||||||
|
console.log('代码已复制到剪贴板');
|
||||||
|
}).catch(err => {
|
||||||
|
console.error('复制失败:', err);
|
||||||
|
// 如果现代API失败,使用传统方法
|
||||||
|
const textArea = document.createElement('textarea');
|
||||||
|
textArea.value = code;
|
||||||
|
document.body.appendChild(textArea);
|
||||||
|
textArea.select();
|
||||||
|
try {
|
||||||
|
document.execCommand('copy');
|
||||||
|
console.log('代码已复制到剪贴板 (fallback)');
|
||||||
|
} catch (fallbackErr) {
|
||||||
|
console.error('复制失败 (fallback):', fallbackErr);
|
||||||
|
}
|
||||||
|
document.body.removeChild(textArea);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
// 复制bot消息到剪贴板
|
||||||
|
copyBotMessage(message, messageIndex) {
|
||||||
|
// 获取对应的消息对象
|
||||||
|
const msgObj = this.messages[messageIndex].content;
|
||||||
|
let textToCopy = '';
|
||||||
|
|
||||||
|
// 如果有文本消息,添加到复制内容中
|
||||||
|
if (message && message.trim()) {
|
||||||
|
// 移除HTML标签,获取纯文本
|
||||||
|
const tempDiv = document.createElement('div');
|
||||||
|
tempDiv.innerHTML = message;
|
||||||
|
textToCopy = tempDiv.textContent || tempDiv.innerText || message;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有内嵌图片,添加说明
|
||||||
|
if (msgObj && msgObj.embedded_images && msgObj.embedded_images.length > 0) {
|
||||||
|
if (textToCopy) textToCopy += '\n\n';
|
||||||
|
textToCopy += `[包含 ${msgObj.embedded_images.length} 张图片]`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有内嵌音频,添加说明
|
||||||
|
if (msgObj && msgObj.embedded_audio) {
|
||||||
|
if (textToCopy) textToCopy += '\n\n';
|
||||||
|
textToCopy += '[包含音频内容]';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有任何内容,使用默认文本
|
||||||
|
if (!textToCopy.trim()) {
|
||||||
|
textToCopy = '[媒体内容]';
|
||||||
|
}
|
||||||
|
|
||||||
|
navigator.clipboard.writeText(textToCopy).then(() => {
|
||||||
|
console.log('消息已复制到剪贴板');
|
||||||
|
this.showCopySuccess(messageIndex);
|
||||||
|
}).catch(err => {
|
||||||
|
console.error('复制失败:', err);
|
||||||
|
// 如果现代API失败,使用传统方法
|
||||||
|
const textArea = document.createElement('textarea');
|
||||||
|
textArea.value = textToCopy;
|
||||||
|
document.body.appendChild(textArea);
|
||||||
|
textArea.select();
|
||||||
|
try {
|
||||||
|
document.execCommand('copy');
|
||||||
|
console.log('消息已复制到剪贴板 (fallback)');
|
||||||
|
this.showCopySuccess(messageIndex);
|
||||||
|
} catch (fallbackErr) {
|
||||||
|
console.error('复制失败 (fallback):', fallbackErr);
|
||||||
|
}
|
||||||
|
document.body.removeChild(textArea);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
// 显示复制成功提示
|
||||||
|
showCopySuccess(messageIndex) {
|
||||||
|
this.copiedMessages.add(messageIndex);
|
||||||
|
|
||||||
|
// 2秒后移除成功状态
|
||||||
|
setTimeout(() => {
|
||||||
|
this.copiedMessages.delete(messageIndex);
|
||||||
|
}, 2000);
|
||||||
|
},
|
||||||
|
|
||||||
|
// 获取复制按钮图标
|
||||||
|
getCopyIcon(messageIndex) {
|
||||||
|
return this.copiedMessages.has(messageIndex) ? 'mdi-check' : 'mdi-content-copy';
|
||||||
|
},
|
||||||
|
|
||||||
|
// 检查是否为复制成功状态
|
||||||
|
isCopySuccess(messageIndex) {
|
||||||
|
return this.copiedMessages.has(messageIndex);
|
||||||
|
},
|
||||||
|
|
||||||
|
// 获取复制图标SVG
|
||||||
|
getCopyIconSvg() {
|
||||||
|
return '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>';
|
||||||
|
},
|
||||||
|
|
||||||
|
// 获取成功图标SVG
|
||||||
|
getSuccessIconSvg() {
|
||||||
|
return '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><polyline points="20,6 9,17 4,12"></polyline></svg>';
|
||||||
|
},
|
||||||
|
|
||||||
|
// 初始化代码块复制按钮
|
||||||
|
initCodeCopyButtons() {
|
||||||
|
this.$nextTick(() => {
|
||||||
|
const codeBlocks = this.$refs.messageContainer?.querySelectorAll('pre code') || [];
|
||||||
|
codeBlocks.forEach((codeBlock, index) => {
|
||||||
|
const pre = codeBlock.parentElement;
|
||||||
|
if (pre && !pre.querySelector('.copy-code-btn')) {
|
||||||
|
const button = document.createElement('button');
|
||||||
|
button.className = 'copy-code-btn';
|
||||||
|
button.innerHTML = this.getCopyIconSvg();
|
||||||
|
button.title = '复制代码';
|
||||||
|
button.addEventListener('click', () => {
|
||||||
|
this.copyCodeToClipboard(codeBlock.textContent);
|
||||||
|
// 显示复制成功提示
|
||||||
|
button.innerHTML = this.getSuccessIconSvg();
|
||||||
|
button.style.color = '#4caf50';
|
||||||
|
setTimeout(() => {
|
||||||
|
button.innerHTML = this.getCopyIconSvg();
|
||||||
|
button.style.color = '';
|
||||||
|
}, 2000);
|
||||||
|
});
|
||||||
|
pre.style.position = 'relative';
|
||||||
|
pre.appendChild(button);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
initImageClickEvents() {
|
||||||
|
this.$nextTick(() => {
|
||||||
|
// 查找所有动态生成的图片(在markdown-content中)
|
||||||
|
const images = document.querySelectorAll('.markdown-content img');
|
||||||
|
images.forEach((img) => {
|
||||||
|
if (!img.hasAttribute('data-click-enabled')) {
|
||||||
|
img.style.cursor = 'pointer';
|
||||||
|
img.setAttribute('data-click-enabled', 'true');
|
||||||
|
img.onclick = () => this.$emit('openImagePreview', img.src);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
scrollToBottom() {
|
||||||
|
this.$nextTick(() => {
|
||||||
|
const container = this.$refs.messageContainer;
|
||||||
|
if (container) {
|
||||||
|
container.scrollTop = container.scrollHeight;
|
||||||
|
this.isUserNearBottom = true; // 程序滚动到底部后标记用户在底部
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
// 添加滚动事件监听器
|
||||||
|
addScrollListener() {
|
||||||
|
const container = this.$refs.messageContainer;
|
||||||
|
if (container) {
|
||||||
|
container.addEventListener('scroll', this.throttledHandleScroll);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// 节流处理滚动事件
|
||||||
|
throttledHandleScroll() {
|
||||||
|
if (this.scrollTimer) return;
|
||||||
|
|
||||||
|
this.scrollTimer = setTimeout(() => {
|
||||||
|
this.handleScroll();
|
||||||
|
this.scrollTimer = null;
|
||||||
|
}, 50); // 50ms 节流
|
||||||
|
},
|
||||||
|
|
||||||
|
// 处理滚动事件
|
||||||
|
handleScroll() {
|
||||||
|
const container = this.$refs.messageContainer;
|
||||||
|
if (container) {
|
||||||
|
const { scrollTop, scrollHeight, clientHeight } = container;
|
||||||
|
const distanceFromBottom = scrollHeight - (scrollTop + clientHeight);
|
||||||
|
|
||||||
|
// 判断用户是否在底部附近
|
||||||
|
this.isUserNearBottom = distanceFromBottom <= this.scrollThreshold;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// 组件销毁时移除监听器
|
||||||
|
beforeUnmount() {
|
||||||
|
const container = this.$refs.messageContainer;
|
||||||
|
if (container) {
|
||||||
|
container.removeEventListener('scroll', this.throttledHandleScroll);
|
||||||
|
}
|
||||||
|
// 清理定时器
|
||||||
|
if (this.scrollTimer) {
|
||||||
|
clearTimeout(this.scrollTimer);
|
||||||
|
this.scrollTimer = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
/* 基础动画 */
|
||||||
|
@keyframes fadeIn {
|
||||||
|
from {
|
||||||
|
opacity: 0;
|
||||||
|
transform: translateY(10px);
|
||||||
|
}
|
||||||
|
|
||||||
|
to {
|
||||||
|
opacity: 1;
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.messages-container {
|
||||||
|
height: 100%;
|
||||||
|
max-height: 100%;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 16px;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
flex: 1;
|
||||||
|
min-height: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 消息列表样式 */
|
||||||
|
.message-list {
|
||||||
|
max-width: 900px;
|
||||||
|
margin: 0 auto;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-item {
|
||||||
|
margin-bottom: 24px;
|
||||||
|
animation: fadeIn 0.3s ease-out;
|
||||||
|
}
|
||||||
|
|
||||||
|
.user-message {
|
||||||
|
display: flex;
|
||||||
|
justify-content: flex-end;
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bot-message {
|
||||||
|
display: flex;
|
||||||
|
justify-content: flex-start;
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bot-message-content {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: flex-start;
|
||||||
|
max-width: 80%;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-actions {
|
||||||
|
display: flex;
|
||||||
|
gap: 4px;
|
||||||
|
opacity: 0;
|
||||||
|
transition: opacity 0.2s ease;
|
||||||
|
margin-left: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bot-message:hover .message-actions {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.copy-message-btn {
|
||||||
|
opacity: 0.6;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
color: var(--v-theme-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.copy-message-btn:hover {
|
||||||
|
opacity: 1;
|
||||||
|
background-color: rgba(103, 58, 183, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.copy-message-btn.copy-success {
|
||||||
|
color: #4caf50;
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.copy-message-btn.copy-success:hover {
|
||||||
|
color: #4caf50;
|
||||||
|
background-color: rgba(76, 175, 80, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-bubble {
|
||||||
|
padding: 2px 16px;
|
||||||
|
border-radius: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.user-bubble {
|
||||||
|
color: var(--v-theme-primaryText);
|
||||||
|
padding: 12px 18px;
|
||||||
|
font-size: 15px;
|
||||||
|
max-width: 60%;
|
||||||
|
border-radius: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bot-bubble {
|
||||||
|
border: 1px solid var(--v-theme-border);
|
||||||
|
color: var(--v-theme-primaryText);
|
||||||
|
font-size: 15px;
|
||||||
|
max-width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.user-avatar,
|
||||||
|
.bot-avatar {
|
||||||
|
align-self: flex-start;
|
||||||
|
margin-top: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 附件样式 */
|
||||||
|
.image-attachments {
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
margin-top: 8px;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.image-attachment {
|
||||||
|
position: relative;
|
||||||
|
display: inline-block;
|
||||||
|
}
|
||||||
|
|
||||||
|
.attached-image {
|
||||||
|
width: 120px;
|
||||||
|
height: 120px;
|
||||||
|
object-fit: cover;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||||
|
transition: transform 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.audio-attachment {
|
||||||
|
margin-top: 8px;
|
||||||
|
min-width: 250px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 包含音频的消息气泡最小宽度 */
|
||||||
|
.message-bubble.has-audio {
|
||||||
|
min-width: 280px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.audio-player {
|
||||||
|
width: 100%;
|
||||||
|
height: 36px;
|
||||||
|
border-radius: 18px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.embedded-images {
|
||||||
|
margin-top: 8px;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.embedded-image {
|
||||||
|
display: flex;
|
||||||
|
justify-content: flex-start;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bot-embedded-image {
|
||||||
|
max-width: 80%;
|
||||||
|
width: auto;
|
||||||
|
height: auto;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||||
|
cursor: pointer;
|
||||||
|
transition: transform 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bot-embedded-image:hover {
|
||||||
|
transform: scale(1.02);
|
||||||
|
}
|
||||||
|
|
||||||
|
.embedded-audio {
|
||||||
|
width: 300px;
|
||||||
|
margin-top: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.embedded-audio .audio-player {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 300px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 动画类 */
|
||||||
|
.fade-in {
|
||||||
|
animation: fadeIn 0.3s ease-in-out;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
/* Markdown内容样式 - 需要全局样式 */
|
||||||
|
.markdown-content {
|
||||||
|
font-family: inherit;
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content h1,
|
||||||
|
.markdown-content h2,
|
||||||
|
.markdown-content h3,
|
||||||
|
.markdown-content h4,
|
||||||
|
.markdown-content h5,
|
||||||
|
.markdown-content h6 {
|
||||||
|
margin-top: 16px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--v-theme-primaryText);
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content h1 {
|
||||||
|
font-size: 1.8em;
|
||||||
|
border-bottom: 1px solid var(--v-theme-border);
|
||||||
|
padding-bottom: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content h2 {
|
||||||
|
font-size: 1.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content h3 {
|
||||||
|
font-size: 1.3em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content li {
|
||||||
|
margin-left: 16px;
|
||||||
|
margin-bottom: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content p {
|
||||||
|
margin-top: .5rem;
|
||||||
|
margin-bottom: .5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content pre {
|
||||||
|
background-color: var(--v-theme-surface);
|
||||||
|
padding: 12px;
|
||||||
|
border-radius: 6px;
|
||||||
|
overflow-x: auto;
|
||||||
|
margin: 12px 0;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content code {
|
||||||
|
background-color: rgb(var(--v-theme-codeBg));
|
||||||
|
padding: 2px 4px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-family: 'Fira Code', monospace;
|
||||||
|
font-size: 0.9em;
|
||||||
|
color: var(--v-theme-code);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 代码块中的code标签样式 */
|
||||||
|
.markdown-content pre code {
|
||||||
|
background-color: transparent;
|
||||||
|
padding: 0;
|
||||||
|
border-radius: 0;
|
||||||
|
font-family: 'Fira Code', 'Consolas', 'Monaco', 'Courier New', monospace;
|
||||||
|
font-size: 0.85em;
|
||||||
|
color: inherit;
|
||||||
|
display: block;
|
||||||
|
overflow-x: auto;
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 自定义代码高亮样式 */
|
||||||
|
.markdown-content pre {
|
||||||
|
border: 1px solid var(--v-theme-border);
|
||||||
|
background-color: rgb(var(--v-theme-preBg));
|
||||||
|
border-radius: 16px;
|
||||||
|
padding: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 确保highlight.js的样式正确应用 */
|
||||||
|
.markdown-content pre code.hljs {
|
||||||
|
background: transparent !important;
|
||||||
|
color: inherit;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 亮色主题下的代码高亮 */
|
||||||
|
.v-theme--light .markdown-content pre {
|
||||||
|
background-color: #f6f8fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 暗色主题下的代码块样式 */
|
||||||
|
.v-theme--dark .markdown-content pre {
|
||||||
|
background-color: #0d1117 !important;
|
||||||
|
border-color: rgba(255, 255, 255, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-theme--dark .markdown-content pre code {
|
||||||
|
color: #e6edf3 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 暗色主题下的highlight.js样式覆盖 */
|
||||||
|
.v-theme--dark .hljs {
|
||||||
|
background: #0d1117 !important;
|
||||||
|
color: #e6edf3 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-theme--dark .hljs-keyword,
|
||||||
|
.v-theme--dark .hljs-selector-tag,
|
||||||
|
.v-theme--dark .hljs-built_in,
|
||||||
|
.v-theme--dark .hljs-name,
|
||||||
|
.v-theme--dark .hljs-tag {
|
||||||
|
color: #ff7b72 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-theme--dark .hljs-string,
|
||||||
|
.v-theme--dark .hljs-title,
|
||||||
|
.v-theme--dark .hljs-section,
|
||||||
|
.v-theme--dark .hljs-attribute,
|
||||||
|
.v-theme--dark .hljs-literal,
|
||||||
|
.v-theme--dark .hljs-template-tag,
|
||||||
|
.v-theme--dark .hljs-template-variable,
|
||||||
|
.v-theme--dark .hljs-type,
|
||||||
|
.v-theme--dark .hljs-addition {
|
||||||
|
color: #a5d6ff !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-theme--dark .hljs-comment,
|
||||||
|
.v-theme--dark .hljs-quote,
|
||||||
|
.v-theme--dark .hljs-deletion,
|
||||||
|
.v-theme--dark .hljs-meta {
|
||||||
|
color: #8b949e !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-theme--dark .hljs-number,
|
||||||
|
.v-theme--dark .hljs-regexp,
|
||||||
|
.v-theme--dark .hljs-symbol,
|
||||||
|
.v-theme--dark .hljs-variable,
|
||||||
|
.v-theme--dark .hljs-template-variable,
|
||||||
|
.v-theme--dark .hljs-link,
|
||||||
|
.v-theme--dark .hljs-selector-attr,
|
||||||
|
.v-theme--dark .hljs-selector-pseudo {
|
||||||
|
color: #79c0ff !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-theme--dark .hljs-function,
|
||||||
|
.v-theme--dark .hljs-class,
|
||||||
|
.v-theme--dark .hljs-title.class_ {
|
||||||
|
color: #d2a8ff !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 复制按钮样式 */
|
||||||
|
.copy-code-btn {
|
||||||
|
position: absolute;
|
||||||
|
top: 8px;
|
||||||
|
right: 8px;
|
||||||
|
background: rgba(255, 255, 255, 0.9);
|
||||||
|
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 6px;
|
||||||
|
cursor: pointer;
|
||||||
|
opacity: 0;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
color: #666;
|
||||||
|
font-size: 12px;
|
||||||
|
z-index: 10;
|
||||||
|
backdrop-filter: blur(4px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.copy-code-btn:hover {
|
||||||
|
background: rgba(255, 255, 255, 1);
|
||||||
|
color: #333;
|
||||||
|
transform: scale(1.05);
|
||||||
|
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
.copy-code-btn:active {
|
||||||
|
transform: scale(0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content pre:hover .copy-code-btn {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-theme--dark .copy-code-btn {
|
||||||
|
background: rgba(45, 45, 45, 0.9);
|
||||||
|
border-color: rgba(255, 255, 255, 0.15);
|
||||||
|
color: #ccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-theme--dark .copy-code-btn:hover {
|
||||||
|
background: rgba(45, 45, 45, 1);
|
||||||
|
color: #fff;
|
||||||
|
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content img {
|
||||||
|
max-width: 100%;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content blockquote {
|
||||||
|
border-left: 4px solid var(--v-theme-secondary);
|
||||||
|
padding-left: 16px;
|
||||||
|
color: var(--v-theme-secondaryText);
|
||||||
|
margin: 16px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content table {
|
||||||
|
border-collapse: collapse;
|
||||||
|
width: 100%;
|
||||||
|
margin: 16px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content th,
|
||||||
|
.markdown-content td {
|
||||||
|
border: 1px solid var(--v-theme-background);
|
||||||
|
padding: 8px 12px;
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-content th {
|
||||||
|
background-color: var(--v-theme-containerBg);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -1,21 +1,11 @@
|
|||||||
<template>
|
<template>
|
||||||
<div>
|
<div>
|
||||||
<!-- 选择提供商和模型按钮 -->
|
<!-- 选择提供商和模型按钮 -->
|
||||||
<v-btn
|
<v-btn class="text-none" variant="tonal" rounded="xl" size="small"
|
||||||
class="text-none"
|
v-if="selectedProviderId && selectedModelName" @click="openDialog">
|
||||||
variant="tonal"
|
|
||||||
rounded="xl"
|
|
||||||
size="small"
|
|
||||||
v-if="selectedProviderId && selectedModelName"
|
|
||||||
@click="showDialog = true">
|
|
||||||
{{ selectedProviderId }} / {{ selectedModelName }}
|
{{ selectedProviderId }} / {{ selectedModelName }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn
|
<v-btn variant="tonal" rounded="xl" size="small" v-else @click="openDialog">
|
||||||
variant="tonal"
|
|
||||||
rounded="xl"
|
|
||||||
size="small"
|
|
||||||
v-else
|
|
||||||
@click="showDialog = true">
|
|
||||||
选择模型
|
选择模型
|
||||||
</v-btn>
|
</v-btn>
|
||||||
|
|
||||||
@@ -33,16 +23,12 @@
|
|||||||
<h4>提供商</h4>
|
<h4>提供商</h4>
|
||||||
</div>
|
</div>
|
||||||
<v-list density="compact" nav class="provider-list">
|
<v-list density="compact" nav class="provider-list">
|
||||||
<v-list-item
|
<v-list-item v-for="provider in providerConfigs" :key="provider.id" :value="provider.id"
|
||||||
v-for="provider in providerConfigs"
|
@click="selectProvider(provider)" :active="tempSelectedProviderId === provider.id"
|
||||||
:key="provider.id"
|
rounded="lg" class="provider-item">
|
||||||
:value="provider.id"
|
|
||||||
@click="selectProvider(provider)"
|
|
||||||
:active="selectedProviderId === provider.id"
|
|
||||||
rounded="lg"
|
|
||||||
class="provider-item">
|
|
||||||
<v-list-item-title>{{ provider.id }}</v-list-item-title>
|
<v-list-item-title>{{ provider.id }}</v-list-item-title>
|
||||||
<v-list-item-subtitle v-if="provider.api_base">{{ provider.api_base }}</v-list-item-subtitle>
|
<v-list-item-subtitle v-if="provider.api_base">{{ provider.api_base
|
||||||
|
}}</v-list-item-subtitle>
|
||||||
</v-list-item>
|
</v-list-item>
|
||||||
</v-list>
|
</v-list>
|
||||||
<div v-if="providerConfigs.length === 0" class="empty-state">
|
<div v-if="providerConfigs.length === 0" class="empty-state">
|
||||||
@@ -55,33 +41,28 @@
|
|||||||
<div class="model-list-panel">
|
<div class="model-list-panel">
|
||||||
<div class="panel-header">
|
<div class="panel-header">
|
||||||
<h4>模型</h4>
|
<h4>模型</h4>
|
||||||
<v-btn
|
<v-btn v-if="tempSelectedProviderId" icon="mdi-refresh" size="small" variant="text"
|
||||||
v-if="selectedProviderId"
|
@click="refreshModels" :loading="loadingModels">
|
||||||
icon="mdi-refresh"
|
|
||||||
size="small"
|
|
||||||
variant="text"
|
|
||||||
@click="refreshModels"
|
|
||||||
:loading="loadingModels">
|
|
||||||
</v-btn>
|
</v-btn>
|
||||||
</div>
|
</div>
|
||||||
<v-list density="compact" nav class="model-list" v-if="selectedProviderId">
|
<v-list density="compact" nav class="model-list" v-if="tempSelectedProviderId">
|
||||||
<v-list-item
|
|
||||||
v-for="model in modelList"
|
<v-text-field v-model="tempSelectedModelName" placeholder="自定义模型" hide-details solo variant="outlined" density="compact" class="mb-2 mx-2"></v-text-field>
|
||||||
:key="model"
|
|
||||||
:value="model"
|
<v-list-item v-for="model in modelList" :key="model" :value="model"
|
||||||
@click="selectModel(model)"
|
@click="selectModel(model)" :active="tempSelectedModelName === model" rounded="lg"
|
||||||
:active="selectedModelName === model"
|
|
||||||
rounded="lg"
|
|
||||||
class="model-item">
|
class="model-item">
|
||||||
<v-list-item-title>{{ model }}</v-list-item-title>
|
<v-list-item-title>{{ model }}</v-list-item-title>
|
||||||
<v-list-item-subtitle v-if="model.description">{{ model.description }}</v-list-item-subtitle>
|
<v-list-item-subtitle v-if="model.description">{{ model.description
|
||||||
|
}}</v-list-item-subtitle>
|
||||||
</v-list-item>
|
</v-list-item>
|
||||||
</v-list>
|
</v-list>
|
||||||
<div v-else class="empty-state">
|
<div v-else class="empty-state">
|
||||||
<v-icon icon="mdi-robot-outline" size="large" color="grey-lighten-1"></v-icon>
|
<v-icon icon="mdi-robot-outline" size="large" color="grey-lighten-1"></v-icon>
|
||||||
<div class="empty-text">请先选择提供商</div>
|
<div class="empty-text">请先选择提供商</div>
|
||||||
</div>
|
</div>
|
||||||
<div v-if="selectedProviderId && modelList.length === 0 && !loadingModels" class="empty-state">
|
<div v-if="tempSelectedProviderId && modelList.length === 0 && !loadingModels"
|
||||||
|
class="empty-state">
|
||||||
<v-icon icon="mdi-robot-off-outline" size="large" color="grey-lighten-1"></v-icon>
|
<v-icon icon="mdi-robot-off-outline" size="large" color="grey-lighten-1"></v-icon>
|
||||||
<div class="empty-text">该提供商暂无可用模型</div>
|
<div class="empty-text">该提供商暂无可用模型</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -91,11 +72,8 @@
|
|||||||
<v-card-actions>
|
<v-card-actions>
|
||||||
<v-spacer></v-spacer>
|
<v-spacer></v-spacer>
|
||||||
<v-btn text @click="closeDialog" color="grey-darken-1">取消</v-btn>
|
<v-btn text @click="closeDialog" color="grey-darken-1">取消</v-btn>
|
||||||
<v-btn
|
<v-btn text @click="confirmSelection" color="primary"
|
||||||
text
|
:disabled="!tempSelectedProviderId || !tempSelectedModelName">
|
||||||
@click="confirmSelection"
|
|
||||||
color="primary"
|
|
||||||
:disabled="!selectedProviderId || !selectedModelName">
|
|
||||||
确认选择
|
确认选择
|
||||||
</v-btn>
|
</v-btn>
|
||||||
</v-card-actions>
|
</v-card-actions>
|
||||||
@@ -127,12 +105,17 @@ export default {
|
|||||||
modelList: [],
|
modelList: [],
|
||||||
selectedProviderId: '',
|
selectedProviderId: '',
|
||||||
selectedModelName: '',
|
selectedModelName: '',
|
||||||
|
// 临时选择状态,用于对话框内的选择
|
||||||
|
tempSelectedProviderId: '',
|
||||||
|
tempSelectedModelName: '',
|
||||||
loadingModels: false
|
loadingModels: false
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
mounted() {
|
mounted() {
|
||||||
// 从localStorage加载保存的选择
|
// 从localStorage加载保存的选择
|
||||||
this.loadFromStorage();
|
this.loadFromStorage();
|
||||||
|
// 初始化临时选择
|
||||||
|
this.resetTempSelection();
|
||||||
// 获取提供商列表
|
// 获取提供商列表
|
||||||
this.loadProviderConfigs();
|
this.loadProviderConfigs();
|
||||||
// 如果有保存的选择,加载对应的模型列表
|
// 如果有保存的选择,加载对应的模型列表
|
||||||
@@ -215,27 +198,31 @@ export default {
|
|||||||
|
|
||||||
// 选择提供商
|
// 选择提供商
|
||||||
selectProvider(provider) {
|
selectProvider(provider) {
|
||||||
this.selectedProviderId = provider.id;
|
this.tempSelectedProviderId = provider.id;
|
||||||
this.selectedModelName = ''; // 清空已选择的模型
|
this.tempSelectedModelName = ''; // 清空已选择的模型
|
||||||
this.modelList = []; // 清空模型列表
|
this.modelList = []; // 清空模型列表
|
||||||
this.getProviderModels(provider.id); // 获取该提供商的模型列表
|
this.getProviderModels(provider.id); // 获取该提供商的模型列表
|
||||||
},
|
},
|
||||||
|
|
||||||
// 选择模型
|
// 选择模型
|
||||||
selectModel(model) {
|
selectModel(model) {
|
||||||
this.selectedModelName = model;
|
this.tempSelectedModelName = model;
|
||||||
},
|
},
|
||||||
|
|
||||||
// 刷新模型列表
|
// 刷新模型列表
|
||||||
refreshModels() {
|
refreshModels() {
|
||||||
if (this.selectedProviderId) {
|
if (this.tempSelectedProviderId) {
|
||||||
this.getProviderModels(this.selectedProviderId);
|
this.getProviderModels(this.tempSelectedProviderId);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
// 确认选择
|
// 确认选择
|
||||||
confirmSelection() {
|
confirmSelection() {
|
||||||
if (this.selectedProviderId && this.selectedModelName) {
|
if (this.tempSelectedProviderId && this.tempSelectedModelName) {
|
||||||
|
// 将临时选择应用到正式选择
|
||||||
|
this.selectedProviderId = this.tempSelectedProviderId;
|
||||||
|
this.selectedModelName = this.tempSelectedModelName;
|
||||||
|
|
||||||
// 保存到localStorage
|
// 保存到localStorage
|
||||||
this.saveToStorage();
|
this.saveToStorage();
|
||||||
|
|
||||||
@@ -252,6 +239,24 @@ export default {
|
|||||||
// 关闭对话框
|
// 关闭对话框
|
||||||
closeDialog() {
|
closeDialog() {
|
||||||
this.showDialog = false;
|
this.showDialog = false;
|
||||||
|
// 重置临时选择为当前选择
|
||||||
|
this.resetTempSelection();
|
||||||
|
},
|
||||||
|
|
||||||
|
// 重置临时选择
|
||||||
|
resetTempSelection() {
|
||||||
|
this.tempSelectedProviderId = this.selectedProviderId;
|
||||||
|
this.tempSelectedModelName = this.selectedModelName;
|
||||||
|
// 如果有临时选择的提供商,重新加载模型列表
|
||||||
|
if (this.tempSelectedProviderId) {
|
||||||
|
this.getProviderModels(this.tempSelectedProviderId);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// 打开对话框
|
||||||
|
openDialog() {
|
||||||
|
this.resetTempSelection();
|
||||||
|
this.showDialog = true;
|
||||||
},
|
},
|
||||||
|
|
||||||
// 公开方法:获取当前选择
|
// 公开方法:获取当前选择
|
||||||
|
|||||||
173
dashboard/src/components/platform/AddNewPlatform.vue
Normal file
173
dashboard/src/components/platform/AddNewPlatform.vue
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
<template>
|
||||||
|
<v-dialog v-model="showDialog" max-width="900px" min-height="80%">
|
||||||
|
<v-card class="platform-selection-dialog" :title="tm('dialog.addPlatform')">
|
||||||
|
<v-card-text class="pa-4" style="overflow-y: auto;">
|
||||||
|
<v-row style="padding: 0px 8px;">
|
||||||
|
<v-col v-for="(template, name) in platformTemplates"
|
||||||
|
:key="name" cols="12" sm="6" md="6">
|
||||||
|
<v-card variant="outlined" hover class="platform-card" @click="selectTemplate(name)">
|
||||||
|
<div class="platform-card-content">
|
||||||
|
<div class="platform-card-text">
|
||||||
|
<v-card-title class="platform-card-title">{{ tm('dialog.connectTitle', { name }) }}</v-card-title>
|
||||||
|
<v-card-text class="text-caption text-medium-emphasis platform-card-description">
|
||||||
|
{{ getPlatformDescription(template, name) }}
|
||||||
|
</v-card-text>
|
||||||
|
</div>
|
||||||
|
<div class="platform-card-logo">
|
||||||
|
<img :src="getPlatformIcon(template.type)" v-if="getPlatformIcon(template.type)" class="platform-logo-img">
|
||||||
|
<div v-else class="platform-logo-fallback">
|
||||||
|
{{ name[0].toUpperCase() }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</v-card>
|
||||||
|
</v-col>
|
||||||
|
<v-col
|
||||||
|
v-if="Object.keys(platformTemplates).length === 0"
|
||||||
|
cols="12">
|
||||||
|
<v-alert type="info" variant="tonal">
|
||||||
|
{{ tm('dialog.noTemplates') }}
|
||||||
|
</v-alert>
|
||||||
|
</v-col>
|
||||||
|
</v-row>
|
||||||
|
</v-card-text>
|
||||||
|
<v-card-actions>
|
||||||
|
<v-spacer></v-spacer>
|
||||||
|
<v-btn text @click="closeDialog">{{ tm('dialog.cancel') }}</v-btn>
|
||||||
|
</v-card-actions>
|
||||||
|
</v-card>
|
||||||
|
</v-dialog>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
|
import { getPlatformIcon, getPlatformDescription } from '@/utils/platformUtils';
|
||||||
|
|
||||||
|
export default {
|
||||||
|
name: 'AddNewPlatform',
|
||||||
|
emits: ['update:show', 'select-template'],
|
||||||
|
props: {
|
||||||
|
show: {
|
||||||
|
type: Boolean,
|
||||||
|
default: false
|
||||||
|
},
|
||||||
|
metadata: {
|
||||||
|
type: Object,
|
||||||
|
default: () => ({})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
setup() {
|
||||||
|
const { tm } = useModuleI18n('features/platform');
|
||||||
|
return { tm };
|
||||||
|
},
|
||||||
|
computed: {
|
||||||
|
showDialog: {
|
||||||
|
get() {
|
||||||
|
return this.show;
|
||||||
|
},
|
||||||
|
set(value) {
|
||||||
|
this.$emit('update:show', value);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
platformTemplates() {
|
||||||
|
return this.metadata['platform_group']?.metadata?.platform?.config_template || {};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
methods: {
|
||||||
|
// 从工具函数导入
|
||||||
|
getPlatformIcon,
|
||||||
|
getPlatformDescription,
|
||||||
|
|
||||||
|
selectTemplate(name) {
|
||||||
|
this.$emit('select-template', name);
|
||||||
|
this.closeDialog();
|
||||||
|
},
|
||||||
|
|
||||||
|
closeDialog() {
|
||||||
|
this.showDialog = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.platform-selection-dialog .v-card-title {
|
||||||
|
border-top-left-radius: 4px;
|
||||||
|
border-top-right-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-card {
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
height: 100%;
|
||||||
|
cursor: pointer;
|
||||||
|
overflow: hidden;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-card:hover {
|
||||||
|
transform: translateY(-4px);
|
||||||
|
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
|
||||||
|
border-color: var(--v-primary-base);
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-card-content {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
height: 100px;
|
||||||
|
padding: 16px;
|
||||||
|
position: relative;
|
||||||
|
z-index: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-card-text {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-card-title {
|
||||||
|
font-size: 15px;
|
||||||
|
font-weight: 600;
|
||||||
|
margin-bottom: 4px;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-card-description {
|
||||||
|
padding: 0;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-card-logo {
|
||||||
|
position: absolute;
|
||||||
|
right: 0;
|
||||||
|
top: 0;
|
||||||
|
bottom: 0;
|
||||||
|
width: 80px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-logo-img {
|
||||||
|
max-width: 60px;
|
||||||
|
max-height: 60px;
|
||||||
|
opacity: 0.6;
|
||||||
|
object-fit: contain;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platform-logo-fallback {
|
||||||
|
width: 50px;
|
||||||
|
height: 50px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: var(--v-primary-base);
|
||||||
|
color: white;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
font-size: 24px;
|
||||||
|
font-weight: bold;
|
||||||
|
opacity: 0.3;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
237
dashboard/src/components/provider/AddNewProvider.vue
Normal file
237
dashboard/src/components/provider/AddNewProvider.vue
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
<template>
|
||||||
|
<v-dialog v-model="showDialog" max-width="1100px" min-height="95%">
|
||||||
|
<v-card :title="tm('dialogs.addProvider.title')">
|
||||||
|
<v-card-text style="overflow-y: auto;">
|
||||||
|
<v-tabs v-model="activeProviderTab" grow>
|
||||||
|
<v-tab value="chat_completion" class="font-weight-medium px-3">
|
||||||
|
<v-icon start>mdi-message-text</v-icon>
|
||||||
|
{{ tm('dialogs.addProvider.tabs.basic') }}
|
||||||
|
</v-tab>
|
||||||
|
<v-tab value="speech_to_text" class="font-weight-medium px-3">
|
||||||
|
<v-icon start>mdi-microphone-message</v-icon>
|
||||||
|
{{ tm('dialogs.addProvider.tabs.speechToText') }}
|
||||||
|
</v-tab>
|
||||||
|
<v-tab value="text_to_speech" class="font-weight-medium px-3">
|
||||||
|
<v-icon start>mdi-volume-high</v-icon>
|
||||||
|
{{ tm('dialogs.addProvider.tabs.textToSpeech') }}
|
||||||
|
</v-tab>
|
||||||
|
<v-tab value="embedding" class="font-weight-medium px-3">
|
||||||
|
<v-icon start>mdi-code-json</v-icon>
|
||||||
|
{{ tm('dialogs.addProvider.tabs.embedding') }}
|
||||||
|
</v-tab>
|
||||||
|
<v-tab value="rerank" class="font-weight-medium px-3">
|
||||||
|
<v-icon start>mdi-compare-vertical</v-icon>
|
||||||
|
{{ tm('dialogs.addProvider.tabs.rerank') }}
|
||||||
|
</v-tab>
|
||||||
|
</v-tabs>
|
||||||
|
|
||||||
|
<v-window v-model="activeProviderTab" class="mt-4">
|
||||||
|
<v-window-item
|
||||||
|
v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
|
||||||
|
:key="tabType" :value="tabType">
|
||||||
|
<v-row class="mt-1">
|
||||||
|
<v-col v-for="(template, name) in getTemplatesByType(tabType)" :key="name" cols="12" sm="6"
|
||||||
|
md="4">
|
||||||
|
<v-card variant="outlined" hover class="provider-card"
|
||||||
|
@click="selectProviderTemplate(name)">
|
||||||
|
<div class="provider-card-content">
|
||||||
|
<div class="provider-card-text">
|
||||||
|
<v-card-title class="provider-card-title">接入 {{ name }}</v-card-title>
|
||||||
|
<v-card-text
|
||||||
|
class="text-caption text-medium-emphasis provider-card-description">
|
||||||
|
{{ getProviderDescription(template, name) }}
|
||||||
|
</v-card-text>
|
||||||
|
</div>
|
||||||
|
<div class="provider-card-logo">
|
||||||
|
<img :src="getProviderIcon(template.provider)"
|
||||||
|
v-if="getProviderIcon(template.provider)" class="provider-logo-img">
|
||||||
|
<div v-else class="provider-logo-fallback">
|
||||||
|
{{ name[0].toUpperCase() }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</v-card>
|
||||||
|
</v-col>
|
||||||
|
<v-col v-if="Object.keys(getTemplatesByType(tabType)).length === 0" cols="12">
|
||||||
|
<v-alert type="info" variant="tonal">
|
||||||
|
{{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
|
||||||
|
</v-alert>
|
||||||
|
</v-col>
|
||||||
|
</v-row>
|
||||||
|
</v-window-item>
|
||||||
|
</v-window>
|
||||||
|
</v-card-text>
|
||||||
|
<v-card-actions>
|
||||||
|
<v-spacer></v-spacer>
|
||||||
|
<v-btn text @click="closeDialog">{{ tm('dialogs.config.cancel') }}</v-btn>
|
||||||
|
</v-card-actions>
|
||||||
|
</v-card>
|
||||||
|
</v-dialog>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
|
import { getProviderIcon, getProviderDescription } from '@/utils/providerUtils';
|
||||||
|
|
||||||
|
export default {
|
||||||
|
name: 'AddNewProvider',
|
||||||
|
props: {
|
||||||
|
show: {
|
||||||
|
type: Boolean,
|
||||||
|
default: false
|
||||||
|
},
|
||||||
|
metadata: {
|
||||||
|
type: Object,
|
||||||
|
default: () => ({})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
emits: ['update:show', 'select-template'],
|
||||||
|
setup() {
|
||||||
|
const { tm } = useModuleI18n('features/provider');
|
||||||
|
return { tm };
|
||||||
|
},
|
||||||
|
data() {
|
||||||
|
return {
|
||||||
|
activeProviderTab: 'chat_completion'
|
||||||
|
};
|
||||||
|
},
|
||||||
|
computed: {
|
||||||
|
showDialog: {
|
||||||
|
get() {
|
||||||
|
return this.show;
|
||||||
|
},
|
||||||
|
set(value) {
|
||||||
|
this.$emit('update:show', value);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// 翻译消息的计算属性
|
||||||
|
messages() {
|
||||||
|
return {
|
||||||
|
tabTypes: {
|
||||||
|
'chat_completion': this.tm('providers.tabs.chatCompletion'),
|
||||||
|
'speech_to_text': this.tm('providers.tabs.speechToText'),
|
||||||
|
'text_to_speech': this.tm('providers.tabs.textToSpeech'),
|
||||||
|
'embedding': this.tm('providers.tabs.embedding'),
|
||||||
|
'rerank': this.tm('providers.tabs.rerank')
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
methods: {
|
||||||
|
closeDialog() {
|
||||||
|
this.showDialog = false;
|
||||||
|
},
|
||||||
|
|
||||||
|
// 按提供商类型获取模板列表
|
||||||
|
getTemplatesByType(type) {
|
||||||
|
const templates = this.metadata['provider_group']?.metadata?.provider?.config_template || {};
|
||||||
|
const filtered = {};
|
||||||
|
|
||||||
|
for (const [name, template] of Object.entries(templates)) {
|
||||||
|
if (template.provider_type === type) {
|
||||||
|
filtered[name] = template;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filtered;
|
||||||
|
},
|
||||||
|
|
||||||
|
// 从工具函数导入
|
||||||
|
getProviderIcon,
|
||||||
|
|
||||||
|
// 获取Tab类型的中文名称
|
||||||
|
getTabTypeName(tabType) {
|
||||||
|
return this.messages.tabTypes[tabType] || tabType;
|
||||||
|
},
|
||||||
|
|
||||||
|
// 获取提供商简介
|
||||||
|
getProviderDescription(template, name) {
|
||||||
|
return getProviderDescription(template, name, this.tm);
|
||||||
|
},
|
||||||
|
|
||||||
|
// 选择提供商模板
|
||||||
|
selectProviderTemplate(name) {
|
||||||
|
this.$emit('select-template', name);
|
||||||
|
this.closeDialog();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.provider-card {
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
height: 100%;
|
||||||
|
cursor: pointer;
|
||||||
|
overflow: hidden;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-card:hover {
|
||||||
|
transform: translateY(-4px);
|
||||||
|
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
|
||||||
|
border-color: var(--v-primary-base);
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-card-content {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
height: 100px;
|
||||||
|
padding: 16px;
|
||||||
|
position: relative;
|
||||||
|
z-index: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-card-text {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-card-title {
|
||||||
|
font-size: 15px;
|
||||||
|
font-weight: 600;
|
||||||
|
margin-bottom: 4px;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-card-description {
|
||||||
|
padding: 0;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-card-logo {
|
||||||
|
position: absolute;
|
||||||
|
right: 0;
|
||||||
|
top: 0;
|
||||||
|
bottom: 0;
|
||||||
|
width: 80px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-logo-img {
|
||||||
|
width: 60px;
|
||||||
|
height: 60px;
|
||||||
|
opacity: 0.6;
|
||||||
|
object-fit: contain;
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-logo-fallback {
|
||||||
|
width: 50px;
|
||||||
|
height: 50px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: var(--v-primary-base);
|
||||||
|
color: white;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
font-size: 24px;
|
||||||
|
font-weight: bold;
|
||||||
|
opacity: 0.3;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -40,6 +40,15 @@
|
|||||||
>
|
>
|
||||||
{{ t('core.common.itemCard.edit') }}
|
{{ t('core.common.itemCard.edit') }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
|
<v-btn
|
||||||
|
v-if="showCopyButton"
|
||||||
|
variant="tonal"
|
||||||
|
color="secondary"
|
||||||
|
rounded="xl"
|
||||||
|
@click="$emit('copy', item)"
|
||||||
|
>
|
||||||
|
{{ t('core.common.itemCard.copy') }}
|
||||||
|
</v-btn>
|
||||||
<v-spacer></v-spacer>
|
<v-spacer></v-spacer>
|
||||||
</v-card-actions>
|
</v-card-actions>
|
||||||
|
|
||||||
@@ -83,9 +92,13 @@ export default {
|
|||||||
loading: {
|
loading: {
|
||||||
type: Boolean,
|
type: Boolean,
|
||||||
default: false
|
default: false
|
||||||
|
},
|
||||||
|
showCopyButton: {
|
||||||
|
type: Boolean,
|
||||||
|
default: false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
emits: ['toggle-enabled', 'delete', 'edit'],
|
emits: ['toggle-enabled', 'delete', 'edit', 'copy'],
|
||||||
methods: {
|
methods: {
|
||||||
getItemTitle() {
|
getItemTitle() {
|
||||||
return this.item[this.titleField];
|
return this.item[this.titleField];
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
<script setup lang="ts">
|
|
||||||
const props = defineProps({
|
|
||||||
title: String
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
|
|
||||||
<template>
|
|
||||||
<v-card variant="outlined" elevation="0" class="withbg">
|
|
||||||
<v-card-item>
|
|
||||||
<div class="d-sm-flex align-center justify-space-between">
|
|
||||||
<v-card-title>{{ props.title }}</v-card-title>
|
|
||||||
<slot name="action"></slot>
|
|
||||||
</div>
|
|
||||||
</v-card-item>
|
|
||||||
<v-divider></v-divider>
|
|
||||||
<v-card-text>
|
|
||||||
<slot />
|
|
||||||
</v-card-text>
|
|
||||||
</v-card>
|
|
||||||
</template>
|
|
||||||
@@ -12,6 +12,13 @@
|
|||||||
"title": "Conversation History",
|
"title": "Conversation History",
|
||||||
"refresh": "Refresh"
|
"refresh": "Refresh"
|
||||||
},
|
},
|
||||||
|
"batch": {
|
||||||
|
"deleteSelected": "Delete Selected ({count})"
|
||||||
|
},
|
||||||
|
"pagination": {
|
||||||
|
"itemsPerPage": "Items per page",
|
||||||
|
"showingItems": "Showing {start}-{end} of {total} items"
|
||||||
|
},
|
||||||
"table": {
|
"table": {
|
||||||
"headers": {
|
"headers": {
|
||||||
"title": "Conversation Title",
|
"title": "Conversation Title",
|
||||||
@@ -61,6 +68,13 @@
|
|||||||
"message": "Are you sure you want to delete conversation {title}? This action cannot be undone.",
|
"message": "Are you sure you want to delete conversation {title}? This action cannot be undone.",
|
||||||
"cancel": "Cancel",
|
"cancel": "Cancel",
|
||||||
"confirm": "Delete"
|
"confirm": "Delete"
|
||||||
|
},
|
||||||
|
"batchDelete": {
|
||||||
|
"title": "Batch Delete Confirmation",
|
||||||
|
"message": "Are you sure you want to delete the selected {count} conversations? This action cannot be undone, please proceed with caution!",
|
||||||
|
"andMore": "and {count} more",
|
||||||
|
"cancel": "Cancel",
|
||||||
|
"confirm": "Batch Delete"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"messages": {
|
"messages": {
|
||||||
@@ -72,6 +86,10 @@
|
|||||||
"historyError": "Failed to fetch conversation history",
|
"historyError": "Failed to fetch conversation history",
|
||||||
"historySaveSuccess": "Conversation history saved successfully",
|
"historySaveSuccess": "Conversation history saved successfully",
|
||||||
"historySaveError": "Failed to save conversation history",
|
"historySaveError": "Failed to save conversation history",
|
||||||
"invalidJson": "Invalid JSON format"
|
"invalidJson": "Invalid JSON format",
|
||||||
|
"noItemSelected": "Please select conversations to delete first",
|
||||||
|
"batchDeleteSuccess": "Successfully deleted {count} conversations",
|
||||||
|
"batchDeleteError": "Batch delete failed",
|
||||||
|
"batchDeletePartial": "Delete completed: {deleted} successful, {failed} failed"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -7,7 +7,8 @@
|
|||||||
"apply": "Apply Batch Settings",
|
"apply": "Apply Batch Settings",
|
||||||
"editName": "Edit Session Name",
|
"editName": "Edit Session Name",
|
||||||
"save": "Save",
|
"save": "Save",
|
||||||
"cancel": "Cancel"
|
"cancel": "Cancel",
|
||||||
|
"delete": "Delete"
|
||||||
},
|
},
|
||||||
"sessions": {
|
"sessions": {
|
||||||
"activeSessions": "Active Sessions",
|
"activeSessions": "Active Sessions",
|
||||||
@@ -29,7 +30,8 @@
|
|||||||
"ttsProvider": "TTS Provider",
|
"ttsProvider": "TTS Provider",
|
||||||
"llmStatus": "LLM Status",
|
"llmStatus": "LLM Status",
|
||||||
"ttsStatus": "TTS Status",
|
"ttsStatus": "TTS Status",
|
||||||
"pluginManagement": "Plugin Management"
|
"pluginManagement": "Plugin Management",
|
||||||
|
"actions": "Actions"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"status": {
|
"status": {
|
||||||
@@ -65,6 +67,10 @@
|
|||||||
"fullSessionId": "Full Session ID",
|
"fullSessionId": "Full Session ID",
|
||||||
"hint": "Custom names help you easily identify sessions. The small information icon (!) will show the actual UMO when hovering."
|
"hint": "Custom names help you easily identify sessions. The small information icon (!) will show the actual UMO when hovering."
|
||||||
},
|
},
|
||||||
|
"deleteConfirm": {
|
||||||
|
"message": "Are you sure you want to delete session {sessionName}?",
|
||||||
|
"warning": "This action will permanently delete all chat history and preference settings for this session (except for data linked via plugins), and this cannot be undone. Continue?"
|
||||||
|
},
|
||||||
"messages": {
|
"messages": {
|
||||||
"refreshSuccess": "Session list refreshed",
|
"refreshSuccess": "Session list refreshed",
|
||||||
"personaUpdateSuccess": "Persona updated successfully",
|
"personaUpdateSuccess": "Persona updated successfully",
|
||||||
@@ -82,6 +88,8 @@
|
|||||||
"pluginStatusSuccess": "Plugin {name} {status}",
|
"pluginStatusSuccess": "Plugin {name} {status}",
|
||||||
"pluginStatusError": "Failed to update plugin status",
|
"pluginStatusError": "Failed to update plugin status",
|
||||||
"nameUpdateSuccess": "Session name updated successfully",
|
"nameUpdateSuccess": "Session name updated successfully",
|
||||||
"nameUpdateError": "Failed to update session name"
|
"nameUpdateError": "Failed to update session name",
|
||||||
|
"deleteSuccess": "Session deleted successfully",
|
||||||
|
"deleteError": "Failed to delete session"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,6 +73,7 @@
|
|||||||
"disabled": "已禁用",
|
"disabled": "已禁用",
|
||||||
"delete": "删除",
|
"delete": "删除",
|
||||||
"edit": "编辑",
|
"edit": "编辑",
|
||||||
|
"copy": "复制",
|
||||||
"noData": "暂无数据"
|
"noData": "暂无数据"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
"subtitle": "管理和查看用户对话历史记录",
|
"subtitle": "管理和查看用户对话历史记录",
|
||||||
"filters": {
|
"filters": {
|
||||||
"title": "筛选条件",
|
"title": "筛选条件",
|
||||||
"platform": "平台",
|
"platform": "消息平台 ID",
|
||||||
"type": "类型",
|
"type": "类型",
|
||||||
"search": "搜索关键词",
|
"search": "搜索关键词",
|
||||||
"reset": "重置"
|
"reset": "重置"
|
||||||
@@ -12,12 +12,19 @@
|
|||||||
"title": "对话历史",
|
"title": "对话历史",
|
||||||
"refresh": "刷新"
|
"refresh": "刷新"
|
||||||
},
|
},
|
||||||
|
"batch": {
|
||||||
|
"deleteSelected": "删除选中 ({count})"
|
||||||
|
},
|
||||||
|
"pagination": {
|
||||||
|
"itemsPerPage": "每页",
|
||||||
|
"showingItems": "显示 {start}-{end} 项,共 {total} 项"
|
||||||
|
},
|
||||||
"table": {
|
"table": {
|
||||||
"headers": {
|
"headers": {
|
||||||
"title": "对话标题",
|
"title": "对话标题",
|
||||||
"platform": "平台",
|
"platform": "消息平台 ID",
|
||||||
"type": "类型",
|
"type": "类型",
|
||||||
"sessionId": "ID",
|
"sessionId": "ID (UMO)",
|
||||||
"createdAt": "创建时间",
|
"createdAt": "创建时间",
|
||||||
"updatedAt": "更新时间",
|
"updatedAt": "更新时间",
|
||||||
"actions": "操作"
|
"actions": "操作"
|
||||||
@@ -61,6 +68,13 @@
|
|||||||
"message": "确定要删除对话 {title} 吗?此操作不可恢复。",
|
"message": "确定要删除对话 {title} 吗?此操作不可恢复。",
|
||||||
"cancel": "取消",
|
"cancel": "取消",
|
||||||
"confirm": "删除"
|
"confirm": "删除"
|
||||||
|
},
|
||||||
|
"batchDelete": {
|
||||||
|
"title": "批量删除确认",
|
||||||
|
"message": "确定要删除选中的 {count} 个对话吗?此操作不可恢复,请谨慎操作!",
|
||||||
|
"andMore": "等 {count} 个",
|
||||||
|
"cancel": "取消",
|
||||||
|
"confirm": "批量删除"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"messages": {
|
"messages": {
|
||||||
@@ -72,6 +86,10 @@
|
|||||||
"historyError": "获取对话历史失败",
|
"historyError": "获取对话历史失败",
|
||||||
"historySaveSuccess": "对话历史保存成功",
|
"historySaveSuccess": "对话历史保存成功",
|
||||||
"historySaveError": "对话历史保存失败",
|
"historySaveError": "对话历史保存失败",
|
||||||
"invalidJson": "JSON格式无效"
|
"invalidJson": "JSON格式无效",
|
||||||
|
"noItemSelected": "请先选择要删除的对话",
|
||||||
|
"batchDeleteSuccess": "成功删除 {count} 个对话",
|
||||||
|
"batchDeleteError": "批量删除失败",
|
||||||
|
"batchDeletePartial": "删除完成:成功 {deleted} 个,失败 {failed} 个"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -5,9 +5,10 @@
|
|||||||
"refresh": "刷新",
|
"refresh": "刷新",
|
||||||
"edit": "编辑",
|
"edit": "编辑",
|
||||||
"apply": "应用批量设置",
|
"apply": "应用批量设置",
|
||||||
"editName": "编辑会话名称",
|
"editName": "备注",
|
||||||
"save": "保存",
|
"save": "保存",
|
||||||
"cancel": "取消"
|
"cancel": "取消",
|
||||||
|
"delete": "删除"
|
||||||
},
|
},
|
||||||
"sessions": {
|
"sessions": {
|
||||||
"activeSessions": "活跃会话",
|
"activeSessions": "活跃会话",
|
||||||
@@ -22,14 +23,15 @@
|
|||||||
"table": {
|
"table": {
|
||||||
"headers": {
|
"headers": {
|
||||||
"sessionStatus": "会话状态",
|
"sessionStatus": "会话状态",
|
||||||
"sessionInfo": "会话信息",
|
"sessionInfo": "ID (UMO)",
|
||||||
"persona": "人格",
|
"persona": "人格",
|
||||||
"chatProvider": "Chat Provider",
|
"chatProvider": "聊天模型",
|
||||||
"sttProvider": "STT Provider",
|
"sttProvider": "语音识别模型",
|
||||||
"ttsProvider": "TTS Provider",
|
"ttsProvider": "语音合成模型",
|
||||||
"llmStatus": "LLM启停",
|
"llmStatus": "启用 LLM",
|
||||||
"ttsStatus": "TTS启停",
|
"ttsStatus": "启用 TTS",
|
||||||
"pluginManagement": "插件管理"
|
"pluginManagement": "插件管理",
|
||||||
|
"actions": "操作"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"status": {
|
"status": {
|
||||||
@@ -65,6 +67,10 @@
|
|||||||
"fullSessionId": "完整会话ID",
|
"fullSessionId": "完整会话ID",
|
||||||
"hint": "自定义名称帮助您轻松识别会话。当设置了自定义名称时,会显示一个小感叹号标识(!),鼠标悬停时会显示实际的UMO。"
|
"hint": "自定义名称帮助您轻松识别会话。当设置了自定义名称时,会显示一个小感叹号标识(!),鼠标悬停时会显示实际的UMO。"
|
||||||
},
|
},
|
||||||
|
"deleteConfirm": {
|
||||||
|
"message": "确定要删除会话 {sessionName} 吗?",
|
||||||
|
"warning": "此操作将永久删除本次会话的「全部对话记录」与「偏好设置」(插件对会话的关联数据除外),且无法恢复。确认继续?"
|
||||||
|
},
|
||||||
"messages": {
|
"messages": {
|
||||||
"refreshSuccess": "会话列表已刷新",
|
"refreshSuccess": "会话列表已刷新",
|
||||||
"personaUpdateSuccess": "人格更新成功",
|
"personaUpdateSuccess": "人格更新成功",
|
||||||
@@ -82,6 +88,8 @@
|
|||||||
"pluginStatusSuccess": "插件 {name} {status}",
|
"pluginStatusSuccess": "插件 {name} {status}",
|
||||||
"pluginStatusError": "插件状态更新失败",
|
"pluginStatusError": "插件状态更新失败",
|
||||||
"nameUpdateSuccess": "会话名称更新成功",
|
"nameUpdateSuccess": "会话名称更新成功",
|
||||||
"nameUpdateError": "会话名称更新失败"
|
"nameUpdateError": "会话名称更新失败",
|
||||||
|
"deleteSuccess": "会话删除成功",
|
||||||
|
"deleteError": "会话删除失败"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import {ref, computed} from 'vue';
|
import { ref, computed } from 'vue';
|
||||||
import {useCustomizerStore} from '@/stores/customizer';
|
import { useCustomizerStore } from '@/stores/customizer';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
import Logo from '@/components/shared/Logo.vue';
|
import Logo from '@/components/shared/Logo.vue';
|
||||||
import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue';
|
import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue';
|
||||||
import {md5} from 'js-md5';
|
import { md5 } from 'js-md5';
|
||||||
import {useAuthStore} from '@/stores/auth';
|
import { useAuthStore } from '@/stores/auth';
|
||||||
import {useCommonStore} from '@/stores/common';
|
import { useCommonStore } from '@/stores/common';
|
||||||
import MarkdownIt from 'markdown-it';
|
import MarkdownIt from 'markdown-it';
|
||||||
import { useI18n } from '@/i18n/composables';
|
import { useI18n } from '@/i18n/composables';
|
||||||
import { router } from '@/router';
|
import { router } from '@/router';
|
||||||
@@ -44,11 +44,11 @@ let installLoading = ref(false);
|
|||||||
let tab = ref(0);
|
let tab = ref(0);
|
||||||
|
|
||||||
const releasesHeader = computed(() => [
|
const releasesHeader = computed(() => [
|
||||||
{title: t('core.header.updateDialog.table.tag'), key: 'tag_name'},
|
{ title: t('core.header.updateDialog.table.tag'), key: 'tag_name' },
|
||||||
{title: t('core.header.updateDialog.table.publishDate'), key: 'published_at'},
|
{ title: t('core.header.updateDialog.table.publishDate'), key: 'published_at' },
|
||||||
{title: t('core.header.updateDialog.table.content'), key: 'body'},
|
{ title: t('core.header.updateDialog.table.content'), key: 'body' },
|
||||||
{title: t('core.header.updateDialog.table.sourceUrl'), key: 'zipball_url'},
|
{ title: t('core.header.updateDialog.table.sourceUrl'), key: 'zipball_url' },
|
||||||
{title: t('core.header.updateDialog.table.actions'), key: 'switch'}
|
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Form validation
|
// Form validation
|
||||||
@@ -291,16 +291,19 @@ commonStore.getStartTime();
|
|||||||
<template>
|
<template>
|
||||||
<v-app-bar elevation="0" height="55">
|
<v-app-bar elevation="0" height="55">
|
||||||
|
|
||||||
<v-btn v-if="useCustomizerStore().uiTheme==='PurpleTheme'" style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
|
<v-btn v-if="useCustomizerStore().uiTheme === 'PurpleTheme'" style="margin-left: 22px;"
|
||||||
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm" variant="flat"
|
||||||
|
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||||
<v-icon>mdi-menu</v-icon>
|
<v-icon>mdi-menu</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn v-else style="margin-left: 22px; color: var(--v-theme-primaryText); background-color: var(--v-theme-secondary)" class="hidden-md-and-down" icon rounded="sm"
|
<v-btn v-else
|
||||||
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
style="margin-left: 22px; color: var(--v-theme-primaryText); background-color: var(--v-theme-secondary)"
|
||||||
|
class="hidden-md-and-down" icon rounded="sm" variant="flat"
|
||||||
|
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||||
<v-icon>mdi-menu</v-icon>
|
<v-icon>mdi-menu</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn v-if="useCustomizerStore().uiTheme==='PurpleTheme'" class="hidden-lg-and-up ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
|
<v-btn v-if="useCustomizerStore().uiTheme === 'PurpleTheme'" class="hidden-lg-and-up ms-3" color="lightsecondary"
|
||||||
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
|
icon rounded="sm" variant="flat" @click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
|
||||||
<v-icon>mdi-menu</v-icon>
|
<v-icon>mdi-menu</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn v-else class="hidden-lg-and-up ms-3" icon rounded="sm" variant="flat"
|
<v-btn v-else class="hidden-lg-and-up ms-3" icon rounded="sm" variant="flat"
|
||||||
@@ -308,12 +311,12 @@ commonStore.getStartTime();
|
|||||||
<v-icon>mdi-menu</v-icon>
|
<v-icon>mdi-menu</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
|
|
||||||
<div class="logo-container" :class="{'mobile-logo': $vuetify.display.xs}" @click="$router.push('/about')">
|
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs }" @click="$router.push('/about')">
|
||||||
<span class="logo-text">Astr<span class="logo-text-light">Bot</span></span>
|
<span class="logo-text">Astr<span class="logo-text-light">Bot</span></span>
|
||||||
<span class="version-text hidden-xs">{{ botCurrVersion }}</span>
|
<span class="version-text hidden-xs">{{ botCurrVersion }}</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<v-spacer/>
|
<v-spacer />
|
||||||
|
|
||||||
<!-- 版本提示信息 - 在手机上隐藏 -->
|
<!-- 版本提示信息 - 在手机上隐藏 -->
|
||||||
<div class="mr-4 hidden-xs">
|
<div class="mr-4 hidden-xs">
|
||||||
@@ -329,19 +332,19 @@ commonStore.getStartTime();
|
|||||||
<LanguageSwitcher variant="header" />
|
<LanguageSwitcher variant="header" />
|
||||||
|
|
||||||
<!-- 主题切换按钮 -->
|
<!-- 主题切换按钮 -->
|
||||||
<v-btn size="small" @click="toggleDarkMode();" class="action-btn"
|
<v-btn size="small" @click="toggleDarkMode();" class="action-btn" color="var(--v-theme-surface)" variant="flat"
|
||||||
color="var(--v-theme-surface)" variant="flat" rounded="sm">
|
rounded="sm" icon>
|
||||||
<v-icon v-if="useCustomizerStore().uiTheme === 'PurpleThemeDark'">mdi-weather-night</v-icon>
|
<v-icon v-if="useCustomizerStore().uiTheme === 'PurpleThemeDark'">mdi-weather-night</v-icon>
|
||||||
<v-icon v-else>mdi-white-balance-sunny</v-icon>
|
<v-icon v-else>mdi-white-balance-sunny</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
|
|
||||||
<!-- 更新对话框 -->
|
<!-- 更新对话框 -->
|
||||||
<v-dialog v-model="updateStatusDialog" :width="$vuetify.display.smAndDown ? '100%' : '1200'" :fullscreen="$vuetify.display.xs">
|
<v-dialog v-model="updateStatusDialog" :width="$vuetify.display.smAndDown ? '100%' : '1200'"
|
||||||
|
:fullscreen="$vuetify.display.xs">
|
||||||
<template v-slot:activator="{ props }">
|
<template v-slot:activator="{ props }">
|
||||||
<v-btn size="small" @click="checkUpdate(); getReleases(); getDevCommits();" class="action-btn"
|
<v-btn size="small" @click="checkUpdate(); getReleases(); getDevCommits();" class="action-btn"
|
||||||
color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props">
|
color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props" icon>
|
||||||
<v-icon class="hidden-sm-and-up">mdi-update</v-icon>
|
<v-icon>mdi-arrow-up-circle</v-icon>
|
||||||
<span class="hidden-xs">{{ t('core.header.buttons.update') }}</span>
|
|
||||||
</v-btn>
|
</v-btn>
|
||||||
</template>
|
</template>
|
||||||
<v-card>
|
<v-card>
|
||||||
@@ -380,15 +383,13 @@ commonStore.getStartTime();
|
|||||||
<v-tabs-window-item key="0" v-show="tab == 0">
|
<v-tabs-window-item key="0" v-show="tab == 0">
|
||||||
<div class="mb-4">
|
<div class="mb-4">
|
||||||
<small>{{ t('core.header.updateDialog.dockerTip') }} <a
|
<small>{{ t('core.header.updateDialog.dockerTip') }} <a
|
||||||
href="https://containrrr.dev/watchtower/usage-overview/">{{ t('core.header.updateDialog.dockerTipLink') }}</a> {{ t('core.header.updateDialog.dockerTipContinue') }}</small>
|
href="https://containrrr.dev/watchtower/usage-overview/">{{
|
||||||
|
t('core.header.updateDialog.dockerTipLink')
|
||||||
|
}}</a> {{ t('core.header.updateDialog.dockerTipContinue') }}</small>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<v-alert
|
<v-alert v-if="releases.some(item => isPreRelease(item['tag_name']))" type="warning" variant="tonal"
|
||||||
v-if="releases.some(item => isPreRelease(item['tag_name']))"
|
border="start">
|
||||||
type="warning"
|
|
||||||
variant="tonal"
|
|
||||||
border="start"
|
|
||||||
>
|
|
||||||
<template v-slot:prepend>
|
<template v-slot:prepend>
|
||||||
<v-icon>mdi-alert-circle-outline</v-icon>
|
<v-icon>mdi-alert-circle-outline</v-icon>
|
||||||
</template>
|
</template>
|
||||||
@@ -406,13 +407,8 @@ commonStore.getStartTime();
|
|||||||
<template v-slot:item.tag_name="{ item }: { item: { tag_name: string } }">
|
<template v-slot:item.tag_name="{ item }: { item: { tag_name: string } }">
|
||||||
<div class="d-flex align-center">
|
<div class="d-flex align-center">
|
||||||
<span>{{ item.tag_name }}</span>
|
<span>{{ item.tag_name }}</span>
|
||||||
<v-chip
|
<v-chip v-if="isPreRelease(item.tag_name)" size="x-small" color="warning" variant="tonal"
|
||||||
v-if="isPreRelease(item.tag_name)"
|
class="ml-2">
|
||||||
size="x-small"
|
|
||||||
color="warning"
|
|
||||||
variant="tonal"
|
|
||||||
class="ml-2"
|
|
||||||
>
|
|
||||||
{{ t('core.header.updateDialog.preRelease') }}
|
{{ t('core.header.updateDialog.preRelease') }}
|
||||||
</v-chip>
|
</v-chip>
|
||||||
</div>
|
</div>
|
||||||
@@ -420,7 +416,8 @@ commonStore.getStartTime();
|
|||||||
<template v-slot:item.body="{ item }: { item: { body: string } }">
|
<template v-slot:item.body="{ item }: { item: { body: string } }">
|
||||||
<v-tooltip :text="item.body">
|
<v-tooltip :text="item.body">
|
||||||
<template v-slot:activator="{ props }">
|
<template v-slot:activator="{ props }">
|
||||||
<v-btn v-bind="props" rounded="xl" variant="tonal" color="primary" size="x-small">{{ t('core.header.updateDialog.table.view') }}</v-btn>
|
<v-btn v-bind="props" rounded="xl" variant="tonal" color="primary" size="x-small">{{
|
||||||
|
t('core.header.updateDialog.table.view') }}</v-btn>
|
||||||
</template>
|
</template>
|
||||||
</v-tooltip>
|
</v-tooltip>
|
||||||
</template>
|
</template>
|
||||||
@@ -435,14 +432,12 @@ commonStore.getStartTime();
|
|||||||
<!-- 开发版 -->
|
<!-- 开发版 -->
|
||||||
<v-tabs-window-item key="1" v-show="tab == 1">
|
<v-tabs-window-item key="1" v-show="tab == 1">
|
||||||
<div style="margin-top: 16px;">
|
<div style="margin-top: 16px;">
|
||||||
<v-data-table
|
<v-data-table :headers="[
|
||||||
:headers="[
|
|
||||||
{ title: t('core.header.updateDialog.table.sha'), key: 'sha' },
|
{ title: t('core.header.updateDialog.table.sha'), key: 'sha' },
|
||||||
{ title: t('core.header.updateDialog.table.date'), key: 'date' },
|
{ title: t('core.header.updateDialog.table.date'), key: 'date' },
|
||||||
{ title: t('core.header.updateDialog.table.message'), key: 'message' },
|
{ title: t('core.header.updateDialog.table.message'), key: 'message' },
|
||||||
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
|
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
|
||||||
]"
|
]" :items="devCommits" item-key="sha">
|
||||||
:items="devCommits" item-key="sha">
|
|
||||||
<template v-slot:item.switch="{ item }: { item: { sha: string } }">
|
<template v-slot:item.switch="{ item }: { item: { sha: string } }">
|
||||||
<v-btn @click="switchVersion(item.sha)" rounded="xl" variant="plain" color="primary">
|
<v-btn @click="switchVersion(item.sha)" rounded="xl" variant="plain" color="primary">
|
||||||
{{ t('core.header.updateDialog.table.switch') }}
|
{{ t('core.header.updateDialog.table.switch') }}
|
||||||
@@ -461,7 +456,8 @@ commonStore.getStartTime();
|
|||||||
<div class="mb-4">
|
<div class="mb-4">
|
||||||
<small>{{ t('core.header.updateDialog.manualInput.hint') }}</small>
|
<small>{{ t('core.header.updateDialog.manualInput.hint') }}</small>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>{{ t('core.header.updateDialog.manualInput.linkText') }}</small></a>
|
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>{{
|
||||||
|
t('core.header.updateDialog.manualInput.linkText') }}</small></a>
|
||||||
</div>
|
</div>
|
||||||
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
|
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
|
||||||
{{ t('core.header.updateDialog.manualInput.confirm') }}
|
{{ t('core.header.updateDialog.manualInput.confirm') }}
|
||||||
@@ -471,7 +467,8 @@ commonStore.getStartTime();
|
|||||||
<div style="margin-top: 16px;">
|
<div style="margin-top: 16px;">
|
||||||
<h3 class="mb-4">{{ t('core.header.updateDialog.dashboardUpdate.title') }}</h3>
|
<h3 class="mb-4">{{ t('core.header.updateDialog.dashboardUpdate.title') }}</h3>
|
||||||
<div class="mb-4">
|
<div class="mb-4">
|
||||||
<small>{{ t('core.header.updateDialog.dashboardUpdate.currentVersion') }} {{ dashboardCurrentVersion }}</small>
|
<small>{{ t('core.header.updateDialog.dashboardUpdate.currentVersion') }} {{ dashboardCurrentVersion
|
||||||
|
}}</small>
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
@@ -504,9 +501,9 @@ commonStore.getStartTime();
|
|||||||
<!-- 账户对话框 -->
|
<!-- 账户对话框 -->
|
||||||
<v-dialog v-model="dialog" persistent :max-width="$vuetify.display.xs ? '90%' : '500'">
|
<v-dialog v-model="dialog" persistent :max-width="$vuetify.display.xs ? '90%' : '500'">
|
||||||
<template v-slot:activator="{ props }">
|
<template v-slot:activator="{ props }">
|
||||||
<v-btn size="small" class="action-btn mr-4" color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props">
|
<v-btn size="small" class="action-btn mr-4" color="var(--v-theme-surface)" variant="flat" rounded="sm"
|
||||||
|
v-bind="props" icon>
|
||||||
<v-icon>mdi-account</v-icon>
|
<v-icon>mdi-account</v-icon>
|
||||||
<span class="hidden-xs ml-1">{{ t('core.header.buttons.account') }}</span>
|
|
||||||
</v-btn>
|
</v-btn>
|
||||||
</template>
|
</template>
|
||||||
<v-card class="account-dialog">
|
<v-card class="account-dialog">
|
||||||
@@ -514,78 +511,34 @@ commonStore.getStartTime();
|
|||||||
<div class="d-flex flex-column align-center mb-6">
|
<div class="d-flex flex-column align-center mb-6">
|
||||||
<logo :title="t('core.header.logoTitle')" :subtitle="t('core.header.accountDialog.title')"></logo>
|
<logo :title="t('core.header.logoTitle')" :subtitle="t('core.header.accountDialog.title')"></logo>
|
||||||
</div>
|
</div>
|
||||||
<v-alert
|
<v-alert v-if="accountWarning" type="warning" variant="tonal" border="start" class="mb-4">
|
||||||
v-if="accountWarning"
|
|
||||||
type="warning"
|
|
||||||
variant="tonal"
|
|
||||||
border="start"
|
|
||||||
class="mb-4"
|
|
||||||
>
|
|
||||||
<strong>{{ t('core.header.accountDialog.securityWarning') }}</strong>
|
<strong>{{ t('core.header.accountDialog.securityWarning') }}</strong>
|
||||||
</v-alert>
|
</v-alert>
|
||||||
|
|
||||||
<v-alert
|
<v-alert v-if="accountEditStatus.success" type="success" variant="tonal" border="start" class="mb-4">
|
||||||
v-if="accountEditStatus.success"
|
|
||||||
type="success"
|
|
||||||
variant="tonal"
|
|
||||||
border="start"
|
|
||||||
class="mb-4"
|
|
||||||
>
|
|
||||||
{{ accountEditStatus.message }}
|
{{ accountEditStatus.message }}
|
||||||
</v-alert>
|
</v-alert>
|
||||||
|
|
||||||
<v-alert
|
<v-alert v-if="accountEditStatus.error" type="error" variant="tonal" border="start" class="mb-4">
|
||||||
v-if="accountEditStatus.error"
|
|
||||||
type="error"
|
|
||||||
variant="tonal"
|
|
||||||
border="start"
|
|
||||||
class="mb-4"
|
|
||||||
>
|
|
||||||
{{ accountEditStatus.message }}
|
{{ accountEditStatus.message }}
|
||||||
</v-alert>
|
</v-alert>
|
||||||
|
|
||||||
<v-form v-model="formValid" @submit.prevent="accountEdit">
|
<v-form v-model="formValid" @submit.prevent="accountEdit">
|
||||||
<v-text-field
|
<v-text-field v-model="password" :append-inner-icon="showPassword ? 'mdi-eye-off' : 'mdi-eye'"
|
||||||
v-model="password"
|
:type="showPassword ? 'text' : 'password'" :label="t('core.header.accountDialog.form.currentPassword')"
|
||||||
:append-inner-icon="showPassword ? 'mdi-eye-off' : 'mdi-eye'"
|
variant="outlined" required clearable @click:append-inner="showPassword = !showPassword"
|
||||||
:type="showPassword ? 'text' : 'password'"
|
prepend-inner-icon="mdi-lock-outline" hide-details="auto" class="mb-4"></v-text-field>
|
||||||
:label="t('core.header.accountDialog.form.currentPassword')"
|
|
||||||
variant="outlined"
|
|
||||||
required
|
|
||||||
clearable
|
|
||||||
@click:append-inner="showPassword = !showPassword"
|
|
||||||
prepend-inner-icon="mdi-lock-outline"
|
|
||||||
hide-details="auto"
|
|
||||||
class="mb-4"
|
|
||||||
></v-text-field>
|
|
||||||
|
|
||||||
<v-text-field
|
<v-text-field v-model="newPassword" :append-inner-icon="showNewPassword ? 'mdi-eye-off' : 'mdi-eye'"
|
||||||
v-model="newPassword"
|
:type="showNewPassword ? 'text' : 'password'" :rules="passwordRules"
|
||||||
:append-inner-icon="showNewPassword ? 'mdi-eye-off' : 'mdi-eye'"
|
:label="t('core.header.accountDialog.form.newPassword')" variant="outlined" required clearable
|
||||||
:type="showNewPassword ? 'text' : 'password'"
|
@click:append-inner="showNewPassword = !showNewPassword" prepend-inner-icon="mdi-lock-plus-outline"
|
||||||
:rules="passwordRules"
|
:hint="t('core.header.accountDialog.form.passwordHint')" persistent-hint class="mb-4"></v-text-field>
|
||||||
:label="t('core.header.accountDialog.form.newPassword')"
|
|
||||||
variant="outlined"
|
|
||||||
required
|
|
||||||
clearable
|
|
||||||
@click:append-inner="showNewPassword = !showNewPassword"
|
|
||||||
prepend-inner-icon="mdi-lock-plus-outline"
|
|
||||||
:hint="t('core.header.accountDialog.form.passwordHint')"
|
|
||||||
persistent-hint
|
|
||||||
class="mb-4"
|
|
||||||
></v-text-field>
|
|
||||||
|
|
||||||
<v-text-field
|
<v-text-field v-model="newUsername" :rules="usernameRules"
|
||||||
v-model="newUsername"
|
:label="t('core.header.accountDialog.form.newUsername')" variant="outlined" clearable
|
||||||
:rules="usernameRules"
|
prepend-inner-icon="mdi-account-edit-outline" :hint="t('core.header.accountDialog.form.usernameHint')"
|
||||||
:label="t('core.header.accountDialog.form.newUsername')"
|
persistent-hint class="mb-3"></v-text-field>
|
||||||
variant="outlined"
|
|
||||||
clearable
|
|
||||||
prepend-inner-icon="mdi-account-edit-outline"
|
|
||||||
:hint="t('core.header.accountDialog.form.usernameHint')"
|
|
||||||
persistent-hint
|
|
||||||
class="mb-3"
|
|
||||||
></v-text-field>
|
|
||||||
</v-form>
|
</v-form>
|
||||||
|
|
||||||
<div class="text-caption text-medium-emphasis mt-2">
|
<div class="text-caption text-medium-emphasis mt-2">
|
||||||
@@ -597,22 +550,12 @@ commonStore.getStartTime();
|
|||||||
|
|
||||||
<v-card-actions class="pa-4">
|
<v-card-actions class="pa-4">
|
||||||
<v-spacer></v-spacer>
|
<v-spacer></v-spacer>
|
||||||
<v-btn
|
<v-btn v-if="!accountWarning" variant="tonal" color="secondary" @click="dialog = false"
|
||||||
v-if="!accountWarning"
|
:disabled="accountEditStatus.loading">
|
||||||
variant="tonal"
|
|
||||||
color="secondary"
|
|
||||||
@click="dialog = false"
|
|
||||||
:disabled="accountEditStatus.loading"
|
|
||||||
>
|
|
||||||
{{ t('core.header.accountDialog.actions.cancel') }}
|
{{ t('core.header.accountDialog.actions.cancel') }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn
|
<v-btn color="primary" @click="accountEdit" :loading="accountEditStatus.loading" :disabled="!formValid"
|
||||||
color="primary"
|
prepend-icon="mdi-content-save">
|
||||||
@click="accountEdit"
|
|
||||||
:loading="accountEditStatus.loading"
|
|
||||||
:disabled="!formValid"
|
|
||||||
prepend-icon="mdi-content-save"
|
|
||||||
>
|
|
||||||
{{ t('core.header.accountDialog.actions.save') }}
|
{{ t('core.header.accountDialog.actions.save') }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
</v-card-actions>
|
</v-card-actions>
|
||||||
|
|||||||
31
dashboard/src/stores/toast.js
Normal file
31
dashboard/src/stores/toast.js
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import { defineStore } from 'pinia'
|
||||||
|
import { ref, computed } from 'vue'
|
||||||
|
|
||||||
|
export const useToastStore = defineStore('toast', () => {
|
||||||
|
const queue = ref([])
|
||||||
|
const current = computed(() => queue.value[0])
|
||||||
|
|
||||||
|
function add({
|
||||||
|
message,
|
||||||
|
color = 'info', // Vuetify 颜色
|
||||||
|
timeout = 3000,
|
||||||
|
closable = true,
|
||||||
|
multiLine = false,
|
||||||
|
location = 'top center'
|
||||||
|
}) {
|
||||||
|
queue.value.push({
|
||||||
|
message,
|
||||||
|
color,
|
||||||
|
timeout,
|
||||||
|
closable,
|
||||||
|
multiLine,
|
||||||
|
location
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function shift() {
|
||||||
|
queue.value.shift()
|
||||||
|
}
|
||||||
|
|
||||||
|
return { current, add, shift }
|
||||||
|
})
|
||||||
78
dashboard/src/utils/platformUtils.js
Normal file
78
dashboard/src/utils/platformUtils.js
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
/**
|
||||||
|
* 平台相关工具函数
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取平台图标
|
||||||
|
* @param {string} name - 平台名称或类型
|
||||||
|
* @returns {string|undefined} 图标URL
|
||||||
|
*/
|
||||||
|
export function getPlatformIcon(name) {
|
||||||
|
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
|
||||||
|
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
||||||
|
} else if (name === 'wecom') {
|
||||||
|
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
||||||
|
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
||||||
|
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
||||||
|
} else if (name === 'lark') {
|
||||||
|
return new URL('@/assets/images/platform_logos/lark.png', import.meta.url).href
|
||||||
|
} else if (name === 'dingtalk') {
|
||||||
|
return new URL('@/assets/images/platform_logos/dingtalk.svg', import.meta.url).href
|
||||||
|
} else if (name === 'telegram') {
|
||||||
|
return new URL('@/assets/images/platform_logos/telegram.svg', import.meta.url).href
|
||||||
|
} else if (name === 'discord') {
|
||||||
|
return new URL('@/assets/images/platform_logos/discord.svg', import.meta.url).href
|
||||||
|
} else if (name === 'slack') {
|
||||||
|
return new URL('@/assets/images/platform_logos/slack.svg', import.meta.url).href
|
||||||
|
} else if (name === 'kook') {
|
||||||
|
return new URL('@/assets/images/platform_logos/kook.png', import.meta.url).href
|
||||||
|
} else if (name === 'vocechat') {
|
||||||
|
return new URL('@/assets/images/platform_logos/vocechat.png', import.meta.url).href
|
||||||
|
} else if (name === 'satori' || name === 'Satori') {
|
||||||
|
return new URL('@/assets/images/platform_logos/satori.png', import.meta.url).href
|
||||||
|
} else if (name === 'misskey') {
|
||||||
|
return new URL('@/assets/images/platform_logos/misskey.png', import.meta.url).href
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取平台教程链接
|
||||||
|
* @param {string} platformType - 平台类型
|
||||||
|
* @returns {string} 教程链接
|
||||||
|
*/
|
||||||
|
export function getTutorialLink(platformType) {
|
||||||
|
const tutorialMap = {
|
||||||
|
"qq_official_webhook": "https://docs.astrbot.app/deploy/platform/qqofficial/webhook.html",
|
||||||
|
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
|
||||||
|
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
||||||
|
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
|
||||||
|
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
|
||||||
|
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
|
||||||
|
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
|
||||||
|
"wechatpadpro": "https://docs.astrbot.app/deploy/platform/wechat/wechatpadpro.html",
|
||||||
|
"weixin_official_account": "https://docs.astrbot.app/deploy/platform/weixin-official-account.html",
|
||||||
|
"discord": "https://docs.astrbot.app/deploy/platform/discord.html",
|
||||||
|
"slack": "https://docs.astrbot.app/deploy/platform/slack.html",
|
||||||
|
"kook": "https://docs.astrbot.app/deploy/platform/kook.html",
|
||||||
|
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
|
||||||
|
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
|
||||||
|
"misskey": "https://docs.astrbot.app/deploy/platform/misskey.html",
|
||||||
|
}
|
||||||
|
return tutorialMap[platformType] || "https://docs.astrbot.app";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取平台描述
|
||||||
|
* @param {Object} template - 平台模板
|
||||||
|
* @param {string} name - 平台名称
|
||||||
|
* @returns {string} 平台描述
|
||||||
|
*/
|
||||||
|
export function getPlatformDescription(template, name) {
|
||||||
|
// special judge for community platforms
|
||||||
|
if (name.includes('vocechat')) {
|
||||||
|
return "由 @HikariFroya 提供。";
|
||||||
|
} else if (name.includes('kook')) {
|
||||||
|
return "由 @wuyan1003 提供。"
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
}
|
||||||
52
dashboard/src/utils/providerUtils.js
Normal file
52
dashboard/src/utils/providerUtils.js
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
/**
|
||||||
|
* 提供商相关的工具函数
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取提供商类型对应的图标
|
||||||
|
* @param {string} type - 提供商类型
|
||||||
|
* @returns {string} 图标 URL
|
||||||
|
*/
|
||||||
|
export function getProviderIcon(type) {
|
||||||
|
const icons = {
|
||||||
|
'openai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
|
||||||
|
'azure': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/azure.svg',
|
||||||
|
'xai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/xai.svg',
|
||||||
|
'anthropic': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/anthropic.svg',
|
||||||
|
'ollama': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ollama.svg',
|
||||||
|
'google': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
|
||||||
|
'deepseek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
|
||||||
|
'modelscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/modelscope.svg',
|
||||||
|
'zhipu': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
|
||||||
|
'siliconflow': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
|
||||||
|
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
|
||||||
|
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
|
||||||
|
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
|
||||||
|
"coze": "https://registry.npmmirror.com/@lobehub/icons-static-svg/1.66.0/files/icons/coze.svg",
|
||||||
|
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
|
||||||
|
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
|
||||||
|
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
|
||||||
|
'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',
|
||||||
|
'minimax': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/minimax.svg',
|
||||||
|
'302ai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
|
||||||
|
'microsoft': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/microsoft.svg',
|
||||||
|
'vllm': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/vllm.svg',
|
||||||
|
};
|
||||||
|
return icons[type] || '';
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取提供商简介
|
||||||
|
* @param {Object} template - 模板对象
|
||||||
|
* @param {string} name - 提供商名称
|
||||||
|
* @param {Function} tm - 翻译函数
|
||||||
|
* @returns {string} 提供商描述
|
||||||
|
*/
|
||||||
|
export function getProviderDescription(template, name, tm) {
|
||||||
|
if (name == 'OpenAI') {
|
||||||
|
return tm('providers.description.openai', { type: template.type });
|
||||||
|
} else if (name == 'vLLM Rerank') {
|
||||||
|
return tm('providers.description.vllm_rerank', { type: template.type });
|
||||||
|
}
|
||||||
|
return tm('providers.description.default', { type: template.type });
|
||||||
|
}
|
||||||
16
dashboard/src/utils/toast.js
Normal file
16
dashboard/src/utils/toast.js
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import { useToastStore } from '@/stores/toast'
|
||||||
|
|
||||||
|
export function useToast() {
|
||||||
|
const store = useToastStore()
|
||||||
|
|
||||||
|
const toast = (message, color = 'info', opts = {}) =>
|
||||||
|
store.add({ message, color, ...opts })
|
||||||
|
|
||||||
|
return {
|
||||||
|
toast,
|
||||||
|
success: (msg, opts) => toast(msg, 'success', opts),
|
||||||
|
error: (msg, opts) => toast(msg, 'error', opts),
|
||||||
|
info: (msg, opts) => toast(msg, 'primary', opts),
|
||||||
|
warning: (msg, opts) => toast(msg, 'warning', opts)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
<script setup>
|
<script setup>
|
||||||
import ChatPage from './ChatPage.vue';
|
import Chat from '@/components/chat/Chat.vue'
|
||||||
import { useCustomizerStore } from '@/stores/customizer';
|
import { useCustomizerStore } from '@/stores/customizer';
|
||||||
const customizer = useCustomizerStore();
|
const customizer = useCustomizerStore();
|
||||||
</script>
|
</script>
|
||||||
@@ -9,7 +9,7 @@ const customizer = useCustomizerStore();
|
|||||||
<div
|
<div
|
||||||
style="height: 100%; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
|
style="height: 100%; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
|
||||||
<div id="container">
|
<div id="container">
|
||||||
<ChatPage :chatbox-mode="true"></ChatPage>
|
<Chat :chatbox-mode="true"></Chat>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</v-app>
|
</v-app>
|
||||||
@@ -18,24 +18,6 @@ const customizer = useCustomizerStore();
|
|||||||
<style scoped>
|
<style scoped>
|
||||||
#container {
|
#container {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
height: 100%;
|
height: 100vh;
|
||||||
}
|
|
||||||
|
|
||||||
@media (min-width: 768px) {
|
|
||||||
#container {
|
|
||||||
min-width: 600px;
|
|
||||||
min-height: 370px;
|
|
||||||
max-width: 1100px;
|
|
||||||
max-height: 860px;
|
|
||||||
padding: 36px;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@media (max-width: 767px) {
|
|
||||||
#container {
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,50 +1,30 @@
|
|||||||
<template>
|
<template>
|
||||||
<div class="conversation-page">
|
<div class="conversation-page">
|
||||||
<v-container fluid class="pa-0">
|
<v-container fluid class="pa-0">
|
||||||
<!-- 页面标题 -->
|
<!-- 对话列表部分 -->
|
||||||
<v-row>
|
<v-card flat>
|
||||||
<v-col cols="12">
|
|
||||||
<h1 class="text-h4 font-weight-bold mb-2">
|
|
||||||
<v-icon size="x-large" color="primary" class="me-2">mdi-chat-processing</v-icon>{{ tm('title') }}
|
|
||||||
</h1>
|
|
||||||
<p class="text-subtitle-1 text-medium-emphasis mb-4">
|
|
||||||
{{ tm('subtitle') }}
|
|
||||||
</p>
|
|
||||||
</v-col>
|
|
||||||
</v-row>
|
|
||||||
|
|
||||||
<!-- 过滤器部分 -->
|
|
||||||
<v-card class="mb-4" elevation="2">
|
|
||||||
<v-card-title class="d-flex align-center py-3 px-4">
|
<v-card-title class="d-flex align-center py-3 px-4">
|
||||||
<v-icon color="primary" class="me-2">mdi-filter-variant</v-icon>
|
<span class="text-h4">{{ tm('history.title') }}</span>
|
||||||
<span class="text-h6">{{ tm('filters.title') }}</span>
|
<v-chip size="small" class="ml-2">{{ pagination.total || 0 }}</v-chip>
|
||||||
<v-spacer></v-spacer>
|
<v-row class="me-4 ms-4" dense>
|
||||||
<v-btn color="primary" variant="text" @click="resetFilters" class="ml-2">
|
|
||||||
<v-icon class="mr-1">mdi-refresh</v-icon>{{ tm('filters.reset') }}
|
|
||||||
</v-btn>
|
|
||||||
</v-card-title>
|
|
||||||
|
|
||||||
<v-divider></v-divider>
|
|
||||||
|
|
||||||
<v-card-text class="py-4">
|
|
||||||
<v-row>
|
|
||||||
<v-col cols="12" sm="6" md="4">
|
<v-col cols="12" sm="6" md="4">
|
||||||
<v-select v-model="platformFilter" :label="tm('filters.platform')" :items="availablePlatforms" chips multiple
|
<v-combobox v-model="platformFilter" :label="tm('filters.platform')"
|
||||||
clearable variant="outlined" density="compact" hide-details>
|
:items="availablePlatforms" chips multiple clearable variant="solo-filled" flat
|
||||||
|
density="compact" hide-details :disabled="loading">
|
||||||
<template v-slot:selection="{ item }">
|
<template v-slot:selection="{ item }">
|
||||||
<v-chip size="small" :color="getPlatformColor(item.value)" label>
|
<v-chip size="small" label>
|
||||||
{{ item.title }}
|
{{ item.title }}
|
||||||
</v-chip>
|
</v-chip>
|
||||||
</template>
|
</template>
|
||||||
</v-select>
|
</v-combobox>
|
||||||
</v-col>
|
</v-col>
|
||||||
|
|
||||||
<v-col cols="12" sm="6" md="4">
|
<v-col cols="12" sm="6" md="4">
|
||||||
<v-select v-model="messageTypeFilter" :label="tm('filters.type')" :items="messageTypeItems" chips multiple
|
<v-select v-model="messageTypeFilter" :label="tm('filters.type')" :items="messageTypeItems"
|
||||||
clearable variant="outlined" density="compact" hide-details>
|
chips multiple clearable variant="solo-filled" density="compact" hide-details flat
|
||||||
|
:disabled="loading">
|
||||||
<template v-slot:selection="{ item }">
|
<template v-slot:selection="{ item }">
|
||||||
<v-chip size="small" :color="getMessageTypeColor(item.value)" variant="outlined"
|
<v-chip size="small" variant="solo-filled" label>
|
||||||
label>
|
|
||||||
{{ item.title }}
|
{{ item.title }}
|
||||||
</v-chip>
|
</v-chip>
|
||||||
</template>
|
</template>
|
||||||
@@ -52,49 +32,49 @@
|
|||||||
</v-col>
|
</v-col>
|
||||||
|
|
||||||
<v-col cols="12" sm="12" md="4">
|
<v-col cols="12" sm="12" md="4">
|
||||||
<v-text-field v-model="search" prepend-inner-icon="mdi-magnify" :label="tm('filters.search')" hide-details
|
<v-text-field v-model="search" prepend-inner-icon="mdi-magnify"
|
||||||
density="compact" variant="outlined" clearable></v-text-field>
|
:label="tm('filters.search')" hide-details density="compact" variant="solo-filled" flat
|
||||||
|
clearable :disabled="loading"></v-text-field>
|
||||||
</v-col>
|
</v-col>
|
||||||
</v-row>
|
</v-row>
|
||||||
</v-card-text>
|
|
||||||
</v-card>
|
|
||||||
|
|
||||||
<!-- 对话列表部分 -->
|
|
||||||
<v-card class="mb-6" elevation="2">
|
|
||||||
<v-card-title class="d-flex align-center py-3 px-4">
|
|
||||||
<v-icon color="primary" class="me-2">mdi-message</v-icon>
|
|
||||||
<span class="text-h6">{{ tm('history.title') }}</span>
|
|
||||||
<v-chip color="info" size="small" class="ml-2">{{ pagination.total || 0 }}</v-chip>
|
|
||||||
<v-spacer></v-spacer>
|
|
||||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="fetchConversations"
|
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="fetchConversations"
|
||||||
:loading="loading">
|
:loading="loading" size="small" class="mr-2">
|
||||||
{{ tm('history.refresh') }}
|
{{ tm('history.refresh') }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
|
<v-btn
|
||||||
|
v-if="selectedItems.length > 0"
|
||||||
|
color="error"
|
||||||
|
prepend-icon="mdi-delete"
|
||||||
|
variant="tonal"
|
||||||
|
@click="confirmBatchDelete"
|
||||||
|
:disabled="loading"
|
||||||
|
size="small">
|
||||||
|
{{ tm('batch.deleteSelected', { count: selectedItems.length }) }}
|
||||||
|
</v-btn>
|
||||||
</v-card-title>
|
</v-card-title>
|
||||||
|
|
||||||
<v-divider></v-divider>
|
<v-divider></v-divider>
|
||||||
|
|
||||||
<v-card-text class="pa-0">
|
<v-card-text class="pa-0">
|
||||||
<v-data-table :headers="tableHeaders" :items="conversations" :loading="loading" density="comfortable"
|
<v-data-table v-model="selectedItems" :headers="tableHeaders" :items="conversations"
|
||||||
hide-default-footer items-per-page="10" class="elevation-0"
|
:loading="loading" style="font-size: 12px;" density="comfortable" hide-default-footer
|
||||||
:items-per-page="pagination.page_size" :items-per-page-options="[10, 20, 50, 100]"
|
class="elevation-0" :items-per-page="pagination.page_size"
|
||||||
@update:options="handleTableOptions">
|
:items-per-page-options="pageSizeOptions" show-select return-object
|
||||||
|
:disabled="loading" @update:options="handleTableOptions">
|
||||||
<template v-slot:item.title="{ item }">
|
<template v-slot:item.title="{ item }">
|
||||||
<div class="d-flex align-center">
|
<div class="d-flex align-center">
|
||||||
<v-icon color="primary" class="mr-2">mdi-chat</v-icon>
|
|
||||||
<span>{{ item.title || tm('status.noTitle') }}</span>
|
<span>{{ item.title || tm('status.noTitle') }}</span>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<template v-slot:item.platform="{ item }">
|
<template v-slot:item.platform="{ item }">
|
||||||
<v-chip :color="getPlatformColor(item.sessionInfo.platform)" size="small" label>
|
<v-chip size="small" label>
|
||||||
{{ item.sessionInfo.platform || tm('status.unknown') }}
|
{{ item.sessionInfo.platform || tm('status.unknown') }}
|
||||||
</v-chip>
|
</v-chip>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<template v-slot:item.messageType="{ item }">
|
<template v-slot:item.messageType="{ item }">
|
||||||
<v-chip :color="getMessageTypeColor(item.sessionInfo.messageType)" size="small"
|
<v-chip size="small" label>
|
||||||
variant="outlined" label>
|
|
||||||
{{ getMessageTypeDisplay(item.sessionInfo.messageType) }}
|
{{ getMessageTypeDisplay(item.sessionInfo.messageType) }}
|
||||||
</v-chip>
|
</v-chip>
|
||||||
</template>
|
</template>
|
||||||
@@ -113,17 +93,17 @@
|
|||||||
|
|
||||||
<template v-slot:item.actions="{ item }">
|
<template v-slot:item.actions="{ item }">
|
||||||
<div class="actions-wrapper">
|
<div class="actions-wrapper">
|
||||||
<v-btn color="primary" variant="flat" size="small" class="action-button"
|
<v-btn icon variant="plain" size="x-small" class="action-button"
|
||||||
@click="viewConversation(item)">
|
@click="viewConversation(item)" :disabled="loading">
|
||||||
<v-icon class="mr-1">mdi-eye</v-icon>{{ tm('actions.view') }}
|
<v-icon>mdi-eye</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn color="warning" variant="flat" size="small" class="action-button"
|
<v-btn icon variant="plain" size="x-small" class="action-button"
|
||||||
@click="editConversation(item)">
|
@click="editConversation(item)" :disabled="loading">
|
||||||
<v-icon class="mr-1">mdi-pencil</v-icon>{{ tm('actions.edit') }}
|
<v-icon>mdi-pencil</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn color="error" variant="flat" size="small" class="action-button"
|
<v-btn icon color="error" variant="plain" size="x-small" class="action-button"
|
||||||
@click="confirmDeleteConversation(item)">
|
@click="confirmDeleteConversation(item)" :disabled="loading">
|
||||||
<v-icon class="mr-1">mdi-delete</v-icon>{{ tm('actions.delete') }}
|
<v-icon>mdi-delete</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
@@ -137,9 +117,25 @@
|
|||||||
</v-data-table>
|
</v-data-table>
|
||||||
|
|
||||||
<!-- 分页控制 -->
|
<!-- 分页控制 -->
|
||||||
<div class="d-flex justify-end pa-4">
|
<div class="d-flex justify-center py-3">
|
||||||
|
<!-- 每页大小选择器 -->
|
||||||
|
<div class="d-flex justify-between align-center px-4 py-2 bg-grey-lighten-5">
|
||||||
|
<div class="d-flex align-center">
|
||||||
|
<span class="text-caption mr-2">{{ tm('pagination.itemsPerPage') }}:</span>
|
||||||
|
<v-select v-model="pagination.page_size" :items="pageSizeOptions" variant="outlined"
|
||||||
|
density="compact" hide-details style="max-width: 100px;"
|
||||||
|
:disabled="loading" @update:model-value="onPageSizeChange"></v-select>
|
||||||
|
</div>
|
||||||
|
<div class="text-caption ml-4">
|
||||||
|
{{ tm('pagination.showingItems', {
|
||||||
|
start: Math.min((pagination.page - 1) * pagination.page_size + 1, pagination.total),
|
||||||
|
end: Math.min(pagination.page * pagination.page_size, pagination.total),
|
||||||
|
total: pagination.total
|
||||||
|
}) }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<v-pagination v-model="pagination.page" :length="pagination.total_pages" :disabled="loading"
|
<v-pagination v-model="pagination.page" :length="pagination.total_pages" :disabled="loading"
|
||||||
@update:model-value="fetchConversations" rounded="circle"></v-pagination>
|
@update:model-value="fetchConversations" rounded="circle" :total-visible="7"></v-pagination>
|
||||||
</div>
|
</div>
|
||||||
</v-card-text>
|
</v-card-text>
|
||||||
</v-card>
|
</v-card>
|
||||||
@@ -148,24 +144,20 @@
|
|||||||
<!-- 对话详情对话框 -->
|
<!-- 对话详情对话框 -->
|
||||||
<v-dialog v-model="dialogView" max-width="900px" scrollable>
|
<v-dialog v-model="dialogView" max-width="900px" scrollable>
|
||||||
<v-card class="conversation-detail-card">
|
<v-card class="conversation-detail-card">
|
||||||
<v-card-title class="bg-primary text-white py-3 d-flex align-center">
|
<v-card-title class="ml-2 mt-2 d-flex align-center">
|
||||||
<v-icon color="white" class="me-2">mdi-eye</v-icon>
|
|
||||||
<span class="text-truncate">{{ selectedConversation?.title || tm('status.noTitle') }}</span>
|
<span class="text-truncate">{{ selectedConversation?.title || tm('status.noTitle') }}</span>
|
||||||
<v-spacer></v-spacer>
|
<v-spacer></v-spacer>
|
||||||
|
|
||||||
<div class="d-flex align-center" v-if="selectedConversation?.sessionInfo">
|
<div class="d-flex align-center" v-if="selectedConversation?.sessionInfo">
|
||||||
<v-chip color="white" text-color="primary" size="small" class="mr-2">
|
<v-chip text-color="primary" size="small" class="mr-2" rounded="md">
|
||||||
{{ selectedConversation.sessionInfo.platform }}
|
{{ selectedConversation.sessionInfo.platform }}
|
||||||
</v-chip>
|
</v-chip>
|
||||||
<v-chip color="white" text-color="secondary" size="small">
|
<v-chip text-color="secondary" size="small" rounded="md">
|
||||||
{{ getMessageTypeDisplay(selectedConversation.sessionInfo.messageType) }}
|
{{ getMessageTypeDisplay(selectedConversation.sessionInfo.messageType) }}
|
||||||
</v-chip>
|
</v-chip>
|
||||||
</div>
|
</div>
|
||||||
</v-card-title>
|
</v-card-title>
|
||||||
|
|
||||||
<v-divider></v-divider>
|
<v-card-text>
|
||||||
|
|
||||||
<v-card-text class="py-4">
|
|
||||||
<div class="mb-4 d-flex align-center">
|
<div class="mb-4 d-flex align-center">
|
||||||
<v-btn color="secondary" variant="tonal" size="small" class="mr-2"
|
<v-btn color="secondary" variant="tonal" size="small" class="mr-2"
|
||||||
@click="isEditingHistory = !isEditingHistory">
|
@click="isEditingHistory = !isEditingHistory">
|
||||||
@@ -199,51 +191,11 @@
|
|||||||
<p class="text-disabled mt-2">{{ tm('status.emptyContent') }}</p>
|
<p class="text-disabled mt-2">{{ tm('status.emptyContent') }}</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 消息列表 -->
|
<!-- 消息列表组件 -->
|
||||||
<div v-else class="message-list">
|
<MessageList v-else :messages="formattedMessages" :isDark="false" />
|
||||||
<div class="message-item" v-for="(msg, index) in conversationHistory" :key="index">
|
|
||||||
<!-- 用户消息 -->
|
|
||||||
<div v-if="msg.role === 'user'" class="user-message">
|
|
||||||
<div class="message-bubble user-bubble">
|
|
||||||
<span v-html="formatMessage(msg.content)"></span>
|
|
||||||
|
|
||||||
<!-- 图片附件 -->
|
|
||||||
<div class="image-attachments" v-if="msg.image_url && msg.image_url.length > 0">
|
|
||||||
<div v-for="(img, imgIndex) in msg.image_url" :key="imgIndex"
|
|
||||||
class="image-attachment">
|
|
||||||
<img :src="img" class="attached-image" />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 音频附件 -->
|
|
||||||
<div class="audio-attachment" v-if="msg.audio_url">
|
|
||||||
<audio controls class="audio-player">
|
|
||||||
<source :src="msg.audio_url" type="audio/wav">
|
|
||||||
{{ tm('status.audioNotSupported') }}
|
|
||||||
</audio>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<v-avatar class="user-avatar" color="deep-purple-lighten-3" size="36">
|
|
||||||
<v-icon icon="mdi-account" />
|
|
||||||
</v-avatar>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 机器人消息 -->
|
|
||||||
<div v-else class="bot-message">
|
|
||||||
<v-avatar class="bot-avatar" color="deep-purple" size="36">
|
|
||||||
<span class="text-h6">✨</span>
|
|
||||||
</v-avatar>
|
|
||||||
<div class="message-bubble bot-bubble">
|
|
||||||
<div v-html="formatMessage(msg.content)" class="markdown-content"></div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</v-card-text>
|
</v-card-text>
|
||||||
|
|
||||||
<v-divider></v-divider>
|
|
||||||
|
|
||||||
<v-card-actions class="pa-4">
|
<v-card-actions class="pa-4">
|
||||||
<v-spacer></v-spacer>
|
<v-spacer></v-spacer>
|
||||||
<v-btn variant="text" @click="closeHistoryDialog">
|
<v-btn variant="text" @click="closeHistoryDialog">
|
||||||
@@ -263,8 +215,9 @@
|
|||||||
|
|
||||||
<v-card-text class="py-4">
|
<v-card-text class="py-4">
|
||||||
<v-form ref="form" v-model="valid">
|
<v-form ref="form" v-model="valid">
|
||||||
<v-text-field v-model="editedItem.title" :label="tm('dialogs.edit.titleLabel')" :placeholder="tm('dialogs.edit.titlePlaceholder')" variant="outlined"
|
<v-text-field v-model="editedItem.title" :label="tm('dialogs.edit.titleLabel')"
|
||||||
density="comfortable" class="mb-3"></v-text-field>
|
:placeholder="tm('dialogs.edit.titlePlaceholder')" variant="outlined" density="comfortable"
|
||||||
|
class="mb-3"></v-text-field>
|
||||||
</v-form>
|
</v-form>
|
||||||
</v-card-text>
|
</v-card-text>
|
||||||
|
|
||||||
@@ -291,7 +244,8 @@
|
|||||||
</v-card-title>
|
</v-card-title>
|
||||||
|
|
||||||
<v-card-text class="py-4">
|
<v-card-text class="py-4">
|
||||||
<p>{{ tm('dialogs.delete.message', { title: selectedConversation?.title || tm('status.noTitle') }) }}</p>
|
<p>{{ tm('dialogs.delete.message', { title: selectedConversation?.title || tm('status.noTitle') })
|
||||||
|
}}</p>
|
||||||
</v-card-text>
|
</v-card-text>
|
||||||
|
|
||||||
<v-divider></v-divider>
|
<v-divider></v-divider>
|
||||||
@@ -308,6 +262,48 @@
|
|||||||
</v-card>
|
</v-card>
|
||||||
</v-dialog>
|
</v-dialog>
|
||||||
|
|
||||||
|
<!-- 批量删除确认对话框 -->
|
||||||
|
<v-dialog v-model="dialogBatchDelete" max-width="600px">
|
||||||
|
<v-card>
|
||||||
|
<v-card-title class="bg-error text-white py-3">
|
||||||
|
<v-icon color="white" class="me-2">mdi-delete</v-icon>
|
||||||
|
<span>{{ tm('dialogs.batchDelete.title') }}</span>
|
||||||
|
</v-card-title>
|
||||||
|
|
||||||
|
<v-card-text class="py-4">
|
||||||
|
<p class="mb-3">{{ tm('dialogs.batchDelete.message', { count: selectedItems.length }) }}</p>
|
||||||
|
|
||||||
|
<!-- 显示前几个要删除的对话 -->
|
||||||
|
<div v-if="selectedItems.length > 0" class="mb-3">
|
||||||
|
<v-chip v-for="(item, index) in selectedItems.slice(0, 5)" :key="`${item.user_id}-${item.cid}`"
|
||||||
|
size="small" class="mr-1 mb-1" closable @click:close="removeFromSelection(item)"
|
||||||
|
:disabled="loading">
|
||||||
|
{{ item.title || tm('status.noTitle') }}
|
||||||
|
</v-chip>
|
||||||
|
<v-chip v-if="selectedItems.length > 5" size="small" class="mr-1 mb-1">
|
||||||
|
{{ tm('dialogs.batchDelete.andMore', { count: selectedItems.length - 5 }) }}
|
||||||
|
</v-chip>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<v-alert type="warning" variant="tonal" class="mb-3">
|
||||||
|
{{ tm('dialogs.batchDelete.warning') }}
|
||||||
|
</v-alert>
|
||||||
|
</v-card-text>
|
||||||
|
|
||||||
|
<v-divider></v-divider>
|
||||||
|
|
||||||
|
<v-card-actions class="pa-4">
|
||||||
|
<v-spacer></v-spacer>
|
||||||
|
<v-btn variant="text" @click="dialogBatchDelete = false" :disabled="loading">
|
||||||
|
{{ tm('dialogs.batchDelete.cancel') }}
|
||||||
|
</v-btn>
|
||||||
|
<v-btn color="error" @click="batchDeleteConversations" :loading="loading">
|
||||||
|
{{ tm('dialogs.batchDelete.confirm') }}
|
||||||
|
</v-btn>
|
||||||
|
</v-card-actions>
|
||||||
|
</v-card>
|
||||||
|
</v-dialog>
|
||||||
|
|
||||||
<!-- 消息提示 -->
|
<!-- 消息提示 -->
|
||||||
<v-snackbar :timeout="3000" elevation="24" :color="messageType" v-model="showMessage" location="top">
|
<v-snackbar :timeout="3000" elevation="24" :color="messageType" v-model="showMessage" location="top">
|
||||||
{{ message }}
|
{{ message }}
|
||||||
@@ -321,6 +317,7 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor';
|
|||||||
import MarkdownIt from 'markdown-it';
|
import MarkdownIt from 'markdown-it';
|
||||||
import { useCommonStore } from '@/stores/common';
|
import { useCommonStore } from '@/stores/common';
|
||||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||||
|
import MessageList from '@/components/chat/MessageList.vue';
|
||||||
|
|
||||||
// 配置markdown-it,默认安全设置
|
// 配置markdown-it,默认安全设置
|
||||||
const md = new MarkdownIt({
|
const md = new MarkdownIt({
|
||||||
@@ -333,7 +330,8 @@ const md = new MarkdownIt({
|
|||||||
export default {
|
export default {
|
||||||
name: 'ConversationPage',
|
name: 'ConversationPage',
|
||||||
components: {
|
components: {
|
||||||
VueMonacoEditor
|
VueMonacoEditor,
|
||||||
|
MessageList
|
||||||
},
|
},
|
||||||
|
|
||||||
setup() {
|
setup() {
|
||||||
@@ -353,32 +351,13 @@ export default {
|
|||||||
conversations: [],
|
conversations: [],
|
||||||
search: '',
|
search: '',
|
||||||
headers: [],
|
headers: [],
|
||||||
|
selectedItems: [], // 批量选择的项目
|
||||||
|
|
||||||
// 筛选条件
|
// 筛选条件
|
||||||
platformFilter: [],
|
platformFilter: [],
|
||||||
messageTypeFilter: [],
|
messageTypeFilter: [],
|
||||||
lastAppliedFilters: null, // 记录上次应用的筛选条件
|
lastAppliedFilters: null, // 记录上次应用的筛选条件
|
||||||
|
|
||||||
// 平台颜色映射
|
|
||||||
platformColors: {
|
|
||||||
'telegram': 'blue-lighten-1',
|
|
||||||
'qq_official': 'purple-lighten-1',
|
|
||||||
'qq_official_webhook': 'purple-lighten-2',
|
|
||||||
'aiocqhttp': 'deep-purple-lighten-1',
|
|
||||||
'lark': 'cyan-darken-1',
|
|
||||||
'wecom': 'green-darken-1',
|
|
||||||
'dingtalk': 'blue-darken-2',
|
|
||||||
'default': 'grey-lighten-1'
|
|
||||||
},
|
|
||||||
|
|
||||||
// 消息类型颜色映射
|
|
||||||
messageTypeColors: {
|
|
||||||
'GroupMessage': 'green',
|
|
||||||
'FriendMessage': 'blue',
|
|
||||||
'GuildMessage': 'purple',
|
|
||||||
'default': 'grey'
|
|
||||||
},
|
|
||||||
|
|
||||||
// 分页数据
|
// 分页数据
|
||||||
pagination: {
|
pagination: {
|
||||||
page: 1,
|
page: 1,
|
||||||
@@ -386,11 +365,13 @@ export default {
|
|||||||
total: 0,
|
total: 0,
|
||||||
total_pages: 0
|
total_pages: 0
|
||||||
},
|
},
|
||||||
|
pageSizeOptions: [10, 20, 50, 100], // 每页大小选项
|
||||||
|
|
||||||
// 对话框控制
|
// 对话框控制
|
||||||
dialogView: false,
|
dialogView: false,
|
||||||
dialogEdit: false,
|
dialogEdit: false,
|
||||||
dialogDelete: false,
|
dialogDelete: false,
|
||||||
|
dialogBatchDelete: false, // 批量删除对话框
|
||||||
|
|
||||||
// 选中的对话
|
// 选中的对话
|
||||||
selectedConversation: null,
|
selectedConversation: null,
|
||||||
@@ -402,11 +383,6 @@ export default {
|
|||||||
cid: '',
|
cid: '',
|
||||||
title: ''
|
title: ''
|
||||||
},
|
},
|
||||||
defaultItem: {
|
|
||||||
user_id: '',
|
|
||||||
cid: '',
|
|
||||||
title: ''
|
|
||||||
},
|
|
||||||
|
|
||||||
// 表单验证
|
// 表单验证
|
||||||
valid: true,
|
valid: true,
|
||||||
@@ -454,12 +430,18 @@ export default {
|
|||||||
tableHeaders() {
|
tableHeaders() {
|
||||||
return [
|
return [
|
||||||
{ title: this.tm('table.headers.title'), key: 'title', sortable: true },
|
{ title: this.tm('table.headers.title'), key: 'title', sortable: true },
|
||||||
|
{
|
||||||
|
title: this.tm('table.headers.sessionId'),
|
||||||
|
align: 'center',
|
||||||
|
children: [
|
||||||
{ title: this.tm('table.headers.platform'), key: 'platform', sortable: true, width: '120px' },
|
{ title: this.tm('table.headers.platform'), key: 'platform', sortable: true, width: '120px' },
|
||||||
{ title: this.tm('table.headers.type'), key: 'messageType', sortable: true, width: '100px' },
|
{ title: this.tm('table.headers.type'), key: 'messageType', sortable: true, width: '100px' },
|
||||||
{ title: this.tm('table.headers.sessionId'), key: 'sessionId', sortable: true, width: '100px' },
|
{ title: '会话 ID', key: 'sessionId', sortable: true, width: '100px' },
|
||||||
|
],
|
||||||
|
},
|
||||||
{ title: this.tm('table.headers.createdAt'), key: 'created_at', sortable: true, width: '180px' },
|
{ title: this.tm('table.headers.createdAt'), key: 'created_at', sortable: true, width: '180px' },
|
||||||
{ title: this.tm('table.headers.updatedAt'), key: 'updated_at', sortable: true, width: '180px' },
|
{ title: this.tm('table.headers.updatedAt'), key: 'updated_at', sortable: true, width: '180px' },
|
||||||
{ title: this.tm('table.headers.actions'), key: 'actions', sortable: false, align: 'center', width: '240px' }
|
{ title: this.tm('table.headers.actions'), key: 'actions', sortable: false, align: 'center' }
|
||||||
];
|
];
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -487,24 +469,40 @@ export default {
|
|||||||
];
|
];
|
||||||
},
|
},
|
||||||
|
|
||||||
// 筛选后的对话 - 现在只用于额外的客户端筛选(排除astrbot和webchat)
|
|
||||||
filteredConversations() {
|
|
||||||
return this.conversations.filter(conv => {
|
|
||||||
// 排除 user_id 为 astrbot 或 platform 为 webchat 的对话
|
|
||||||
if (conv.user_id === 'astrbot' || conv.sessionInfo?.platform === 'webchat') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
},
|
|
||||||
|
|
||||||
// 当前的筛选条件对象
|
// 当前的筛选条件对象
|
||||||
currentFilters() {
|
currentFilters() {
|
||||||
|
const platforms = this.platformFilter.map(item =>
|
||||||
|
typeof item === 'object' ? item.value : item
|
||||||
|
);
|
||||||
return {
|
return {
|
||||||
platforms: this.platformFilter,
|
platforms: platforms,
|
||||||
messageTypes: this.messageTypeFilter,
|
messageTypes: this.messageTypeFilter,
|
||||||
search: this.search
|
search: this.search
|
||||||
};
|
};
|
||||||
|
},
|
||||||
|
|
||||||
|
// 将对话历史转换为 MessageList 组件期望的格式
|
||||||
|
formattedMessages() {
|
||||||
|
return this.conversationHistory.map(msg => {
|
||||||
|
console.log('处理消息:', msg.role, msg.image_url, msg.audio_url);
|
||||||
|
if (msg.role === 'user') {
|
||||||
|
return {
|
||||||
|
content: {
|
||||||
|
type: 'user',
|
||||||
|
message: this.extractTextFromContent(msg.content),
|
||||||
|
image_url: this.extractImagesFromContent(msg.content),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
content: {
|
||||||
|
type: 'bot',
|
||||||
|
message: this.extractTextFromContent(msg.content),
|
||||||
|
embedded_images: this.extractImagesFromContent(msg.content),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -541,16 +539,6 @@ export default {
|
|||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
|
||||||
// 重置过滤条件
|
|
||||||
resetFilters() {
|
|
||||||
this.platformFilter = [];
|
|
||||||
this.messageTypeFilter = [];
|
|
||||||
this.search = '';
|
|
||||||
// 立即应用筛选,不使用防抖
|
|
||||||
this.pagination.page = 1;
|
|
||||||
this.fetchConversations();
|
|
||||||
},
|
|
||||||
|
|
||||||
// 处理表格选项变更(页面大小等)
|
// 处理表格选项变更(页面大小等)
|
||||||
handleTableOptions(options) {
|
handleTableOptions(options) {
|
||||||
// 处理页面大小变更
|
// 处理页面大小变更
|
||||||
@@ -579,16 +567,6 @@ export default {
|
|||||||
return { platform: 'default', messageType: 'default', sessionId: userId };
|
return { platform: 'default', messageType: 'default', sessionId: userId };
|
||||||
},
|
},
|
||||||
|
|
||||||
// 获取平台对应的颜色
|
|
||||||
getPlatformColor(platform) {
|
|
||||||
return this.platformColors[platform] || this.platformColors.default;
|
|
||||||
},
|
|
||||||
|
|
||||||
// 获取消息类型对应的颜色
|
|
||||||
getMessageTypeColor(messageType) {
|
|
||||||
return this.messageTypeColors[messageType] || this.messageTypeColors.default;
|
|
||||||
},
|
|
||||||
|
|
||||||
// 获取消息类型的显示文本
|
// 获取消息类型的显示文本
|
||||||
getMessageTypeDisplay(messageType) {
|
getMessageTypeDisplay(messageType) {
|
||||||
const typeMap = {
|
const typeMap = {
|
||||||
@@ -610,9 +588,12 @@ export default {
|
|||||||
page_size: this.pagination.page_size
|
page_size: this.pagination.page_size
|
||||||
};
|
};
|
||||||
|
|
||||||
// 添加筛选条件
|
// 添加筛选条件 - 处理combobox的混合数据格式
|
||||||
if (this.platformFilter.length > 0) {
|
if (this.platformFilter.length > 0) {
|
||||||
params.platforms = this.platformFilter.join(',');
|
const platforms = this.platformFilter.map(item =>
|
||||||
|
typeof item === 'object' ? item.value : item
|
||||||
|
);
|
||||||
|
params.platforms = platforms.join(',');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.messageTypeFilter.length > 0) {
|
if (this.messageTypeFilter.length > 0) {
|
||||||
@@ -620,19 +601,15 @@ export default {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (this.search) {
|
if (this.search) {
|
||||||
params.search = this.search;
|
params.search = this.search.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加排除条件
|
// 添加排除条件
|
||||||
params.exclude_ids = 'astrbot';
|
params.exclude_ids = 'astrbot';
|
||||||
params.exclude_platforms = 'webchat';
|
params.exclude_platforms = 'webchat';
|
||||||
|
|
||||||
console.log(`正在请求对话列表: /api/conversation/list 参数:`, params);
|
|
||||||
|
|
||||||
const response = await axios.get('/api/conversation/list', { params });
|
const response = await axios.get('/api/conversation/list', { params });
|
||||||
|
|
||||||
console.log('收到对话列表响应:', response.data);
|
|
||||||
|
|
||||||
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
|
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
|
||||||
|
|
||||||
if (response.data.status === "ok") {
|
if (response.data.status === "ok") {
|
||||||
@@ -836,6 +813,88 @@ export default {
|
|||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.deleteError'));
|
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.deleteError'));
|
||||||
|
} finally {
|
||||||
|
this.loading = false;
|
||||||
|
this.selectedItems = this.selectedItems.filter(item =>
|
||||||
|
!(item.user_id === this.selectedConversation.user_id && item.cid === this.selectedConversation.cid)
|
||||||
|
);
|
||||||
|
this.selectedConversation = null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// 处理页面大小变更
|
||||||
|
onPageSizeChange() {
|
||||||
|
this.pagination.page = 1; // 重置到第一页
|
||||||
|
this.fetchConversations();
|
||||||
|
},
|
||||||
|
|
||||||
|
// 确认批量删除
|
||||||
|
confirmBatchDelete() {
|
||||||
|
if (this.selectedItems.length === 0) {
|
||||||
|
this.showErrorMessage(this.tm('messages.noItemSelected'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.dialogBatchDelete = true;
|
||||||
|
},
|
||||||
|
|
||||||
|
// 从选择中移除项目
|
||||||
|
removeFromSelection(item) {
|
||||||
|
const index = this.selectedItems.findIndex(selected =>
|
||||||
|
selected.user_id === item.user_id && selected.cid === item.cid
|
||||||
|
);
|
||||||
|
if (index !== -1) {
|
||||||
|
this.selectedItems.splice(index, 1);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// 批量删除对话
|
||||||
|
async batchDeleteConversations() {
|
||||||
|
if (this.selectedItems.length === 0) {
|
||||||
|
this.showErrorMessage(this.tm('messages.noItemSelected'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.loading = true;
|
||||||
|
try {
|
||||||
|
// 准备批量删除的数据
|
||||||
|
const conversations = this.selectedItems.map(item => ({
|
||||||
|
user_id: item.user_id,
|
||||||
|
cid: item.cid
|
||||||
|
}));
|
||||||
|
|
||||||
|
const response = await axios.post('/api/conversation/delete', {
|
||||||
|
conversations: conversations
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.data.status === "ok") {
|
||||||
|
const result = response.data.data;
|
||||||
|
this.dialogBatchDelete = false;
|
||||||
|
this.selectedItems = []; // 清空选择
|
||||||
|
|
||||||
|
// 显示结果消息
|
||||||
|
if (result.failed_count > 0) {
|
||||||
|
this.showErrorMessage(
|
||||||
|
this.tm('messages.batchDeletePartial', {
|
||||||
|
deleted: result.deleted_count,
|
||||||
|
failed: result.failed_count
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
this.showSuccessMessage(
|
||||||
|
this.tm('messages.batchDeleteSuccess', {
|
||||||
|
count: result.deleted_count
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新列表
|
||||||
|
this.fetchConversations();
|
||||||
|
} else {
|
||||||
|
this.showErrorMessage(response.data.message || this.tm('messages.batchDeleteError'));
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('批量删除对话出错:', error);
|
||||||
|
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.batchDeleteError'));
|
||||||
} finally {
|
} finally {
|
||||||
this.loading = false;
|
this.loading = false;
|
||||||
}
|
}
|
||||||
@@ -858,35 +917,6 @@ export default {
|
|||||||
}).format(date);
|
}).format(date);
|
||||||
},
|
},
|
||||||
|
|
||||||
// 格式化消息内容
|
|
||||||
formatMessage(content) {
|
|
||||||
|
|
||||||
// content 可能是数组
|
|
||||||
// [{"type": "image_url", "image_url": {"url": url_or_base64}}, {"type": "text", "text": "text"}]
|
|
||||||
|
|
||||||
let final_content = content;
|
|
||||||
if (Array.isArray(content)) {
|
|
||||||
// 处理数组内容
|
|
||||||
final_content = content.map(item => {
|
|
||||||
if (item.type === 'image_url') {
|
|
||||||
return `<img src="${item.image_url.url}" alt="Image" />`;
|
|
||||||
} else if (item.type === 'text') {
|
|
||||||
return item.text;
|
|
||||||
}
|
|
||||||
return '';
|
|
||||||
}).join('\n');
|
|
||||||
} else if (typeof content === 'object') {
|
|
||||||
// 处理对象内容
|
|
||||||
final_content = Object.values(content).join('');
|
|
||||||
} else if (typeof content === 'string') {
|
|
||||||
// 处理字符串内容
|
|
||||||
final_content = content;
|
|
||||||
} else if (!final_content) return this.tm('status.emptyContent');
|
|
||||||
|
|
||||||
// 使用markdown-it处理,默认安全(html: false会禁用HTML标签)
|
|
||||||
return md.render(final_content);
|
|
||||||
},
|
|
||||||
|
|
||||||
// 显示成功消息
|
// 显示成功消息
|
||||||
showSuccessMessage(message) {
|
showSuccessMessage(message) {
|
||||||
this.message = message;
|
this.message = message;
|
||||||
@@ -899,16 +929,36 @@ export default {
|
|||||||
this.message = message;
|
this.message = message;
|
||||||
this.messageType = 'error';
|
this.messageType = 'error';
|
||||||
this.showMessage = true;
|
this.showMessage = true;
|
||||||
|
},
|
||||||
|
|
||||||
|
// 从内容中提取文本
|
||||||
|
extractTextFromContent(content) {
|
||||||
|
if (typeof content === 'string') {
|
||||||
|
return content;
|
||||||
|
} else if (Array.isArray(content)) {
|
||||||
|
return content.filter(item => item.type === 'text')
|
||||||
|
.map(item => item.text)
|
||||||
|
.join('\n');
|
||||||
|
} else if (typeof content === 'object') {
|
||||||
|
return Object.values(content).filter(val => typeof val === 'string').join('');
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
},
|
||||||
|
|
||||||
|
// 从内容中提取图片URL
|
||||||
|
extractImagesFromContent(content) {
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
return content.filter(item => item.type === 'image_url')
|
||||||
|
.map(item => item.image_url?.url)
|
||||||
|
.filter(url => url);
|
||||||
|
}
|
||||||
|
return [];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style>
|
<style>
|
||||||
.conversation-page {
|
|
||||||
padding: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.actions-wrapper {
|
.actions-wrapper {
|
||||||
display: flex;
|
display: flex;
|
||||||
justify-content: flex-end;
|
justify-content: flex-end;
|
||||||
@@ -918,11 +968,6 @@ export default {
|
|||||||
.action-button {
|
.action-button {
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
transition: all 0.2s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
.action-button:hover {
|
|
||||||
transform: translateY(-2px);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.monaco-editor-container {
|
.monaco-editor-container {
|
||||||
@@ -932,7 +977,7 @@ export default {
|
|||||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
|
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* 聊天消息样式 */
|
/* 聊天消息容器样式 */
|
||||||
.conversation-messages-container {
|
.conversation-messages-container {
|
||||||
max-height: 500px;
|
max-height: 500px;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
@@ -941,87 +986,6 @@ export default {
|
|||||||
background-color: #f9f9f9;
|
background-color: #f9f9f9;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-list {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 16px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-item {
|
|
||||||
margin-bottom: 8px;
|
|
||||||
animation: fadeIn 0.3s ease-out;
|
|
||||||
}
|
|
||||||
|
|
||||||
.user-message {
|
|
||||||
display: flex;
|
|
||||||
justify-content: flex-end;
|
|
||||||
align-items: flex-start;
|
|
||||||
gap: 12px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.bot-message {
|
|
||||||
display: flex;
|
|
||||||
justify-content: flex-start;
|
|
||||||
align-items: flex-start;
|
|
||||||
gap: 12px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-bubble {
|
|
||||||
padding: 12px 16px;
|
|
||||||
border-radius: 18px;
|
|
||||||
max-width: 80%;
|
|
||||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
|
||||||
}
|
|
||||||
|
|
||||||
.user-bubble {
|
|
||||||
background-color: #f0f4ff;
|
|
||||||
color: #333;
|
|
||||||
border-top-right-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.bot-bubble {
|
|
||||||
background-color: #fff;
|
|
||||||
border: 1px solid #eaeaea;
|
|
||||||
color: #333;
|
|
||||||
border-top-left-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.user-avatar,
|
|
||||||
.bot-avatar {
|
|
||||||
margin-top: 2px;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 附件样式 */
|
|
||||||
.image-attachments {
|
|
||||||
display: flex;
|
|
||||||
gap: 8px;
|
|
||||||
margin-top: 8px;
|
|
||||||
flex-wrap: wrap;
|
|
||||||
}
|
|
||||||
|
|
||||||
.attached-image {
|
|
||||||
width: 120px;
|
|
||||||
height: 120px;
|
|
||||||
object-fit: cover;
|
|
||||||
border-radius: 8px;
|
|
||||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
|
||||||
transition: transform 0.2s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
.attached-image:hover {
|
|
||||||
transform: scale(1.05);
|
|
||||||
}
|
|
||||||
|
|
||||||
.audio-attachment {
|
|
||||||
margin-top: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.audio-player {
|
|
||||||
width: 100%;
|
|
||||||
height: 36px;
|
|
||||||
border-radius: 18px;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 对话详情卡片 */
|
/* 对话详情卡片 */
|
||||||
.conversation-detail-card {
|
.conversation-detail-card {
|
||||||
max-height: 90vh;
|
max-height: 90vh;
|
||||||
@@ -1029,95 +993,6 @@ export default {
|
|||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Markdown内容样式 */
|
|
||||||
.markdown-content {
|
|
||||||
font-family: inherit;
|
|
||||||
line-height: 1.6;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content h1,
|
|
||||||
.markdown-content h2,
|
|
||||||
.markdown-content h3,
|
|
||||||
.markdown-content h4,
|
|
||||||
.markdown-content h5,
|
|
||||||
.markdown-content h6 {
|
|
||||||
margin-top: 16px;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
font-weight: 600;
|
|
||||||
color: #333;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content h1 {
|
|
||||||
font-size: 1.8em;
|
|
||||||
border-bottom: 1px solid #eee;
|
|
||||||
padding-bottom: 6px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content h2 {
|
|
||||||
font-size: 1.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content h3 {
|
|
||||||
font-size: 1.3em;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content li {
|
|
||||||
margin-left: 16px;
|
|
||||||
margin-bottom: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content p {
|
|
||||||
margin-top: 10px;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content pre {
|
|
||||||
background-color: #f8f8f8;
|
|
||||||
padding: 12px;
|
|
||||||
border-radius: 6px;
|
|
||||||
overflow-x: auto;
|
|
||||||
margin: 12px 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content code {
|
|
||||||
background-color: #f5f0ff;
|
|
||||||
padding: 2px 4px;
|
|
||||||
border-radius: 4px;
|
|
||||||
font-family: 'Fira Code', monospace;
|
|
||||||
font-size: 0.9em;
|
|
||||||
color: #673ab7;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content img {
|
|
||||||
max-width: 100%;
|
|
||||||
border-radius: 8px;
|
|
||||||
margin: 10px 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content blockquote {
|
|
||||||
border-left: 4px solid #673ab7;
|
|
||||||
padding-left: 16px;
|
|
||||||
color: #666;
|
|
||||||
margin: 16px 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content table {
|
|
||||||
border-collapse: collapse;
|
|
||||||
width: 100%;
|
|
||||||
margin: 16px 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content th,
|
|
||||||
.markdown-content td {
|
|
||||||
border: 1px solid #eee;
|
|
||||||
padding: 8px 12px;
|
|
||||||
text-align: left;
|
|
||||||
}
|
|
||||||
|
|
||||||
.markdown-content th {
|
|
||||||
background-color: #f5f0ff;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 动画 */
|
/* 动画 */
|
||||||
@keyframes fadeIn {
|
@keyframes fadeIn {
|
||||||
from {
|
from {
|
||||||
|
|||||||
@@ -518,27 +518,12 @@ onMounted(async () => {
|
|||||||
<v-row>
|
<v-row>
|
||||||
<v-col cols="12" md="12">
|
<v-col cols="12" md="12">
|
||||||
<v-card variant="flat">
|
<v-card variant="flat">
|
||||||
<v-card-item>
|
|
||||||
<template v-slot:prepend>
|
|
||||||
<div class="plugin-page-icon d-flex justify-center align-center rounded-lg mr-4">
|
|
||||||
<v-icon size="36" color="primary">mdi-puzzle</v-icon>
|
|
||||||
</div>
|
|
||||||
</template>
|
|
||||||
<v-card-title class="text-h4 font-weight-bold">
|
|
||||||
{{ tm('title') }}
|
|
||||||
</v-card-title>
|
|
||||||
<v-card-subtitle class="text-subtitle-1 mt-1 text-medium-emphasis">
|
|
||||||
{{ tm('subtitle') }}
|
|
||||||
</v-card-subtitle>
|
|
||||||
</v-card-item>
|
|
||||||
|
|
||||||
<!-- 标签页 -->
|
<!-- 标签页 -->
|
||||||
<v-card-text>
|
<v-card-text>
|
||||||
|
|
||||||
<!-- 标签栏和搜索栏 - 响应式布局 -->
|
<!-- 标签栏和搜索栏 - 响应式布局 -->
|
||||||
<div class="mb-4">
|
<div class="mb-4 d-flex flex-wrap">
|
||||||
<!-- 标签栏 -->
|
<!-- 标签栏 -->
|
||||||
<v-tabs v-model="activeTab" color="primary" class="mb-3">
|
<v-tabs v-model="activeTab" color="primary">
|
||||||
<v-tab value="installed">
|
<v-tab value="installed">
|
||||||
<v-icon class="mr-2">mdi-puzzle</v-icon>
|
<v-icon class="mr-2">mdi-puzzle</v-icon>
|
||||||
{{ tm('tabs.installed') }}
|
{{ tm('tabs.installed') }}
|
||||||
@@ -550,8 +535,7 @@ onMounted(async () => {
|
|||||||
</v-tabs>
|
</v-tabs>
|
||||||
|
|
||||||
<!-- 搜索栏 - 在移动端时独占一行 -->
|
<!-- 搜索栏 - 在移动端时独占一行 -->
|
||||||
<v-row class="mb-2">
|
<div style="flex-grow: 1; min-width: 250px; max-width: 400px; margin-left: auto; margin-top: 8px;">
|
||||||
<v-col cols="12" sm="6" md="4" lg="3">
|
|
||||||
<v-text-field v-if="activeTab == 'market'" v-model="marketSearch" density="compact"
|
<v-text-field v-if="activeTab == 'market'" v-model="marketSearch" density="compact"
|
||||||
:label="tm('search.marketPlaceholder')" prepend-inner-icon="mdi-magnify" variant="solo-filled" flat
|
:label="tm('search.marketPlaceholder')" prepend-inner-icon="mdi-magnify" variant="solo-filled" flat
|
||||||
hide-details single-line>
|
hide-details single-line>
|
||||||
@@ -559,8 +543,8 @@ onMounted(async () => {
|
|||||||
<v-text-field v-else v-model="pluginSearch" density="compact" :label="tm('search.placeholder')"
|
<v-text-field v-else v-model="pluginSearch" density="compact" :label="tm('search.placeholder')"
|
||||||
prepend-inner-icon="mdi-magnify" variant="solo-filled" flat hide-details single-line>
|
prepend-inner-icon="mdi-magnify" variant="solo-filled" flat hide-details single-line>
|
||||||
</v-text-field>
|
</v-text-field>
|
||||||
</v-col>
|
</div>
|
||||||
</v-row>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
@@ -776,18 +760,13 @@ onMounted(async () => {
|
|||||||
<div class="d-flex align-center mb-2" style="justify-content: space-between;">
|
<div class="d-flex align-center mb-2" style="justify-content: space-between;">
|
||||||
<h2>{{ tm('market.allPlugins') }}</h2>
|
<h2>{{ tm('market.allPlugins') }}</h2>
|
||||||
<div class="d-flex align-center">
|
<div class="d-flex align-center">
|
||||||
<v-btn
|
<v-btn variant="tonal" size="small" @click="refreshPluginMarket" :loading="refreshingMarket"
|
||||||
variant="tonal"
|
class="mr-2">
|
||||||
size="small"
|
|
||||||
@click="refreshPluginMarket"
|
|
||||||
:loading="refreshingMarket"
|
|
||||||
class="mr-2"
|
|
||||||
>
|
|
||||||
<v-icon>mdi-refresh</v-icon>
|
<v-icon>mdi-refresh</v-icon>
|
||||||
{{ tm('buttons.refresh') }}
|
{{ tm('buttons.refresh') }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-switch v-model="showPluginFullName" :label="tm('market.showFullName')" hide-details density="compact"
|
<v-switch v-model="showPluginFullName" :label="tm('market.showFullName')" hide-details
|
||||||
style="margin-left: 12px" />
|
density="compact" style="margin-left: 12px" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -827,7 +806,7 @@ onMounted(async () => {
|
|||||||
<template v-slot:item.tags="{ item }">
|
<template v-slot:item.tags="{ item }">
|
||||||
<span v-if="item.tags.length === 0">-</span>
|
<span v-if="item.tags.length === 0">-</span>
|
||||||
<v-chip v-for="tag in item.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'"
|
<v-chip v-for="tag in item.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'"
|
||||||
size="x-small" v-show="tag !== 'danger'">
|
size="x-small" v-show="tag !== 'danger'" class="ma-1">
|
||||||
{{ tag }}</v-chip>
|
{{ tag }}</v-chip>
|
||||||
</template>
|
</template>
|
||||||
<template v-slot:item.actions="{ item }">
|
<template v-slot:item.actions="{ item }">
|
||||||
|
|||||||
@@ -23,7 +23,7 @@
|
|||||||
<!-- 人格卡片网格 -->
|
<!-- 人格卡片网格 -->
|
||||||
<v-row>
|
<v-row>
|
||||||
<v-col v-for="persona in personas" :key="persona.persona_id" cols="12" md="6" lg="4" xl="3">
|
<v-col v-for="persona in personas" :key="persona.persona_id" cols="12" md="6" lg="4" xl="3">
|
||||||
<v-card class="persona-card" elevation="2" rounded="lg" @click="viewPersona(persona)">
|
<v-card class="persona-card" rounded="md" @click="viewPersona(persona)">
|
||||||
<v-card-title class="d-flex justify-space-between align-center">
|
<v-card-title class="d-flex justify-space-between align-center">
|
||||||
<div class="text-truncate ml-2">
|
<div class="text-truncate ml-2">
|
||||||
{{ persona.persona_id }}
|
{{ persona.persona_id }}
|
||||||
@@ -296,9 +296,9 @@
|
|||||||
<v-card-text>
|
<v-card-text>
|
||||||
<div class="mb-4">
|
<div class="mb-4">
|
||||||
<h4 class="text-h6 mb-2">{{ tm('form.systemPrompt') }}</h4>
|
<h4 class="text-h6 mb-2">{{ tm('form.systemPrompt') }}</h4>
|
||||||
<div class="system-prompt-content">
|
<pre class="system-prompt-content">
|
||||||
{{ viewingPersona.system_prompt }}
|
{{ viewingPersona.system_prompt }}
|
||||||
</div>
|
</pre>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div v-if="viewingPersona.begin_dialogs && viewingPersona.begin_dialogs.length > 0" class="mb-4">
|
<div v-if="viewingPersona.begin_dialogs && viewingPersona.begin_dialogs.length > 0" class="mb-4">
|
||||||
@@ -759,10 +759,6 @@ export default {
|
|||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
.persona-card:hover {
|
|
||||||
box-shadow: 0 8px 25px 0 rgba(0, 0, 0, 0.15);
|
|
||||||
}
|
|
||||||
|
|
||||||
.system-prompt-preview {
|
.system-prompt-preview {
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
line-height: 1.4;
|
line-height: 1.4;
|
||||||
@@ -775,10 +771,10 @@ export default {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.system-prompt-content {
|
.system-prompt-content {
|
||||||
background-color: rgba(var(--v-theme-surface-variant), 0.3);
|
max-height: 400px;
|
||||||
|
overflow: auto;
|
||||||
padding: 12px;
|
padding: 12px;
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
font-family: 'Roboto Mono', monospace;
|
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
line-height: 1.5;
|
line-height: 1.5;
|
||||||
white-space: pre-wrap;
|
white-space: pre-wrap;
|
||||||
|
|||||||
@@ -10,7 +10,8 @@
|
|||||||
{{ tm('subtitle') }}
|
{{ tm('subtitle') }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddPlatformDialog = true" rounded="xl" size="x-large">
|
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddPlatformDialog = true"
|
||||||
|
rounded="xl" size="x-large">
|
||||||
{{ tm('addAdapter') }}
|
{{ tm('addAdapter') }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
</v-row>
|
</v-row>
|
||||||
@@ -25,14 +26,9 @@
|
|||||||
|
|
||||||
<v-row v-else>
|
<v-row v-else>
|
||||||
<v-col v-for="(platform, index) in config_data.platform || []" :key="index" cols="12" md="6" lg="4" xl="3">
|
<v-col v-for="(platform, index) in config_data.platform || []" :key="index" cols="12" md="6" lg="4" xl="3">
|
||||||
<item-card
|
<item-card :item="platform" title-field="id" enabled-field="enable"
|
||||||
:item="platform"
|
:bglogo="getPlatformIcon(platform.type || platform.id)" @toggle-enabled="platformStatusChange"
|
||||||
title-field="id"
|
@delete="deletePlatform" @edit="editPlatform">
|
||||||
enabled-field="enable"
|
|
||||||
:bglogo="getPlatformIcon(platform.type || platform.id)"
|
|
||||||
@toggle-enabled="platformStatusChange"
|
|
||||||
@delete="deletePlatform"
|
|
||||||
@edit="editPlatform">
|
|
||||||
</item-card>
|
</item-card>
|
||||||
</v-col>
|
</v-col>
|
||||||
</v-row>
|
</v-row>
|
||||||
@@ -61,59 +57,13 @@
|
|||||||
</v-container>
|
</v-container>
|
||||||
|
|
||||||
<!-- 添加平台适配器对话框 -->
|
<!-- 添加平台适配器对话框 -->
|
||||||
<v-dialog v-model="showAddPlatformDialog" max-width="900px" min-height="80%">
|
<AddNewPlatform v-model:show="showAddPlatformDialog" :metadata="metadata"
|
||||||
<v-card class="platform-selection-dialog">
|
@select-template="selectPlatformTemplate" />
|
||||||
<v-card-title class="bg-primary text-white py-3 px-4" style="display: flex; align-items: center;">
|
|
||||||
<v-icon color="white" class="me-2">mdi-plus-circle</v-icon>
|
|
||||||
<span>{{ tm('dialog.addPlatform') }}</span>
|
|
||||||
<v-spacer></v-spacer>
|
|
||||||
<v-btn icon variant="text" color="white" @click="showAddPlatformDialog = false">
|
|
||||||
<v-icon>mdi-close</v-icon>
|
|
||||||
</v-btn>
|
|
||||||
</v-card-title>
|
|
||||||
|
|
||||||
<v-card-text class="pa-4" style="overflow-y: auto;">
|
|
||||||
<v-row class="mt-1">
|
|
||||||
<v-col v-for="(template, name) in metadata['platform_group']?.metadata?.platform?.config_template || {}"
|
|
||||||
:key="name" cols="12" sm="6" md="6">
|
|
||||||
<v-card variant="outlined" hover class="platform-card" @click="selectPlatformTemplate(name)">
|
|
||||||
<div class="platform-card-content">
|
|
||||||
<div class="platform-card-text">
|
|
||||||
<v-card-title class="platform-card-title">{{ tm('dialog.connectTitle', { name }) }}</v-card-title>
|
|
||||||
<v-card-text class="text-caption text-medium-emphasis platform-card-description">
|
|
||||||
{{ getPlatformDescription(template, name) }}
|
|
||||||
</v-card-text>
|
|
||||||
</div>
|
|
||||||
<div class="platform-card-logo">
|
|
||||||
<img :src="getPlatformIcon(template.type)" v-if="getPlatformIcon(template.type)" class="platform-logo-img">
|
|
||||||
<div v-else class="platform-logo-fallback">
|
|
||||||
{{ name[0].toUpperCase() }}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</v-card>
|
|
||||||
</v-col>
|
|
||||||
<v-col
|
|
||||||
v-if="Object.keys(metadata['platform_group']?.metadata?.platform?.config_template || {}).length === 0"
|
|
||||||
cols="12">
|
|
||||||
<v-alert type="info" variant="tonal">
|
|
||||||
{{ tm('dialog.noTemplates') }}
|
|
||||||
</v-alert>
|
|
||||||
</v-col>
|
|
||||||
</v-row>
|
|
||||||
</v-card-text>
|
|
||||||
</v-card>
|
|
||||||
</v-dialog>
|
|
||||||
|
|
||||||
<!-- 配置对话框 -->
|
<!-- 配置对话框 -->
|
||||||
<v-dialog v-model="showPlatformCfg" persistent width="900px" max-width="90%">
|
<v-dialog v-model="showPlatformCfg" persistent width="900px" max-width="90%">
|
||||||
<v-card>
|
<v-card
|
||||||
<v-card-title class="bg-primary text-white py-3">
|
:title="updatingMode ? tm('dialog.edit') : tm('dialog.add') + ` ${newSelectedPlatformName} ` + tm('dialog.adapter')">
|
||||||
<v-icon color="white" class="me-2">{{ updatingMode ? 'mdi-pencil' : 'mdi-plus' }}</v-icon>
|
|
||||||
<span>{{ updatingMode ? tm('dialog.edit') : tm('dialog.add') }} {{ newSelectedPlatformName }} {{
|
|
||||||
tm('dialog.adapter') }}</span>
|
|
||||||
</v-card-title>
|
|
||||||
|
|
||||||
<v-card-text class="py-4">
|
<v-card-text class="py-4">
|
||||||
<v-row>
|
<v-row>
|
||||||
<v-col cols="12">
|
<v-col cols="12">
|
||||||
@@ -177,7 +127,9 @@
|
|||||||
</v-card-title>
|
</v-card-title>
|
||||||
<v-card-text class="py-4">
|
<v-card-text class="py-4">
|
||||||
<p>{{ tm('dialog.securityWarning.aiocqhttpTokenMissing') }}</p>
|
<p>{{ tm('dialog.securityWarning.aiocqhttpTokenMissing') }}</p>
|
||||||
<span><a href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7" target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
|
<span><a
|
||||||
|
href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7"
|
||||||
|
target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
|
||||||
</v-card-text>
|
</v-card-text>
|
||||||
<v-card-actions class="px-4 pb-4">
|
<v-card-actions class="px-4 pb-4">
|
||||||
<v-spacer></v-spacer>
|
<v-spacer></v-spacer>
|
||||||
@@ -199,8 +151,10 @@ import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
|||||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||||
import ItemCard from '@/components/shared/ItemCard.vue';
|
import ItemCard from '@/components/shared/ItemCard.vue';
|
||||||
|
import AddNewPlatform from '@/components/platform/AddNewPlatform.vue';
|
||||||
import { useCommonStore } from '@/stores/common';
|
import { useCommonStore } from '@/stores/common';
|
||||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||||
|
import { getPlatformIcon, getTutorialLink } from '@/utils/platformUtils';
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: 'PlatformPage',
|
name: 'PlatformPage',
|
||||||
@@ -208,7 +162,8 @@ export default {
|
|||||||
AstrBotConfig,
|
AstrBotConfig,
|
||||||
WaitingForRestart,
|
WaitingForRestart,
|
||||||
ConsoleDisplayer,
|
ConsoleDisplayer,
|
||||||
ItemCard
|
ItemCard,
|
||||||
|
AddNewPlatform
|
||||||
},
|
},
|
||||||
setup() {
|
setup() {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
@@ -285,66 +240,14 @@ export default {
|
|||||||
},
|
},
|
||||||
|
|
||||||
methods: {
|
methods: {
|
||||||
|
// 从工具函数导入
|
||||||
|
getPlatformIcon,
|
||||||
|
|
||||||
openTutorial() {
|
openTutorial() {
|
||||||
const tutorialUrl = this.getTutorialLink(this.newSelectedPlatformConfig.type);
|
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);
|
||||||
window.open(tutorialUrl, '_blank');
|
window.open(tutorialUrl, '_blank');
|
||||||
},
|
},
|
||||||
|
|
||||||
getPlatformIcon(name) {
|
|
||||||
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
|
|
||||||
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
|
||||||
} else if (name === 'wecom') {
|
|
||||||
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
|
||||||
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
|
||||||
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
|
||||||
} else if (name === 'lark') {
|
|
||||||
return new URL('@/assets/images/platform_logos/lark.png', import.meta.url).href
|
|
||||||
} else if (name === 'dingtalk') {
|
|
||||||
return new URL('@/assets/images/platform_logos/dingtalk.svg', import.meta.url).href
|
|
||||||
} else if (name === 'telegram') {
|
|
||||||
return new URL('@/assets/images/platform_logos/telegram.svg', import.meta.url).href
|
|
||||||
} else if (name === 'discord') {
|
|
||||||
return new URL('@/assets/images/platform_logos/discord.svg', import.meta.url).href
|
|
||||||
} else if (name === 'slack') {
|
|
||||||
return new URL('@/assets/images/platform_logos/slack.svg', import.meta.url).href
|
|
||||||
} else if (name === 'kook') {
|
|
||||||
return new URL('@/assets/images/platform_logos/kook.png', import.meta.url).href
|
|
||||||
} else if (name === 'vocechat') {
|
|
||||||
return new URL('@/assets/images/platform_logos/vocechat.png', import.meta.url).href
|
|
||||||
} else if (name === 'satori' || name === 'Satori') {
|
|
||||||
return new URL('@/assets/images/platform_logos/satori.png', import.meta.url).href
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
getTutorialLink(platform_type) {
|
|
||||||
let tutorial_map = {
|
|
||||||
"qq_official_webhook": "https://docs.astrbot.app/deploy/platform/qqofficial/webhook.html",
|
|
||||||
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
|
|
||||||
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
|
||||||
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
|
|
||||||
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
|
|
||||||
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
|
|
||||||
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
|
|
||||||
"wechatpadpro": "https://docs.astrbot.app/deploy/platform/wechat/wechatpadpro.html",
|
|
||||||
"weixin_official_account": "https://docs.astrbot.app/deploy/platform/weixin-official-account.html",
|
|
||||||
"discord": "https://docs.astrbot.app/deploy/platform/discord.html",
|
|
||||||
"slack": "https://docs.astrbot.app/deploy/platform/slack.html",
|
|
||||||
"kook": "https://docs.astrbot.app/deploy/platform/kook.html",
|
|
||||||
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
|
|
||||||
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
|
|
||||||
}
|
|
||||||
return tutorial_map[platform_type] || "https://docs.astrbot.app";
|
|
||||||
},
|
|
||||||
|
|
||||||
getPlatformDescription(template, name) {
|
|
||||||
// special judge for community platforms
|
|
||||||
if (name.includes('vocechat')) {
|
|
||||||
return "由 @HikariFroya 提供。";
|
|
||||||
} else if (name.includes('kook')) {
|
|
||||||
return "由 @wuyan1003 提供。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
getConfig() {
|
getConfig() {
|
||||||
axios.get('/api/config/get').then((res) => {
|
axios.get('/api/config/get').then((res) => {
|
||||||
this.config_data = res.data.data.config;
|
this.config_data = res.data.data.config;
|
||||||
@@ -355,7 +258,7 @@ export default {
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
// 添加一个新方法来选择平台模板
|
// 选择平台模板
|
||||||
selectPlatformTemplate(name) {
|
selectPlatformTemplate(name) {
|
||||||
this.newSelectedPlatformName = name;
|
this.newSelectedPlatformName = name;
|
||||||
this.showPlatformCfg = true;
|
this.showPlatformCfg = true;
|
||||||
@@ -363,7 +266,6 @@ export default {
|
|||||||
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(
|
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(
|
||||||
this.metadata['platform_group']?.metadata?.platform?.config_template[name] || {}
|
this.metadata['platform_group']?.metadata?.platform?.config_template[name] || {}
|
||||||
));
|
));
|
||||||
this.showAddPlatformDialog = false;
|
|
||||||
},
|
},
|
||||||
|
|
||||||
addFromDefaultConfigTmpl(index) {
|
addFromDefaultConfigTmpl(index) {
|
||||||
@@ -532,84 +434,4 @@ export default {
|
|||||||
padding: 20px;
|
padding: 20px;
|
||||||
padding-top: 8px;
|
padding-top: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.platform-selection-dialog .v-card-title {
|
|
||||||
border-top-left-radius: 4px;
|
|
||||||
border-top-right-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-card {
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
height: 100%;
|
|
||||||
cursor: pointer;
|
|
||||||
overflow: hidden;
|
|
||||||
position: relative;
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-card:hover {
|
|
||||||
transform: translateY(-4px);
|
|
||||||
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
|
|
||||||
border-color: var(--v-primary-base);
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-card-content {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
height: 100px;
|
|
||||||
padding: 16px;
|
|
||||||
position: relative;
|
|
||||||
z-index: 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-card-text {
|
|
||||||
flex: 1;
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
justify-content: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-card-title {
|
|
||||||
font-size: 15px;
|
|
||||||
font-weight: 600;
|
|
||||||
margin-bottom: 4px;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-card-description {
|
|
||||||
padding: 0;
|
|
||||||
margin: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-card-logo {
|
|
||||||
position: absolute;
|
|
||||||
right: 0;
|
|
||||||
top: 0;
|
|
||||||
bottom: 0;
|
|
||||||
width: 80px;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
z-index: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-logo-img {
|
|
||||||
max-width: 60px;
|
|
||||||
max-height: 60px;
|
|
||||||
opacity: 0.6;
|
|
||||||
object-fit: contain;
|
|
||||||
}
|
|
||||||
|
|
||||||
.platform-logo-fallback {
|
|
||||||
width: 50px;
|
|
||||||
height: 50px;
|
|
||||||
border-radius: 50%;
|
|
||||||
background-color: var(--v-primary-base);
|
|
||||||
color: white;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
font-size: 24px;
|
|
||||||
font-weight: bold;
|
|
||||||
opacity: 0.3;
|
|
||||||
}
|
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -63,7 +63,9 @@
|
|||||||
@toggle-enabled="providerStatusChange"
|
@toggle-enabled="providerStatusChange"
|
||||||
:bglogo="getProviderIcon(provider.provider)"
|
:bglogo="getProviderIcon(provider.provider)"
|
||||||
@delete="deleteProvider"
|
@delete="deleteProvider"
|
||||||
@edit="configExistingProvider">
|
@edit="configExistingProvider"
|
||||||
|
@copy="copyProvider"
|
||||||
|
:show-copy-button="true">
|
||||||
<template v-slot:details="{ item }">
|
<template v-slot:details="{ item }">
|
||||||
</template>
|
</template>
|
||||||
</item-card>
|
</item-card>
|
||||||
@@ -153,86 +155,15 @@
|
|||||||
</v-container>
|
</v-container>
|
||||||
|
|
||||||
<!-- 添加提供商对话框 -->
|
<!-- 添加提供商对话框 -->
|
||||||
<v-dialog v-model="showAddProviderDialog" max-width="1100px" min-height="95%">
|
<AddNewProvider
|
||||||
<v-card class="provider-selection-dialog">
|
v-model:show="showAddProviderDialog"
|
||||||
<v-card-title class="bg-primary text-white py-3 px-4" style="display: flex; align-items: center;">
|
:metadata="metadata"
|
||||||
<v-icon color="white" class="me-2">mdi-plus-circle</v-icon>
|
@select-template="selectProviderTemplate"
|
||||||
<span>{{ tm('dialogs.addProvider.title') }}</span>
|
/>
|
||||||
<v-spacer></v-spacer>
|
|
||||||
<v-btn icon variant="text" color="white" @click="showAddProviderDialog = false">
|
|
||||||
<v-icon>mdi-close</v-icon>
|
|
||||||
</v-btn>
|
|
||||||
</v-card-title>
|
|
||||||
|
|
||||||
<v-card-text class="pa-4" style="overflow-y: auto;">
|
|
||||||
<v-tabs v-model="activeProviderTab" grow slider-color="primary" bg-color="background">
|
|
||||||
<v-tab value="chat_completion" class="font-weight-medium px-3">
|
|
||||||
<v-icon start>mdi-message-text</v-icon>
|
|
||||||
{{ tm('dialogs.addProvider.tabs.basic') }}
|
|
||||||
</v-tab>
|
|
||||||
<v-tab value="speech_to_text" class="font-weight-medium px-3">
|
|
||||||
<v-icon start>mdi-microphone-message</v-icon>
|
|
||||||
{{ tm('dialogs.addProvider.tabs.speechToText') }}
|
|
||||||
</v-tab>
|
|
||||||
<v-tab value="text_to_speech" class="font-weight-medium px-3">
|
|
||||||
<v-icon start>mdi-volume-high</v-icon>
|
|
||||||
{{ tm('dialogs.addProvider.tabs.textToSpeech') }}
|
|
||||||
</v-tab>
|
|
||||||
<v-tab value="embedding" class="font-weight-medium px-3">
|
|
||||||
<v-icon start>mdi-code-json</v-icon>
|
|
||||||
{{ tm('dialogs.addProvider.tabs.embedding') }}
|
|
||||||
</v-tab>
|
|
||||||
<v-tab value="rerank" class="font-weight-medium px-3">
|
|
||||||
<v-icon start>mdi-compare-vertical</v-icon>
|
|
||||||
{{ tm('dialogs.addProvider.tabs.rerank') }}
|
|
||||||
</v-tab>
|
|
||||||
</v-tabs>
|
|
||||||
|
|
||||||
<v-window v-model="activeProviderTab" class="mt-4">
|
|
||||||
<v-window-item v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
|
|
||||||
:key="tabType"
|
|
||||||
:value="tabType">
|
|
||||||
<v-row class="mt-1">
|
|
||||||
<v-col v-for="(template, name) in getTemplatesByType(tabType)"
|
|
||||||
:key="name"
|
|
||||||
cols="12" sm="6" md="4">
|
|
||||||
<v-card variant="outlined" hover class="provider-card" @click="selectProviderTemplate(name)">
|
|
||||||
<div class="provider-card-content">
|
|
||||||
<div class="provider-card-text">
|
|
||||||
<v-card-title class="provider-card-title">接入 {{ name }}</v-card-title>
|
|
||||||
<v-card-text class="text-caption text-medium-emphasis provider-card-description">
|
|
||||||
{{ getProviderDescription(template, name) }}
|
|
||||||
</v-card-text>
|
|
||||||
</div>
|
|
||||||
<div class="provider-card-logo">
|
|
||||||
<img :src="getProviderIcon(template.provider)" v-if="getProviderIcon(template.provider)" class="provider-logo-img">
|
|
||||||
<div v-else class="provider-logo-fallback">
|
|
||||||
{{ name[0].toUpperCase() }}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</v-card>
|
|
||||||
</v-col>
|
|
||||||
<v-col v-if="Object.keys(getTemplatesByType(tabType)).length === 0" cols="12">
|
|
||||||
<v-alert type="info" variant="tonal">
|
|
||||||
{{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
|
|
||||||
</v-alert>
|
|
||||||
</v-col>
|
|
||||||
</v-row>
|
|
||||||
</v-window-item>
|
|
||||||
</v-window>
|
|
||||||
</v-card-text>
|
|
||||||
</v-card>
|
|
||||||
</v-dialog>
|
|
||||||
|
|
||||||
<!-- 配置对话框 -->
|
<!-- 配置对话框 -->
|
||||||
<v-dialog v-model="showProviderCfg" width="900" persistent>
|
<v-dialog v-model="showProviderCfg" width="900" persistent>
|
||||||
<v-card>
|
<v-card :title="updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') + ` ${newSelectedProviderName} ` + tm('dialogs.config.provider')">
|
||||||
<v-card-title class="bg-primary text-white py-3">
|
|
||||||
<v-icon color="white" class="me-2">{{ updatingMode ? 'mdi-pencil' : 'mdi-plus' }}</v-icon>
|
|
||||||
<span>{{ updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') }} {{ newSelectedProviderName }} {{ tm('dialogs.config.provider') }}</span>
|
|
||||||
</v-card-title>
|
|
||||||
|
|
||||||
<v-card-text class="py-4">
|
<v-card-text class="py-4">
|
||||||
<AstrBotConfig
|
<AstrBotConfig
|
||||||
:iterable="newSelectedProviderConfig"
|
:iterable="newSelectedProviderConfig"
|
||||||
@@ -307,7 +238,9 @@ import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
|||||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||||
import ItemCard from '@/components/shared/ItemCard.vue';
|
import ItemCard from '@/components/shared/ItemCard.vue';
|
||||||
|
import AddNewProvider from '@/components/provider/AddNewProvider.vue';
|
||||||
import { useModuleI18n } from '@/i18n/composables';
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
|
import { getProviderIcon } from '@/utils/providerUtils';
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: 'ProviderPage',
|
name: 'ProviderPage',
|
||||||
@@ -315,7 +248,8 @@ export default {
|
|||||||
AstrBotConfig,
|
AstrBotConfig,
|
||||||
WaitingForRestart,
|
WaitingForRestart,
|
||||||
ConsoleDisplayer,
|
ConsoleDisplayer,
|
||||||
ItemCard
|
ItemCard,
|
||||||
|
AddNewProvider
|
||||||
},
|
},
|
||||||
setup() {
|
setup() {
|
||||||
const { tm } = useModuleI18n('features/provider');
|
const { tm } = useModuleI18n('features/provider');
|
||||||
@@ -358,7 +292,6 @@ export default {
|
|||||||
|
|
||||||
// 新增提供商对话框相关
|
// 新增提供商对话框相关
|
||||||
showAddProviderDialog: false,
|
showAddProviderDialog: false,
|
||||||
activeProviderTab: 'chat_completion',
|
|
||||||
|
|
||||||
// 添加提供商类型分类
|
// 添加提供商类型分类
|
||||||
activeProviderTypeTab: 'all',
|
activeProviderTypeTab: 'all',
|
||||||
@@ -370,6 +303,7 @@ export default {
|
|||||||
"googlegenai_chat_completion": "chat_completion",
|
"googlegenai_chat_completion": "chat_completion",
|
||||||
"zhipu_chat_completion": "chat_completion",
|
"zhipu_chat_completion": "chat_completion",
|
||||||
"dify": "chat_completion",
|
"dify": "chat_completion",
|
||||||
|
"coze": "chat_completion",
|
||||||
"dashscope": "chat_completion",
|
"dashscope": "chat_completion",
|
||||||
"openai_whisper_api": "speech_to_text",
|
"openai_whisper_api": "speech_to_text",
|
||||||
"openai_whisper_selfhost": "speech_to_text",
|
"openai_whisper_selfhost": "speech_to_text",
|
||||||
@@ -472,6 +406,9 @@ export default {
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// 从工具函数导入
|
||||||
|
getProviderIcon,
|
||||||
|
|
||||||
// 获取空列表文本
|
// 获取空列表文本
|
||||||
getEmptyText() {
|
getEmptyText() {
|
||||||
if (this.activeProviderTypeTab === 'all') {
|
if (this.activeProviderTypeTab === 'all') {
|
||||||
@@ -481,63 +418,11 @@ export default {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
// 按提供商类型获取模板列表
|
|
||||||
getTemplatesByType(type) {
|
|
||||||
const templates = this.metadata['provider_group']?.metadata?.provider?.config_template || {};
|
|
||||||
const filtered = {};
|
|
||||||
|
|
||||||
for (const [name, template] of Object.entries(templates)) {
|
|
||||||
if (template.provider_type === type) {
|
|
||||||
filtered[name] = template;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return filtered;
|
|
||||||
},
|
|
||||||
|
|
||||||
// 获取提供商类型对应的图标
|
|
||||||
getProviderIcon(type) {
|
|
||||||
const icons = {
|
|
||||||
'openai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
|
|
||||||
'azure': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/azure.svg',
|
|
||||||
'xai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/xai.svg',
|
|
||||||
'anthropic': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/anthropic.svg',
|
|
||||||
'ollama': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ollama.svg',
|
|
||||||
'google': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
|
|
||||||
'deepseek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
|
|
||||||
'modelscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/modelscope.svg',
|
|
||||||
'zhipu': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
|
|
||||||
'siliconflow': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
|
|
||||||
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
|
|
||||||
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
|
|
||||||
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
|
|
||||||
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
|
|
||||||
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
|
|
||||||
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
|
|
||||||
'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',
|
|
||||||
'minimax': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/minimax.svg',
|
|
||||||
'302ai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
|
|
||||||
'microsoft': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/microsoft.svg',
|
|
||||||
'vllm': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/vllm.svg',
|
|
||||||
};
|
|
||||||
return icons[type] || '';
|
|
||||||
},
|
|
||||||
|
|
||||||
// 获取Tab类型的中文名称
|
// 获取Tab类型的中文名称
|
||||||
getTabTypeName(tabType) {
|
getTabTypeName(tabType) {
|
||||||
return this.messages.tabTypes[tabType] || tabType;
|
return this.messages.tabTypes[tabType] || tabType;
|
||||||
},
|
},
|
||||||
|
|
||||||
// 获取提供商简介
|
|
||||||
getProviderDescription(template, name) {
|
|
||||||
if (name == 'OpenAI') {
|
|
||||||
return this.tm('providers.description.openai', { type: template.type });
|
|
||||||
} else if (name == 'vLLM Rerank') {
|
|
||||||
return this.tm('providers.description.vllm_rerank', { type: template.type });
|
|
||||||
}
|
|
||||||
return this.tm('providers.description.default', { type: template.type });
|
|
||||||
},
|
|
||||||
|
|
||||||
// 选择提供商模板
|
// 选择提供商模板
|
||||||
selectProviderTemplate(name) {
|
selectProviderTemplate(name) {
|
||||||
this.newSelectedProviderName = name;
|
this.newSelectedProviderName = name;
|
||||||
@@ -546,7 +431,6 @@ export default {
|
|||||||
this.newSelectedProviderConfig = JSON.parse(JSON.stringify(
|
this.newSelectedProviderConfig = JSON.parse(JSON.stringify(
|
||||||
this.metadata['provider_group']?.metadata?.provider?.config_template[name] || {}
|
this.metadata['provider_group']?.metadata?.provider?.config_template[name] || {}
|
||||||
));
|
));
|
||||||
this.showAddProviderDialog = false;
|
|
||||||
},
|
},
|
||||||
|
|
||||||
configExistingProvider(provider) {
|
configExistingProvider(provider) {
|
||||||
@@ -657,6 +541,40 @@ export default {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async copyProvider(providerToCopy) {
|
||||||
|
console.log('copyProvider triggered for:', providerToCopy);
|
||||||
|
// 1. 创建深拷贝
|
||||||
|
const newProviderConfig = JSON.parse(JSON.stringify(providerToCopy));
|
||||||
|
|
||||||
|
// 2. 生成唯一的 ID
|
||||||
|
const generateUniqueId = (baseId) => {
|
||||||
|
let newId = `${baseId}_copy`;
|
||||||
|
let counter = 1;
|
||||||
|
const existingIds = this.config_data.provider.map(p => p.id);
|
||||||
|
while (existingIds.includes(newId)) {
|
||||||
|
newId = `${baseId}_copy_${counter}`;
|
||||||
|
counter++;
|
||||||
|
}
|
||||||
|
return newId;
|
||||||
|
};
|
||||||
|
newProviderConfig.id = generateUniqueId(providerToCopy.id);
|
||||||
|
|
||||||
|
// 3. 设置为禁用状态,等待用户手动开启
|
||||||
|
newProviderConfig.enable = false;
|
||||||
|
|
||||||
|
this.loading = true;
|
||||||
|
try {
|
||||||
|
// 4. 调用后端接口创建
|
||||||
|
const res = await axios.post('/api/config/provider/new', newProviderConfig);
|
||||||
|
this.showSuccess(res.data.message || `成功复制并创建了 ${newProviderConfig.id}`);
|
||||||
|
this.getConfig(); // 5. 刷新列表
|
||||||
|
} catch (err) {
|
||||||
|
this.showError(err.response?.data?.message || err.message);
|
||||||
|
} finally {
|
||||||
|
this.loading = false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
deleteProvider(provider) {
|
deleteProvider(provider) {
|
||||||
if (confirm(this.tm('messages.confirm.delete', { id: provider.id }))) {
|
if (confirm(this.tm('messages.confirm.delete', { id: provider.id }))) {
|
||||||
axios.post('/api/config/provider/delete', { id: provider.id }).then((res) => {
|
axios.post('/api/config/provider/delete', { id: provider.id }).then((res) => {
|
||||||
@@ -818,89 +736,6 @@ export default {
|
|||||||
padding-top: 8px;
|
padding-top: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.provider-card {
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
height: 100%;
|
|
||||||
cursor: pointer;
|
|
||||||
overflow: hidden;
|
|
||||||
position: relative;
|
|
||||||
}
|
|
||||||
|
|
||||||
.provider-card:hover {
|
|
||||||
transform: translateY(-4px);
|
|
||||||
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
|
|
||||||
border-color: var(--v-primary-base);
|
|
||||||
}
|
|
||||||
|
|
||||||
.provider-card-content {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
height: 100px;
|
|
||||||
padding: 16px;
|
|
||||||
position: relative;
|
|
||||||
z-index: 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
.provider-card-text {
|
|
||||||
flex: 1;
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
justify-content: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.provider-card-title {
|
|
||||||
font-size: 15px;
|
|
||||||
font-weight: 600;
|
|
||||||
margin-bottom: 4px;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.provider-card-description {
|
|
||||||
padding: 0;
|
|
||||||
margin: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.provider-card-logo {
|
|
||||||
position: absolute;
|
|
||||||
right: 0;
|
|
||||||
top: 0;
|
|
||||||
bottom: 0;
|
|
||||||
width: 80px;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
z-index: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
.provider-logo-img {
|
|
||||||
width: 60px;
|
|
||||||
height: 60px;
|
|
||||||
opacity: 0.6;
|
|
||||||
object-fit: contain;
|
|
||||||
}
|
|
||||||
|
|
||||||
.provider-logo-fallback {
|
|
||||||
width: 50px;
|
|
||||||
height: 50px;
|
|
||||||
border-radius: 50%;
|
|
||||||
background-color: var(--v-primary-base);
|
|
||||||
color: white;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
font-size: 24px;
|
|
||||||
font-weight: bold;
|
|
||||||
opacity: 0.3;
|
|
||||||
}
|
|
||||||
|
|
||||||
.v-tabs {
|
|
||||||
border-radius: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.v-window {
|
|
||||||
border-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-card {
|
.status-card {
|
||||||
height: 120px;
|
height: 120px;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,5 @@
|
|||||||
<template>
|
<template>
|
||||||
<div class="dashboard-container">
|
<div class="dashboard-container">
|
||||||
<div class="dashboard-header">
|
|
||||||
<h1 class="dashboard-title">{{ t('title') }}</h1>
|
|
||||||
<div class="dashboard-subtitle">{{ t('subtitle') }}</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<v-slide-y-transition>
|
<v-slide-y-transition>
|
||||||
<v-row v-if="noticeTitle && noticeContent" class="notice-row">
|
<v-row v-if="noticeTitle && noticeContent" class="notice-row">
|
||||||
<v-alert
|
<v-alert
|
||||||
@@ -166,29 +161,10 @@ export default {
|
|||||||
background-color: var(--v-theme-background);
|
background-color: var(--v-theme-background);
|
||||||
min-height: calc(100vh - 64px);
|
min-height: calc(100vh - 64px);
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
.dashboard-header {
|
|
||||||
margin-bottom: 24px;
|
|
||||||
padding-bottom: 16px;
|
|
||||||
border-bottom: 1px solid rgba(0, 0, 0, 0.06);
|
|
||||||
}
|
|
||||||
|
|
||||||
.dashboard-title {
|
|
||||||
font-size: 24px;
|
|
||||||
font-weight: 600;
|
|
||||||
color: var(--v-theme-primaryText);
|
|
||||||
margin-bottom: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.dashboard-subtitle {
|
|
||||||
font-size: 14px;
|
|
||||||
color: var(--v-theme-secondaryText);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.notice-row {
|
.notice-row {
|
||||||
margin-bottom: 20px;
|
margin-bottom: 16px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dashboard-alert {
|
.dashboard-alert {
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ export default {
|
|||||||
|
|
||||||
.stat-value-wrapper {
|
.stat-value-wrapper {
|
||||||
display: flex;
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
align-items: baseline;
|
align-items: baseline;
|
||||||
justify-content: space-between;
|
justify-content: space-between;
|
||||||
margin-bottom: 4px;
|
margin-bottom: 4px;
|
||||||
|
|||||||
@@ -44,7 +44,7 @@
|
|||||||
<div class="stat-box" :class="{'trend-up': growthRate > 0, 'trend-down': growthRate < 0}">
|
<div class="stat-box" :class="{'trend-up': growthRate > 0, 'trend-down': growthRate < 0}">
|
||||||
<div class="stat-label">{{ t('charts.messageTrend.growthRate') }}</div>
|
<div class="stat-label">{{ t('charts.messageTrend.growthRate') }}</div>
|
||||||
<div class="stat-number">
|
<div class="stat-number">
|
||||||
<v-icon size="small" :icon="growthRate > 0 ? 'mdi-arrow-up' : 'mdi-arrow-down'"></v-icon>
|
<v-icon v-show="growthRate !== 0" size="small" :icon="growthRate > 0 ? 'mdi-arrow-up' : 'mdi-arrow-down'"></v-icon>
|
||||||
{{ Math.abs(growthRate) }}%
|
{{ Math.abs(growthRate) }}%
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -303,8 +303,10 @@ export default {
|
|||||||
|
|
||||||
.chart-header {
|
.chart-header {
|
||||||
display: flex;
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
justify-content: space-between;
|
justify-content: space-between;
|
||||||
align-items: flex-start;
|
align-items: flex-start;
|
||||||
|
gap: 10px;
|
||||||
margin-bottom: 20px;
|
margin-bottom: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,7 +323,7 @@ export default {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.time-select {
|
.time-select {
|
||||||
max-width: 150px;
|
max-width: fit-content;
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -349,6 +351,7 @@ export default {
|
|||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
color: var(--v-theme-primaryText);
|
color: var(--v-theme-primaryText);
|
||||||
display: flex;
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -527,12 +527,11 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
return
|
return
|
||||||
|
|
||||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||||
if provider and provider.meta().type == "dify":
|
if provider and provider.meta().type in ["dify", "coze"]:
|
||||||
assert isinstance(provider, ProviderDify)
|
|
||||||
await provider.forget(message.unified_msg_origin)
|
await provider.forget(message.unified_msg_origin)
|
||||||
message.set_result(
|
message.set_result(
|
||||||
MessageEventResult().message(
|
MessageEventResult().message(
|
||||||
"已重置当前 Dify 会话,新聊天将更换到新的会话。"
|
"已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -755,8 +754,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
创建新对话
|
创建新对话
|
||||||
"""
|
"""
|
||||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||||
if provider and provider.meta().type == "dify":
|
if provider and provider.meta().type in ["dify", "coze"]:
|
||||||
assert isinstance(provider, ProviderDify)
|
|
||||||
await provider.forget(message.unified_msg_origin)
|
await provider.forget(message.unified_msg_origin)
|
||||||
message.set_result(
|
message.set_result(
|
||||||
MessageEventResult().message("成功,下次聊天将是新对话。")
|
MessageEventResult().message("成功,下次聊天将是新对话。")
|
||||||
@@ -783,8 +781,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
|
async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
|
||||||
"""创建新群聊对话"""
|
"""创建新群聊对话"""
|
||||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||||
if provider and provider.meta().type == "dify":
|
if provider and provider.meta().type in ["dify", "coze"]:
|
||||||
assert isinstance(provider, ProviderDify)
|
|
||||||
await provider.forget(message.unified_msg_origin)
|
await provider.forget(message.unified_msg_origin)
|
||||||
message.set_result(
|
message.set_result(
|
||||||
MessageEventResult().message("成功,下次聊天将是新对话。")
|
MessageEventResult().message("成功,下次聊天将是新对话。")
|
||||||
@@ -823,7 +820,6 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
|
|
||||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||||
if provider and provider.meta().type == "dify":
|
if provider and provider.meta().type == "dify":
|
||||||
assert isinstance(provider, ProviderDify)
|
|
||||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||||
if not data["data"]:
|
if not data["data"]:
|
||||||
message.set_result(MessageEventResult().message("未找到任何对话。"))
|
message.set_result(MessageEventResult().message("未找到任何对话。"))
|
||||||
@@ -1214,6 +1210,12 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||||
req.prompt = user_info + req.prompt
|
req.prompt = user_info + req.prompt
|
||||||
|
|
||||||
|
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||||
|
group_name = event.message_obj.group.group_name
|
||||||
|
|
||||||
|
if group_name:
|
||||||
|
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||||
|
|
||||||
# 启用附加时间戳
|
# 启用附加时间戳
|
||||||
if cfg.get("datetime_system_prompt"):
|
if cfg.get("datetime_system_prompt"):
|
||||||
current_time = None
|
current_time = None
|
||||||
@@ -1230,6 +1232,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
)
|
)
|
||||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||||
|
|
||||||
|
img_cap_prov_id = cfg.get("default_image_caption_provider_id")
|
||||||
if req.conversation:
|
if req.conversation:
|
||||||
# persona inject
|
# persona inject
|
||||||
persona_id = req.conversation.persona_id or cfg.get("default_personality")
|
persona_id = req.conversation.persona_id or cfg.get("default_personality")
|
||||||
@@ -1270,7 +1273,6 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
|
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
|
||||||
|
|
||||||
# image caption
|
# image caption
|
||||||
img_cap_prov_id = cfg.get("default_image_caption_provider_id")
|
|
||||||
if img_cap_prov_id and req.image_urls:
|
if img_cap_prov_id and req.image_urls:
|
||||||
img_cap_prompt = cfg.get(
|
img_cap_prompt = cfg.get(
|
||||||
"image_caption_prompt", "Please describe the image."
|
"image_caption_prompt", "Please describe the image."
|
||||||
@@ -1307,9 +1309,12 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
break
|
break
|
||||||
if image_seg:
|
if image_seg:
|
||||||
try:
|
try:
|
||||||
if prov := self.context.get_using_provider(
|
prov = None
|
||||||
event.unified_msg_origin
|
if img_cap_prov_id:
|
||||||
):
|
prov = self.context.get_provider_by_id(img_cap_prov_id)
|
||||||
|
if prov is None:
|
||||||
|
prov = self.context.get_using_provider(event.unified_msg_origin)
|
||||||
|
if prov:
|
||||||
llm_resp = await prov.text_chat(
|
llm_resp = await prov.text_chat(
|
||||||
prompt="Please describe the image content.",
|
prompt="Please describe the image content.",
|
||||||
image_urls=[await image_seg.convert_to_file_path()],
|
image_urls=[await image_seg.convert_to_file_path()],
|
||||||
@@ -1318,6 +1323,8 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
req.system_prompt += (
|
req.system_prompt += (
|
||||||
f"Image Caption: {llm_resp.completion_text}\n"
|
f"Image Caption: {llm_resp.completion_text}\n"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No provider found for image captioning.")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(f"处理引用图片失败: {e}")
|
logger.error(f"处理引用图片失败: {e}")
|
||||||
|
|
||||||
@@ -1337,22 +1344,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
logger.error(f"ltm: {e}")
|
logger.error(f"ltm: {e}")
|
||||||
|
|
||||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||||
@filter.command("alter_cmd")
|
@filter.command("alter_cmd", alias={"alter"})
|
||||||
async def alter_cmd(self, event: AstrMessageEvent):
|
async def alter_cmd(self, event: AstrMessageEvent):
|
||||||
# token = event.message_str.split(" ")
|
|
||||||
token = self.parse_commands(event.message_str)
|
token = self.parse_commands(event.message_str)
|
||||||
if token.len < 2:
|
if token.len < 3:
|
||||||
yield event.plain_result(
|
yield event.plain_result(
|
||||||
"可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令\n /alter_cmd reset config 打开reset权限配置"
|
"该指令用于设置指令或指令组的权限。\n"
|
||||||
|
"格式: /alter_cmd <cmd_name> <admin/member>\n"
|
||||||
|
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
|
||||||
|
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
|
||||||
|
"/alter_cmd reset config 打开 reset 权限配置"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
cmd_name = token.get(1)
|
cmd_name = " ".join(token.tokens[1:-1])
|
||||||
cmd_type = token.get(2)
|
cmd_type = token.get(-1)
|
||||||
|
|
||||||
# ============================
|
|
||||||
# 对reset权限进行特殊处理
|
|
||||||
# ============================
|
|
||||||
if cmd_name == "reset" and cmd_type == "config":
|
if cmd_name == "reset" and cmd_type == "config":
|
||||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||||
plugin_ = alter_cmd_cfg.get("astrbot", {})
|
plugin_ = alter_cmd_cfg.get("astrbot", {})
|
||||||
@@ -1402,16 +1409,18 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
|
|
||||||
# 查找指令
|
# 查找指令
|
||||||
found_command = None
|
found_command = None
|
||||||
|
cmd_group = False
|
||||||
for handler in star_handlers_registry:
|
for handler in star_handlers_registry:
|
||||||
assert isinstance(handler, StarHandlerMetadata)
|
assert isinstance(handler, StarHandlerMetadata)
|
||||||
for filter_ in handler.event_filters:
|
for filter_ in handler.event_filters:
|
||||||
if isinstance(filter_, CommandFilter):
|
if isinstance(filter_, CommandFilter):
|
||||||
if filter_.command_name == cmd_name:
|
if filter_.equals(cmd_name):
|
||||||
found_command = handler
|
found_command = handler
|
||||||
break
|
break
|
||||||
elif isinstance(filter_, CommandGroupFilter):
|
elif isinstance(filter_, CommandGroupFilter):
|
||||||
if cmd_name == filter_.group_name:
|
if filter_.equals(cmd_name):
|
||||||
found_command = handler
|
found_command = handler
|
||||||
|
cmd_group = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if not found_command:
|
if not found_command:
|
||||||
@@ -1448,8 +1457,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
|||||||
else filter.PermissionType.MEMBER
|
else filter.PermissionType.MEMBER
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
cmd_group_str = "指令组" if cmd_group else "指令"
|
||||||
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
|
yield event.plain_result(
|
||||||
|
f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。"
|
||||||
|
)
|
||||||
|
|
||||||
async def update_reset_permission(self, scene_key: str, perm_type: str):
|
async def update_reset_permission(self, scene_key: str, perm_type: str):
|
||||||
"""更新reset命令在特定场景下的权限设置
|
"""更新reset命令在特定场景下的权限设置
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class Main(star.Star):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
@filter.command("websearch")
|
@filter.command("websearch")
|
||||||
async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str:
|
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
|
||||||
event.set_result(
|
event.set_result(
|
||||||
MessageEventResult().message(
|
MessageEventResult().message(
|
||||||
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。"
|
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。"
|
||||||
@@ -210,7 +210,7 @@ class Main(star.Star):
|
|||||||
processed_results = await asyncio.gather(*tasks, return_exceptions=True)
|
processed_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
ret = ""
|
ret = ""
|
||||||
for processed_result in processed_results:
|
for processed_result in processed_results:
|
||||||
if isinstance(processed_result, Exception):
|
if isinstance(processed_result, BaseException):
|
||||||
logger.error(f"Error processing search result: {processed_result}")
|
logger.error(f"Error processing search result: {processed_result}")
|
||||||
continue
|
continue
|
||||||
ret += processed_result
|
ret += processed_result
|
||||||
@@ -335,7 +335,7 @@ class Main(star.Star):
|
|||||||
@filter.on_llm_request(priority=-10000)
|
@filter.on_llm_request(priority=-10000)
|
||||||
async def edit_web_search_tools(
|
async def edit_web_search_tools(
|
||||||
self, event: AstrMessageEvent, req: ProviderRequest
|
self, event: AstrMessageEvent, req: ProviderRequest
|
||||||
) -> str:
|
):
|
||||||
"""Get the session conversation for the given event."""
|
"""Get the session conversation for the given event."""
|
||||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||||
prov_settings = cfg.get("provider_settings", {})
|
prov_settings = cfg.get("provider_settings", {})
|
||||||
@@ -347,6 +347,9 @@ class Main(star.Star):
|
|||||||
req.func_tool = tool_set.get_full_tool_set()
|
req.func_tool = tool_set.get_full_tool_set()
|
||||||
tool_set = req.func_tool
|
tool_set = req.func_tool
|
||||||
|
|
||||||
|
if not tool_set:
|
||||||
|
return
|
||||||
|
|
||||||
if not websearch_enable:
|
if not websearch_enable:
|
||||||
# pop tools
|
# pop tools
|
||||||
for tool_name in self.TOOLS:
|
for tool_name in self.TOOLS:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "AstrBot"
|
name = "AstrBot"
|
||||||
version = "4.1.2"
|
version = "4.2.0"
|
||||||
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
@@ -49,6 +49,8 @@ dependencies = [
|
|||||||
"watchfiles>=1.0.5",
|
"watchfiles>=1.0.5",
|
||||||
"websockets>=15.0.1",
|
"websockets>=15.0.1",
|
||||||
"wechatpy>=1.8.18",
|
"wechatpy>=1.8.18",
|
||||||
|
"audioop-lts ; python_full_version >= '3.13'",
|
||||||
|
"click>=8.2.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -43,3 +43,4 @@ pydub
|
|||||||
sqlmodel
|
sqlmodel
|
||||||
deprecated
|
deprecated
|
||||||
sqlalchemy[asyncio]
|
sqlalchemy[asyncio]
|
||||||
|
audioop-lts; python_version>='3.13'
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
from quart import Quart
|
from quart import Quart
|
||||||
from astrbot.dashboard.server import AstrBotDashboard
|
from astrbot.dashboard.server import AstrBotDashboard
|
||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
@@ -9,36 +11,46 @@ from astrbot.core.star.star_handler import star_handlers_registry
|
|||||||
from astrbot.core.star.star import star_registry
|
from astrbot.core.star.star import star_registry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest_asyncio.fixture(scope="module")
|
||||||
def core_lifecycle_td():
|
async def core_lifecycle_td(tmp_path_factory):
|
||||||
db = SQLiteDatabase("data/data_v3.db")
|
"""Creates and initializes a core lifecycle instance with a temporary database."""
|
||||||
|
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
|
||||||
|
db = SQLiteDatabase(str(tmp_db_path))
|
||||||
log_broker = LogBroker()
|
log_broker = LogBroker()
|
||||||
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
|
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
|
||||||
return core_lifecycle_td
|
await core_lifecycle.initialize()
|
||||||
|
return core_lifecycle
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def app(core_lifecycle_td):
|
def app(core_lifecycle_td: AstrBotCoreLifecycle):
|
||||||
db = SQLiteDatabase("data/data_v3.db")
|
"""Creates a Quart app instance for testing."""
|
||||||
server = AstrBotDashboard(core_lifecycle_td, db)
|
shutdown_event = asyncio.Event()
|
||||||
|
# The db instance is already part of the core_lifecycle_td
|
||||||
|
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
|
||||||
return server.app
|
return server.app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest_asyncio.fixture(scope="module")
|
||||||
def header():
|
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||||
return {}
|
"""Handles login and returns an authenticated header."""
|
||||||
|
test_client = app.test_client()
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/auth/login",
|
||||||
|
json={
|
||||||
|
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
|
||||||
|
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
token = data["data"]["token"]
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_init_core_lifecycle_td(core_lifecycle_td):
|
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||||
await core_lifecycle_td.initialize()
|
"""Tests the login functionality with both wrong and correct credentials."""
|
||||||
assert core_lifecycle_td is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auth_login(
|
|
||||||
app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict
|
|
||||||
):
|
|
||||||
test_client = app.test_client()
|
test_client = app.test_client()
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/auth/login", json={"username": "wrong", "password": "password"}
|
"/api/auth/login", json={"username": "wrong", "password": "password"}
|
||||||
@@ -55,31 +67,32 @@ async def test_auth_login(
|
|||||||
)
|
)
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
assert data["status"] == "ok" and "token" in data["data"]
|
assert data["status"] == "ok" and "token" in data["data"]
|
||||||
header["Authorization"] = f"Bearer {data['data']['token']}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_stat(app: Quart, header: dict):
|
async def test_get_stat(app: Quart, authenticated_header: dict):
|
||||||
test_client = app.test_client()
|
test_client = app.test_client()
|
||||||
response = await test_client.get("/api/stat/get")
|
response = await test_client.get("/api/stat/get")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
response = await test_client.get("/api/stat/get", headers=header)
|
response = await test_client.get("/api/stat/get", headers=authenticated_header)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
assert data["status"] == "ok" and "platform" in data["data"]
|
assert data["status"] == "ok" and "platform" in data["data"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_plugins(app: Quart, header: dict):
|
async def test_plugins(app: Quart, authenticated_header: dict):
|
||||||
test_client = app.test_client()
|
test_client = app.test_client()
|
||||||
# 已经安装的插件
|
# 已经安装的插件
|
||||||
response = await test_client.get("/api/plugin/get", headers=header)
|
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
assert data["status"] == "ok"
|
assert data["status"] == "ok"
|
||||||
|
|
||||||
# 插件市场
|
# 插件市场
|
||||||
response = await test_client.get("/api/plugin/market_list", headers=header)
|
response = await test_client.get(
|
||||||
|
"/api/plugin/market_list", headers=authenticated_header
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
assert data["status"] == "ok"
|
assert data["status"] == "ok"
|
||||||
@@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict):
|
|||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/plugin/install",
|
"/api/plugin/install",
|
||||||
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
|
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
|
||||||
headers=header,
|
headers=authenticated_header,
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
@@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict):
|
|||||||
|
|
||||||
# 插件更新
|
# 插件更新
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header
|
"/api/plugin/update",
|
||||||
|
json={"name": "astrbot_plugin_essential"},
|
||||||
|
headers=authenticated_header,
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
@@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict):
|
|||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/plugin/uninstall",
|
"/api/plugin/uninstall",
|
||||||
json={"name": "astrbot_plugin_essential"},
|
json={"name": "astrbot_plugin_essential"},
|
||||||
headers=header,
|
headers=authenticated_header,
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
@@ -132,9 +147,9 @@ async def test_plugins(app: Quart, header: dict):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_update(app: Quart, header: dict):
|
async def test_check_update(app: Quart, authenticated_header: dict):
|
||||||
test_client = app.test_client()
|
test_client = app.test_client()
|
||||||
response = await test_client.get("/api/update/check", headers=header)
|
response = await test_client.get("/api/update/check", headers=authenticated_header)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
assert data["status"] == "success"
|
assert data["status"] == "success"
|
||||||
@@ -142,24 +157,45 @@ async def test_check_update(app: Quart, header: dict):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_do_update(
|
async def test_do_update(
|
||||||
app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle
|
app: Quart,
|
||||||
|
authenticated_header: dict,
|
||||||
|
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path_factory,
|
||||||
):
|
):
|
||||||
global VERSION
|
|
||||||
test_client = app.test_client()
|
test_client = app.test_client()
|
||||||
os.makedirs("data/astrbot_release", exist_ok=True)
|
|
||||||
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
|
# Use a temporary path for the mock update to avoid side effects
|
||||||
VERSION = "114.514.1919810"
|
temp_release_dir = tmp_path_factory.mktemp("release")
|
||||||
response = await test_client.post(
|
release_path = temp_release_dir / "astrbot"
|
||||||
"/api/update/do", headers=header, json={"version": "latest"}
|
|
||||||
|
async def mock_update(*args, **kwargs):
|
||||||
|
"""Mocks the update process by creating a directory in the temp path."""
|
||||||
|
os.makedirs(release_path, exist_ok=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def mock_download_dashboard(*args, **kwargs):
|
||||||
|
"""Mocks the dashboard download to prevent network access."""
|
||||||
|
return
|
||||||
|
|
||||||
|
async def mock_pip_install(*args, **kwargs):
|
||||||
|
"""Mocks pip install to prevent actual installation."""
|
||||||
|
return
|
||||||
|
|
||||||
|
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
|
||||||
data = await response.get_json()
|
|
||||||
assert data["status"] == "error" # 已经是最新版本
|
|
||||||
|
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False}
|
"/api/update/do",
|
||||||
|
headers=authenticated_header,
|
||||||
|
json={"version": "v3.4.0", "reboot": False},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
assert data["status"] == "ok"
|
assert data["status"] == "ok"
|
||||||
assert os.path.exists("data/astrbot_release/astrbot")
|
assert os.path.exists(release_path)
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
# 将项目根目录添加到 sys.path
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from main import check_env, check_dashboard_files
|
from main import check_env, check_dashboard_files
|
||||||
@@ -27,29 +31,58 @@ def test_check_env(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_dashboard_files(monkeypatch):
|
async def test_check_dashboard_files_not_exists(monkeypatch):
|
||||||
|
"""Tests dashboard download when files do not exist."""
|
||||||
monkeypatch.setattr(os.path, "exists", lambda x: False)
|
monkeypatch.setattr(os.path, "exists", lambda x: False)
|
||||||
|
|
||||||
async def mock_get(*args, **kwargs):
|
with mock.patch("main.download_dashboard") as mock_download:
|
||||||
class MockResponse:
|
|
||||||
status = 200
|
|
||||||
|
|
||||||
async def read(self):
|
|
||||||
return b"content"
|
|
||||||
|
|
||||||
return MockResponse()
|
|
||||||
|
|
||||||
with mock.patch("aiohttp.ClientSession.get", new=mock_get):
|
|
||||||
with mock.patch("builtins.open", mock.mock_open()) as mock_file:
|
|
||||||
with mock.patch("zipfile.ZipFile.extractall") as mock_extractall:
|
|
||||||
|
|
||||||
async def mock_aenter(_):
|
|
||||||
await check_dashboard_files()
|
await check_dashboard_files()
|
||||||
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
|
mock_download.assert_called_once()
|
||||||
mock_extractall.assert_called_once()
|
|
||||||
|
|
||||||
async def mock_aexit(obj, exc_type, exc, tb):
|
|
||||||
return
|
|
||||||
|
|
||||||
mock_extractall.__aenter__ = mock_aenter
|
@pytest.mark.asyncio
|
||||||
mock_extractall.__aexit__ = mock_aexit
|
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
|
||||||
|
"""Tests that dashboard is not downloaded when it exists and version matches."""
|
||||||
|
# Mock os.path.exists to return True
|
||||||
|
monkeypatch.setattr(os.path, "exists", lambda x: True)
|
||||||
|
|
||||||
|
# Mock get_dashboard_version to return the current version
|
||||||
|
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||||
|
# We need to import VERSION from main's context
|
||||||
|
from main import VERSION
|
||||||
|
|
||||||
|
mock_get_version.return_value = f"v{VERSION}"
|
||||||
|
|
||||||
|
with mock.patch("main.download_dashboard") as mock_download:
|
||||||
|
await check_dashboard_files()
|
||||||
|
# Assert that download_dashboard was NOT called
|
||||||
|
mock_download.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
|
||||||
|
"""Tests that a warning is logged when dashboard version mismatches."""
|
||||||
|
monkeypatch.setattr(os.path, "exists", lambda x: True)
|
||||||
|
|
||||||
|
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||||
|
mock_get_version.return_value = "v0.0.1" # A different version
|
||||||
|
|
||||||
|
with mock.patch("main.logger.warning") as mock_logger_warning:
|
||||||
|
await check_dashboard_files()
|
||||||
|
mock_logger_warning.assert_called_once()
|
||||||
|
call_args, _ = mock_logger_warning.call_args
|
||||||
|
assert "不符" in call_args[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
|
||||||
|
"""Tests that providing a valid webui_dir skips all checks."""
|
||||||
|
valid_dir = "/tmp/my-custom-webui"
|
||||||
|
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
|
||||||
|
|
||||||
|
with mock.patch("main.download_dashboard") as mock_download:
|
||||||
|
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||||
|
result = await check_dashboard_files(webui_dir=valid_dir)
|
||||||
|
assert result == valid_dir
|
||||||
|
mock_download.assert_not_called()
|
||||||
|
mock_get_version.assert_not_called()
|
||||||
|
|||||||
@@ -1,285 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import asyncio
|
|
||||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
|
||||||
from astrbot.core.star import PluginManager
|
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core.platform.astrbot_message import (
|
|
||||||
AstrBotMessage,
|
|
||||||
MessageMember,
|
|
||||||
MessageType,
|
|
||||||
)
|
|
||||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
|
||||||
from astrbot.core.message.components import Plain, At
|
|
||||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
|
||||||
from astrbot.core.platform.manager import PlatformManager
|
|
||||||
from astrbot.core.provider.manager import ProviderManager
|
|
||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
|
||||||
from astrbot.core.star.context import Context
|
|
||||||
from asyncio import Queue
|
|
||||||
|
|
||||||
SESSION_ID_IN_WHITELIST = "test_sid_wl"
|
|
||||||
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
|
|
||||||
TEST_LLM_PROVIDER = {
|
|
||||||
"id": "zhipu_default",
|
|
||||||
"type": "openai_chat_completion",
|
|
||||||
"enable": True,
|
|
||||||
"key": [os.getenv("ZHIPU_API_KEY")],
|
|
||||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
|
||||||
"model_config": {
|
|
||||||
"model": "glm-4-flash",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_COMMANDS = [
|
|
||||||
["help", "已注册的 AstrBot 内置指令"],
|
|
||||||
["tool ls", "函数工具"],
|
|
||||||
["tool on websearch", "激活工具"],
|
|
||||||
["tool off websearch", "停用工具"],
|
|
||||||
["plugin", "已加载的插件"],
|
|
||||||
["t2i", "文本转图片模式"],
|
|
||||||
["sid", "此 ID 可用于设置会话白名单。"],
|
|
||||||
["op test_op", "授权成功。"],
|
|
||||||
["deop test_op", "取消授权成功。"],
|
|
||||||
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
|
|
||||||
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
|
|
||||||
["provider", "当前载入的 LLM 提供商"],
|
|
||||||
["reset", "重置成功"],
|
|
||||||
# ["model", "查看、切换提供商模型列表"],
|
|
||||||
["history", "历史记录:"],
|
|
||||||
["key", "当前 Key"],
|
|
||||||
["persona", "[Persona]"],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class FakeAstrMessageEvent(AstrMessageEvent):
|
|
||||||
def __init__(self, abm: AstrBotMessage = None):
|
|
||||||
meta = PlatformMetadata("test_platform", "test")
|
|
||||||
super().__init__(
|
|
||||||
message_str=abm.message_str,
|
|
||||||
message_obj=abm,
|
|
||||||
platform_meta=meta,
|
|
||||||
session_id=abm.session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
|
||||||
await super().send(message)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_fake_event(
|
|
||||||
message_str: str,
|
|
||||||
session_id: str = "test_sid",
|
|
||||||
is_at: bool = False,
|
|
||||||
is_group: bool = False,
|
|
||||||
sender_id: str = "123456",
|
|
||||||
):
|
|
||||||
abm = AstrBotMessage()
|
|
||||||
abm.message_str = message_str
|
|
||||||
abm.group_id = "test"
|
|
||||||
abm.message = [Plain(message_str)]
|
|
||||||
if is_at:
|
|
||||||
abm.message.append(At(qq="bot"))
|
|
||||||
abm.self_id = "bot"
|
|
||||||
abm.sender = MessageMember(sender_id, "mika")
|
|
||||||
abm.timestamp = 1234567890
|
|
||||||
abm.message_id = "test"
|
|
||||||
abm.session_id = session_id
|
|
||||||
if is_group:
|
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
|
||||||
else:
|
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
|
||||||
return FakeAstrMessageEvent(abm)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def event_queue():
|
|
||||||
return Queue()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def config():
|
|
||||||
cfg = AstrBotConfig()
|
|
||||||
cfg["platform_settings"]["id_whitelist"] = [
|
|
||||||
"test_platform:FriendMessage:test_sid_wl",
|
|
||||||
"test_platform:GroupMessage:test_sid_wl",
|
|
||||||
]
|
|
||||||
cfg["admins_id"] = ["123456"]
|
|
||||||
cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"]
|
|
||||||
cfg["provider"] = [TEST_LLM_PROVIDER]
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def db():
|
|
||||||
return SQLiteDatabase("data/data_v3.db")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def platform_manager(event_queue, config):
|
|
||||||
return PlatformManager(config, event_queue)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def provider_manager(config, db):
|
|
||||||
return ProviderManager(config, db)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def star_context(event_queue, config, db, platform_manager, provider_manager):
|
|
||||||
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
|
|
||||||
return star_context
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def plugin_manager(star_context, config):
|
|
||||||
plugin_manager = PluginManager(star_context, config)
|
|
||||||
# await plugin_manager.reload()
|
|
||||||
asyncio.run(plugin_manager.reload())
|
|
||||||
return plugin_manager
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def pipeline_context(config, plugin_manager):
|
|
||||||
return PipelineContext(config, plugin_manager)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def pipeline_scheduler(pipeline_context):
|
|
||||||
return PipelineScheduler(pipeline_context)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_platform_initialization(platform_manager: PlatformManager):
|
|
||||||
await platform_manager.initialize()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_provider_initialization(provider_manager: ProviderManager):
|
|
||||||
await provider_manager.initialize()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
|
|
||||||
await pipeline_scheduler.initialize()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
|
||||||
"""测试唤醒"""
|
|
||||||
# 群聊无 @ 无指令
|
|
||||||
caplog.clear()
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any(
|
|
||||||
"执行阶段 WhitelistCheckStage" not in message for message in caplog.messages
|
|
||||||
)
|
|
||||||
# 群聊有 @ 无指令
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
"test", is_group=True, is_at=True
|
|
||||||
)
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
|
|
||||||
# 群聊有指令
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
"/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST
|
|
||||||
)
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert mock_event._has_send_oper is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pipeline_wl(
|
|
||||||
pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog
|
|
||||||
):
|
|
||||||
caplog.clear()
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
"test", SESSION_ID_IN_WHITELIST, sender_id="123"
|
|
||||||
)
|
|
||||||
with caplog.at_level(logging.INFO):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any(
|
|
||||||
"不在会话白名单中,已终止事件传播。" not in message
|
|
||||||
for message in caplog.messages
|
|
||||||
), "日志中未找到预期的消息"
|
|
||||||
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
|
|
||||||
with caplog.at_level(logging.INFO):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any(
|
|
||||||
"不在会话白名单中,已终止事件传播。" in message for message in caplog.messages
|
|
||||||
), "日志中未找到预期的消息"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
|
|
||||||
# 测试默认屏蔽词
|
|
||||||
caplog.clear()
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
"色情", session_id=SESSION_ID_IN_WHITELIST
|
|
||||||
) # 测试需要。
|
|
||||||
with caplog.at_level(logging.INFO):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any("内容安全检查不通过" in message for message in caplog.messages), (
|
|
||||||
"日志中未找到预期的消息"
|
|
||||||
)
|
|
||||||
# 测试额外屏蔽词
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
"TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
|
|
||||||
)
|
|
||||||
with caplog.at_level(logging.INFO):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any("内容安全检查不通过" in message for message in caplog.messages), (
|
|
||||||
"日志中未找到预期的消息"
|
|
||||||
)
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
"_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
|
|
||||||
)
|
|
||||||
with caplog.at_level(logging.INFO):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any("内容安全检查不通过" not in message for message in caplog.messages)
|
|
||||||
# TODO: 测试 百度AI 的内容安全检查
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
|
||||||
caplog.clear()
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
"just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST
|
|
||||||
)
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any("请求 LLM" in message for message in caplog.messages)
|
|
||||||
assert mock_event.get_result() is not None
|
|
||||||
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
|
|
||||||
caplog.clear()
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
"help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST
|
|
||||||
)
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
assert any("请求 LLM" in message for message in caplog.messages)
|
|
||||||
assert any(
|
|
||||||
"web_searcher - search_from_search_engine" in message
|
|
||||||
for message in caplog.messages
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
|
|
||||||
for command in TEST_COMMANDS:
|
|
||||||
caplog.clear()
|
|
||||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
|
||||||
command[0], session_id=SESSION_ID_IN_WHITELIST
|
|
||||||
)
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
await pipeline_scheduler.execute(mock_event)
|
|
||||||
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
|
|
||||||
assert any(command[1] in message for message in caplog.messages)
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import MagicMock
|
||||||
from astrbot.core.star.star_manager import PluginManager
|
from astrbot.core.star.star_manager import PluginManager
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry
|
from astrbot.core.star.star_handler import star_handlers_registry
|
||||||
from astrbot.core.star.star import star_registry
|
from astrbot.core.star.star import star_registry
|
||||||
@@ -8,18 +9,51 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
|||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
|
|
||||||
event_queue = Queue()
|
|
||||||
|
|
||||||
config = AstrBotConfig()
|
|
||||||
|
|
||||||
db = SQLiteDatabase("data/data_v3.db")
|
|
||||||
|
|
||||||
star_context = Context(event_queue, config, db)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def plugin_manager_pm():
|
def plugin_manager_pm(tmp_path):
|
||||||
return PluginManager(star_context, config)
|
"""
|
||||||
|
Provides a fully isolated PluginManager instance for testing.
|
||||||
|
- Uses a temporary directory for plugins.
|
||||||
|
- Uses a temporary database.
|
||||||
|
- Creates a fresh context for each test.
|
||||||
|
"""
|
||||||
|
# Create temporary resources
|
||||||
|
temp_plugins_path = tmp_path / "plugins"
|
||||||
|
temp_plugins_path.mkdir()
|
||||||
|
temp_db_path = tmp_path / "test_db.db"
|
||||||
|
|
||||||
|
# Create fresh, isolated instances for the context
|
||||||
|
event_queue = Queue()
|
||||||
|
config = AstrBotConfig()
|
||||||
|
db = SQLiteDatabase(str(temp_db_path))
|
||||||
|
|
||||||
|
# Set the plugin store path in the config to the temporary directory
|
||||||
|
config.plugin_store_path = str(temp_plugins_path)
|
||||||
|
|
||||||
|
# Mock dependencies for the context
|
||||||
|
provider_manager = MagicMock()
|
||||||
|
platform_manager = MagicMock()
|
||||||
|
conversation_manager = MagicMock()
|
||||||
|
message_history_manager = MagicMock()
|
||||||
|
persona_manager = MagicMock()
|
||||||
|
astrbot_config_mgr = MagicMock()
|
||||||
|
|
||||||
|
star_context = Context(
|
||||||
|
event_queue,
|
||||||
|
config,
|
||||||
|
db,
|
||||||
|
provider_manager,
|
||||||
|
platform_manager,
|
||||||
|
conversation_manager,
|
||||||
|
message_history_manager,
|
||||||
|
persona_manager,
|
||||||
|
astrbot_config_mgr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the PluginManager instance
|
||||||
|
manager = PluginManager(star_context, config)
|
||||||
|
yield manager
|
||||||
|
|
||||||
|
|
||||||
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||||
@@ -36,48 +70,76 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_plugin_crud(plugin_manager_pm: PluginManager):
|
async def test_install_plugin(plugin_manager_pm: PluginManager):
|
||||||
"""测试插件安装和重载"""
|
"""Tests successful plugin installation in an isolated environment."""
|
||||||
os.makedirs("data/plugins", exist_ok=True)
|
|
||||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
|
plugin_info = await plugin_manager_pm.install_plugin(test_repo)
|
||||||
exists = False
|
plugin_path = os.path.join(
|
||||||
for md in star_registry:
|
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
|
||||||
if md.name == "astrbot_plugin_essential":
|
)
|
||||||
exists = True
|
|
||||||
break
|
assert plugin_info is not None
|
||||||
assert plugin_path is not None
|
|
||||||
assert os.path.exists(plugin_path)
|
assert os.path.exists(plugin_path)
|
||||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
assert any(md.name == "astrbot_plugin_essential" for md in star_registry), (
|
||||||
# shutil.rmtree(plugin_path)
|
"Plugin 'astrbot_plugin_essential' was not loaded into star_registry."
|
||||||
|
)
|
||||||
|
|
||||||
# install plugin which is not exists
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||||
|
"""Tests that installing a non-existent plugin raises an exception."""
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
|
await plugin_manager_pm.install_plugin(
|
||||||
|
"https://github.com/Soulter/non_existent_repo"
|
||||||
|
)
|
||||||
|
|
||||||
# update
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_plugin(plugin_manager_pm: PluginManager):
|
||||||
|
"""Tests updating an existing plugin in an isolated environment."""
|
||||||
|
# First, install the plugin
|
||||||
|
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||||
|
await plugin_manager_pm.install_plugin(test_repo)
|
||||||
|
|
||||||
|
# Then, update it
|
||||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
|
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
|
|
||||||
|
|
||||||
# uninstall
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||||
|
"""Tests that updating a non-existent plugin raises an exception."""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await plugin_manager_pm.update_plugin("non_existent_plugin")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_uninstall_plugin(plugin_manager_pm: PluginManager):
|
||||||
|
"""Tests successful plugin uninstallation in an isolated environment."""
|
||||||
|
# First, install the plugin
|
||||||
|
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||||
|
await plugin_manager_pm.install_plugin(test_repo)
|
||||||
|
plugin_path = os.path.join(
|
||||||
|
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
|
||||||
|
)
|
||||||
|
assert os.path.exists(plugin_path) # Pre-condition
|
||||||
|
|
||||||
|
# Then, uninstall it
|
||||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
|
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
|
||||||
|
|
||||||
assert not os.path.exists(plugin_path)
|
assert not os.path.exists(plugin_path)
|
||||||
exists = False
|
assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), (
|
||||||
for md in star_registry:
|
"Plugin 'astrbot_plugin_essential' was not unloaded from star_registry."
|
||||||
if md.name == "astrbot_plugin_essential":
|
)
|
||||||
exists = True
|
assert not any(
|
||||||
break
|
"astrbot_plugin_essential" in md.handler_module_path
|
||||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
for md in star_handlers_registry
|
||||||
exists = False
|
), (
|
||||||
for md in star_handlers_registry:
|
"Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry."
|
||||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
)
|
||||||
exists = True
|
|
||||||
break
|
|
||||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||||
|
"""Tests that uninstalling a non-existent plugin raises an exception."""
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
|
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")
|
||||||
|
|
||||||
# TODO: file installation
|
|
||||||
|
|||||||
Reference in New Issue
Block a user