Compare commits

..

50 Commits

Author SHA1 Message Date
Soulter
3c8c28ebd5 chore: bump version to 4.2.0 2025-09-27 20:45:50 +08:00
Soulter
524285f767 feat: add cancel button with localized text to AddNewPlatform and update close button in AddNewProvider
fixes: #2889
2025-09-27 20:41:45 +08:00
Soulter
c2a34475f1 feat: 支持删除指定会话以及部分会话管理优化 (#2895)
* feat: add toast notification system with snackbar component

* feat: add session deletion functionality

* feat: support batch operations for updating session persona, provider, LLM, and TTS statuses

fix: #2263

* feat: 修复对话状态关闭,删除对话管理库会导致对话无法恢复

fixes: #2309
2025-09-27 20:36:30 +08:00
Soulter
a69195a02b fix: webchat streaming queue interrupted after user closing tab (#2892)
* feat: add toast notification system with snackbar component

* feat: enhance chat functionality with conversation running state and notifications

* fix: update bot message avatar rendering during streaming

* feat: implement conversation tracking context manager for webchat

* fix: update conversation tracking to remove conversation ID on exit
2025-09-27 17:57:12 +08:00
RC-CHN
19d7438499 fix: unit tests (#2760)
* fix:修复了main和plugin_manager部分单元测试

* fix: 修复了dashboard部分测试

* remove: 删除暂无用的配置测试脚本

* perf:拆分插件增查删改为独立的单元测试

* refactor: 重构插件管理器测试,使用临时环境隔离测试实例

* test: 增加对仪表板文件检查的单元测试,涵盖不同情况

* style: format code

* remove: 删除未使用的导入语句

* delete: remove unused test file for pipeline

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-27 14:43:04 +08:00
anka
ccb380ce06 feat: 支持接入 Coze (#2858)
* feat: 适配 coze 供应商
1. 支持文件上传
2. 支持多模态
3. 支持流式传输
4. 支持 API 端的上下文保存历史记录
5. 支持类似 dify 的 forget 接口

* style: format code

* fix: type checking error

* fix: 修复:
1. 使用coze api端的上下文时, 现在不会重复传递上下文
2. 使用 AstrBot 的上下文时, 正确处理其中的图片信息
3. 上传图片时, 提供一个非持久化的缓存避免重复上传(在解析上下文并将文件转化为file_id传递给coze api时, 如果没有缓存会导致很多的网络资源浪费)
4. 修复reset等指令不能正确重置上下文的问题

* fix: 移除某些地方多余的针对 dify 的断言, 以兼容 Coze

* style: 修改配置项显示/webchat平台对于非预期的类型的处理

* fix: 让conversation_id放到请求中正确的位置

* refactor: extract coze api client

* refactor: improve image processing logic in ProviderCoze

* chore: remove file ext guessing

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-27 14:23:29 +08:00
Ding Jiatong
a35c439bbd fix: 使用增量解码器修复 Dify 流式返回结果偶现的解码错误 (#2888)
* fix: 修复linux下utf-8解码错误的问题

* feat: use incremental decoder

* fix: add type hint for response parameter in _stream_sse and refactor file upload method

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-26 23:04:58 +08:00
Soulter
09d1f96603 fix: 修复 /alter_cmd 指令无法控制指令组、子指令组和子指令组下子指令的问题 (#2873)
* fix: revert changes in command_group.py at 782c036 to fix command group permission check

* fix: 不传递 GroupCommand handler

* perf: alter_cmd 指令支持对子指令、指令组进行配置

* chore: remove test commands and subcommands from test_group

* chore: add cache for complete command names list in CommandFilter and CommandGroupFilter

---------

Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-26 14:16:50 +08:00
鸦羽
26aa18d980 Merge pull request #2881 from Raven95676/fix/2879
fix: add missing id field
2025-09-26 11:31:28 +08:00
Raven95676
d10b542797 chore: format 2025-09-26 11:05:32 +08:00
Raven95676
ce4e4fb8dd fix: add missing id field 2025-09-26 10:59:11 +08:00
Soulter
8f4a31cf8c chore: bump version to 4.1.7 2025-09-23 22:16:36 +08:00
Soulter
23549f13d6 Feature: 支持批量删除对话历史 (#2859)
* feat: 支持批量删除对话

closes: #2784

* feat: 添加加载状态禁用功能,优化用户交互体验
2025-09-23 22:10:56 +08:00
Soulter
869d11f9a6 perf: 优化验证配置时的性能,移除配置隐式类型转换
fixes: #2646
2025-09-23 21:04:14 +08:00
Soulter
02e73b82ee fix: 修复无法打开更新对话框的问题 2025-09-23 20:29:10 +08:00
Soulter
f85f87f545 feat: WebChat 支持手动填写模型名
closes: #2830
2025-09-23 15:32:54 +08:00
Soulter
1fff5713f3 refactor: 解耦 PlatformPage 和 ProviderPage 的部分组件 2025-09-23 15:32:54 +08:00
Soulter
8453ec36f0 docs: Revise links for documentation and blog in README
Updated links in the README for documentation and blog.
2025-09-23 14:12:05 +08:00
Soulter
d5b3ce8424 fix: update download_dashboard to log specific dashboard release URLs 2025-09-23 13:10:33 +08:00
Soulter
80cbbfa5ca chore: bump version to 4.1.6 2025-09-23 13:02:06 +08:00
Soulter
9177bb660f fix: improve error handling in run_agent for streaming responses 2025-09-23 10:34:24 +08:00
Soulter
a3df39a01a perf: unified button styles
closes: #2748
2025-09-23 10:27:52 +08:00
Soulter
25dce05cbb refactor: improve webchat UI (#2853) 2025-09-23 10:19:26 +08:00
Soulter
1542ea3e03 fix: context.get_provider_by_id issue 2025-09-22 17:22:50 +08:00
Soulter
6084abbcfe feat: add user_id search capability in get_filtered_conversations 2025-09-21 22:45:55 +08:00
Soulter
ed19b63914 chore: bump version to v4.1.5 2025-09-21 21:47:14 +08:00
Soulter
4efeb85296 chore: remove uv.lock file 2025-09-21 21:47:06 +08:00
shangxue
fc76665615 feat: Satori适配器引用消息无法正确识别 (#2686)
* Update PlatformPage.vue

* Update PlatformPage.vue

* Update PlatformPage.vue

* Update satori_adapter.py

* Update satori_event.py

* Update default.py

* Update satori_adapter.py

* Update satori_adapter.py

* style: format code

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-21 21:45:35 +08:00
Soulter
3a044bb71a fix: 修复 Telegram 下流式传输时,第一次输出的内容会被覆盖掉的问题 (#2838)
fixes: #2481
2025-09-21 21:24:47 +08:00
Soulter
cddd606562 perf: 优化 ExtensionPage 2025-09-21 21:10:03 +08:00
Soulter
7a5bc51c11 fix: 识别引用消息的图片时优先使用默认图片转述提供商 (#2836)
* fix: 识别引用消息的图片时优先使用默认图片转述提供商

closes: #2821

* fix: 添加日志记录以处理未找到图片标题提供者的情况

* style: format code
2025-09-21 20:55:32 +08:00
Soulter
9f939b4b6f fix: 修复对话管理页面的关键词搜索功能失效的问题并优化一些 UI 样式 (#2837)
* fix: 修复对话管理页面的关键词搜索功能失效的问题并优化一些 UI 样式

fixes: #2782

* style: format code

* fix: remove debug print statements from conversation retrieval methods
2025-09-21 20:55:15 +08:00
Soulter
80a86f5b1b fix: 修复 astrbot.core.star 等包下的 type checking error (#2787)
* fix: 修复 astrbot.core.star 等包下的 type checking error

* refactor: improve type checking and annotations

* chore: ruff format
2025-09-21 18:10:04 +08:00
yitaikarma
a0ce1855ab fix: 优化统计页内存占用和消息数据趋势的样式 (#2826)
* fix: 调整统计页内存占用和消息趋势分析的布局,优化响应式显示

* fix: 隐藏增长率为零时的趋势图标
2025-09-21 17:06:47 +08:00
anka
a4b43b884a fix: 修复aiocqhttp适配器at会获取群昵称而消息不会获取的逻辑不一致 (#2769)
* fix: 修复at会获取群昵称而消息不会获取的逻辑不一致

* style: format code
2025-09-19 13:04:51 +08:00
PaloMiku
824c0f6667 feat: 新增 Misskey 平台适配器 (#2774)
* feat: add Misskey platform adapter

* fix: 修复 Misskey 配置项的大小写问题

* feat: 添加消息链序列化功能和可见性解析逻辑

* chore: 删除损坏的 Misskey 平台适配器工具函数文件

* docs: 更新 Misskey 消息适配器设置描述信息

* feat: Misskey 单用户连续上下文对话支持

* feat: 为 Astrbot 添加 Misskey 平台适配器的 ID 配置

* feat: 重构 Misskey 平台适配器,提取通用工具函数并优化消息处理逻辑

* refactor: 清理 Misskey 平台适配器和 API 代码,移除冗余注释

* fix: 修复了使用中和使用者反馈的多个问题

* fix: 修改提及格式,确保提及在新行开始,提升帖子美观和易读性。

* feat: 添加默认可见性和本地仅限设置,优化 Misskey 平台适配器的配置

* fix: 更新 Misskey 平台适配器配置,使用前缀以防止和其他适配器未来可能的冲突问题

* chore: rename 'misskey' to 'Misskey' in config

* feat: Misskey 适配器添加聊天消息响应功能,重构接收和发送逻辑为 Websockets 处理

* fix: 增强 Misskey WebSocket 消息日志输出

* refactor: 优化 Misskey 适配器的消息处理和日志输出

* fix: 增强 Misskey WebSocket 重连接逻辑

* feat: 增强 Misskey 适配器的消息处理,支持房间消息和相关功能,重构通用函数,清理代码重复冗余

* fix: 不屏蔽唤醒前缀对默认 LLM 的唤醒

* fix: 透传所有的群聊消息事件

* fix: 修复 message_type

* perf: 实现 send_streaming 以支援流式请求

* docs(README): update README.md

* fix: super().send(message) 被忽略

* fix: 修正 session 结构

: 作为分隔符可能会导致 umo 组装出现问题

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-18 23:34:41 +08:00
Soulter
a030fe8491 feat: add audioop-lts dependencies (#2809)
pydub needs audioop as a requirement but this builtin package has been removed in 3.13
2025-09-18 23:32:04 +08:00
Soulter
3a9429e8ef fix: on_tool_end hook unavailable 2025-09-17 15:48:57 +08:00
anka
c4eb1ab748 chore: bump version to 4.1.4 2025-09-16 20:09:11 +08:00
anka
29ed19d600 Merge pull request #2783 from AstrBotDevs/revert-2778-fix-handler-type
Revert "fix: parameter type/default handling in CommandFilter"
2025-09-16 20:01:23 +08:00
anka
0cc65513a5 Revert "fix: parameter type/default handling in CommandFilter" 2025-09-16 20:01:05 +08:00
Soulter
debc048659 chore: bump version to 4.1.3 2025-09-16 13:16:21 +08:00
邹永赫
92f5c918dd Merge pull request #2778 from MliKiowa/fix-handler-type
fix: parameter type/default handling in CommandFilter
2025-09-16 13:43:53 +09:00
手瓜一十雪
9519f1e8e2 fix: parameter type/default handling in CommandFilter
Adjusts logic to prioritize type annotations over default values when setting handler_params in CommandFilter. This ensures that parameter types are correctly inferred when available.
2025-09-16 11:49:27 +08:00
Soulter
a8f874bf05 fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 (#2757) 2025-09-16 10:45:39 +08:00
anka
9d9917e45b feat: 增加群名称识别到 system prompt, 并提供相应的配置 (#2770)
* feat🤖: 增加群名称识别到system prompt, 并提供相应的配置

* feat: 优化实现方式, 重构AstrBotMessage, 向后兼容

* style: format
2025-09-16 10:23:08 +08:00
Soulter
91ee0a870d fix: handle image value correctly for mcp BlobResourceContents (#2753) 2025-09-16 08:22:18 +08:00
dependabot[bot]
6cbbffc5a9 chore(deps): bump the github-actions group with 2 updates (#2771)
Bumps the github-actions group with 2 updates: [actions/checkout](https://github.com/actions/checkout) and [actions/setup-python](https://github.com/actions/setup-python).


Updates `actions/checkout` from 4 to 5
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)

Updates `actions/setup-python` from 5 to 6
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-16 08:19:31 +08:00
Yokami
8f26fd34d1 feat: add copy button for service providers (#2767) 2025-09-15 22:17:00 +08:00
Soulter
fda655f6d7 fix: 修复配置默认 TTS 或者 STT 模型之后仍无法生效的问题 (#2758)
fixes: #2731
2025-09-15 22:08:40 +08:00
100 changed files with 8040 additions and 8458 deletions

View File

@@ -12,10 +12,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'

View File

@@ -18,7 +18,8 @@
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
@@ -110,7 +111,6 @@ uv run main.py
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ⚡ 消息平台支持情况
| 平台 | 支持性 |
@@ -127,6 +127,8 @@ uv run main.py
| Discord | ✔ |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
## ⚡ 提供商支持情况
@@ -172,7 +174,6 @@ pip install pre-commit
pre-commit install
```
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
@@ -200,14 +201,11 @@ pre-commit install
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
</details>
_私は、高性能ですから!_

View File

@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str, FunctionTool] | None = None
tools: list[str | FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -92,7 +92,7 @@ class MCPClient:
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
@@ -198,6 +198,8 @@ class MCPClient:
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response

View File

@@ -258,7 +258,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
yield MessageChain(
type="tool_direct_result"
).base64_image(res.content[0].data)
).base64_image(resource.blob)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -269,17 +269,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
yield MessageChain().message("返回的数据类型不受支持。")
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool_name,
func_tool_args,
resp,
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
@@ -289,27 +278,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
else:
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool, func_tool_args, None
)
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
self.run_context.event.clear_result()
except Exception as e:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Literal, Any, Optional
from typing import Awaitable, Callable, Literal, Any, Optional
from .mcp_client import MCPClient
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str | None = None
name: str
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
handler: Callable[..., Awaitable[Any]] | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
@@ -51,7 +51,7 @@ class ToolSet:
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] = None):
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
@@ -79,7 +79,13 @@ class ToolSet:
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
@@ -104,7 +110,7 @@ class ToolSet:
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> list[FunctionTool]:
def get_func(self, name: str) -> FunctionTool | None:
"""Get all function tools."""
return self.get_tool(name)
@@ -125,7 +131,11 @@ class ToolSet:
},
}
if tool.parameters.get("properties") or not omit_empty_parameter_field:
if (
tool.parameters
and tool.parameters.get("properties")
or not omit_empty_parameter_field
):
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
@@ -135,14 +145,14 @@ class ToolSet:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
"input_schema": input_schema,
}
result.append(tool_def)
return result
@@ -210,14 +220,15 @@ class ToolSet:
return result
tools = [
{
tools = []
for tool in self.tools:
d = {
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools.append(d)
declarations = {}
if tools:

View File

@@ -6,7 +6,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.1.2"
VERSION = "4.2.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -60,6 +60,7 @@ DEFAULT_CONFIG = {
"web_search_link": False,
"display_reasoning_text": False,
"identifier": False,
"group_name_display": False,
"datetime_system_prompt": True,
"default_personality": "default",
"persona_pool": ["*"],
@@ -235,6 +236,16 @@ CONFIG_METADATA_2 = {
"discord_guild_id_for_debug": "",
"discord_activity_name": "",
},
"Misskey": {
"id": "misskey",
"type": "misskey",
"enable": False,
"misskey_instance_url": "https://misskey.example",
"misskey_token": "",
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
},
"Slack": {
"id": "slack",
"type": "slack",
@@ -252,7 +263,7 @@ CONFIG_METADATA_2 = {
"type": "satori",
"enable": False,
"satori_api_base_url": "http://localhost:5140/satori/v1",
"satori_endpoint": "ws://127.0.0.1:5140/satori/v1/events",
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
"satori_token": "",
"satori_auto_reconnect": True,
"satori_heartbeat_interval": 10,
@@ -261,34 +272,34 @@ CONFIG_METADATA_2 = {
},
"items": {
"satori_api_base_url": {
"description": "Satori API Base URL",
"description": "Satori API 终结点",
"type": "string",
"hint": "The base URL for the Satori API.",
"hint": "Satori API 的基础地址。",
},
"satori_endpoint": {
"description": "Satori WebSocket Endpoint",
"description": "Satori WebSocket 终结点",
"type": "string",
"hint": "The WebSocket endpoint for Satori events.",
"hint": "Satori 事件的 WebSocket 端点。",
},
"satori_token": {
"description": "Satori Token",
"description": "Satori 令牌",
"type": "string",
"hint": "The token used for authenticating with the Satori API.",
"hint": "用于 Satori API 身份验证的令牌。",
},
"satori_auto_reconnect": {
"description": "Enable Auto Reconnect",
"description": "启用自动重连",
"type": "bool",
"hint": "Whether to automatically reconnect the WebSocket on disconnection.",
"hint": "断开连接时是否自动重新连接 WebSocket。",
},
"satori_heartbeat_interval": {
"description": "Satori Heartbeat Interval",
"description": "Satori 心跳间隔",
"type": "int",
"hint": "The interval (in seconds) for sending heartbeat messages.",
"hint": "发送心跳消息的间隔(秒)。",
},
"satori_reconnect_delay": {
"description": "Satori Reconnect Delay",
"description": "Satori 重连延迟",
"type": "int",
"hint": "The delay (in seconds) before attempting to reconnect.",
"hint": "尝试重新连接前的延迟时间(秒)。",
},
"slack_connection_mode": {
"description": "Slack Connection Mode",
@@ -336,6 +347,32 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
},
"misskey_instance_url": {
"description": "Misskey 实例 URL",
"type": "string",
"hint": "例如 https://misskey.example填写 Bot 账号所在的 Misskey 实例地址",
},
"misskey_token": {
"description": "Misskey Access Token",
"type": "string",
"hint": "连接服务设置生成的 API 鉴权访问令牌Access token",
},
"misskey_default_visibility": {
"description": "默认帖子可见性",
"type": "string",
"options": ["public", "home", "followers"],
"hint": "机器人发帖时的默认可见性设置。public公开home主页时间线followers仅关注者。",
},
"misskey_local_only": {
"description": "仅限本站(不参与联合)",
"type": "bool",
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
},
"misskey_enable_chat": {
"description": "启用聊天消息响应",
"type": "bool",
"hint": "启用后,机器人将会监听和响应私信聊天消息",
},
"telegram_command_register": {
"description": "Telegram 命令注册",
"type": "bool",
@@ -832,6 +869,18 @@ CONFIG_METADATA_2 = {
"timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
},
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "chat_completion",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
"auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
@@ -1698,6 +1747,26 @@ CONFIG_METADATA_2 = {
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
"obvious": True,
},
"coze_api_key": {
"description": "Coze API Key",
"type": "string",
"hint": "Coze API 密钥,用于访问 Coze 服务。",
},
"bot_id": {
"description": "Bot ID",
"type": "string",
"hint": "Coze 机器人的 ID在 Coze 平台上创建机器人后获得。",
},
"coze_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
},
"auto_save_history": {
"description": "由 Coze 管理对话记录",
"type": "bool",
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
},
},
},
"provider_settings": {
@@ -1724,6 +1793,9 @@ CONFIG_METADATA_2 = {
"identifier": {
"type": "bool",
},
"group_name_display": {
"type": "bool",
},
"datetime_system_prompt": {
"type": "bool",
},
@@ -1903,17 +1975,31 @@ CONFIG_METADATA_3 = {
"_special": "select_provider",
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
},
"provider_stt_settings.enable": {
"description": "默认启用语音转文本",
"type": "bool",
},
"provider_stt_settings.provider_id": {
"description": "语音转文本模型",
"type": "string",
"hint": "留空代表不使用。",
"_special": "select_provider_stt",
"condition": {
"provider_stt_settings.enable": True,
},
},
"provider_tts_settings.enable": {
"description": "默认启用文本转语音",
"type": "bool",
},
"provider_tts_settings.provider_id": {
"description": "文本转语音模型",
"type": "string",
"hint": "留空代表不使用。",
"_special": "select_provider_tts",
"condition": {
"provider_tts_settings.enable": True,
},
},
"provider_settings.image_caption_prompt": {
"description": "图片转述提示词",
@@ -1983,6 +2069,11 @@ CONFIG_METADATA_3 = {
"description": "用户识别",
"type": "bool",
},
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
},
"provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知",
"type": "bool",

View File

@@ -87,17 +87,25 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
f = False
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
f = True
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
if f:
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
if curr_cid == conversation_id:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
"""删除会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID

View File

@@ -154,6 +154,11 @@ class BaseDatabase(abc.ABC):
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def delete_conversations_by_user_id(self, user_id: str) -> None:
"""Delete all conversations for a specific user."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,

View File

@@ -18,6 +18,7 @@ from astrbot.core.db.po import (
from sqlalchemy import select, update, delete, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
from sqlalchemy import or_
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
@@ -153,8 +154,22 @@ class SQLiteDatabase(BaseDatabase):
ConversationV2.platform_id.in_(platform_ids)
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
ConversationV2.title.ilike(f"%{search_query}%")
or_(
ConversationV2.title.ilike(f"%{search_query}%"),
ConversationV2.content.ilike(f"%{search_query}%"),
ConversationV2.user_id.ilike(f"%{search_query}%"),
)
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
ConversationV2.user_id.ilike(f"%:{msg_type}:%")
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
ConversationV2.platform_id.in_(kwargs["platforms"])
)
# Get total count matching the filters
@@ -234,6 +249,14 @@ class SQLiteDatabase(BaseDatabase):
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(ConversationV2.user_id == user_id)
)
async def insert_platform_message_history(
self,
platform_id,

View File

@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
self.strategy_selector = StrategySelector(config)
async def process(
self, event: AstrMessageEvent, check_text: str = None
self, event: AstrMessageEvent, check_text: str | None = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()

View File

@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
def check(self, content: str):
def check(self, content: str) -> tuple[bool, str]:
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
return False, ""

View File

@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
def check(self, content: str) -> bool:
def check(self, content: str) -> tuple[bool, str]:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"

View File

@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Awaitable,
handler: T.Callable[..., T.Awaitable[T.Any]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -36,6 +36,9 @@ async def call_handler(
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:

View File

@@ -7,6 +7,7 @@ import copy
import json
import traceback
from typing import AsyncGenerator, Union
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
@@ -133,6 +134,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
@@ -148,7 +158,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
)
yield mcp.types.CallToolResult(content=[text_content])
else:
yield mcp.types.TextContent(
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
@@ -200,7 +210,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await tool.mcp_client.session.call_tool(
session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
)
@@ -271,19 +285,12 @@ async def run_agent(
except Exception as e:
logger.error(traceback.format_exc())
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
class LLMRequestSubStage(Stage):
@@ -325,7 +332,7 @@ class LLMRequestSubStage(Stage):
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent):
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
@@ -337,6 +344,8 @@ class LLMRequestSubStage(Stage):
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation
async def process(
@@ -444,7 +453,10 @@ class LLMRequestSubStage(Stage):
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
plugin = star_map.get(tool.handler_module_path)
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
@@ -505,6 +517,14 @@ class LLMRequestSubStage(Stage):
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
@@ -517,7 +537,23 @@ class LLMRequestSubStage(Stage):
latest_pair = messages[-2:]
if not latest_pair:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
content = latest_pair[0].get("content", "")
if isinstance(content, list):
# 多模态
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "image":
text_parts.append("[图片]")
elif isinstance(item, str):
text_parts.append(item)
cleaned_text = "User: " + " ".join(text_parts).strip()
elif isinstance(content, str):
cleaned_text = "User: " + content.strip()
else:
return
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",

View File

@@ -34,12 +34,14 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
md = star_map.get(handler.handler_module_path)
if not md:
logger.warning(
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
)
continue
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
try:
wrapper = call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
@@ -49,7 +51,7 @@ class StarRequestSubStage(Stage):
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()

View File

@@ -1,17 +1,15 @@
import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from ..context import PipelineContext, call_event_hook
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.components import BaseMessageComponent, ComponentType
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
@@ -114,6 +112,43 @@ class RespondStage(Stage):
# 如果所有组件都为空
return True
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
"""检查是否需要分段回复"""
if not self.enable_seg:
return False
if self.only_llm_result and not event.get_result().is_llm_result():
return False
if event.get_platform_name() in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
return False
return True
def _extract_comp(
self,
raw_chain: list[BaseMessageComponent],
extract_types: set[ComponentType],
modify_raw_chain: bool = True,
):
extracted = []
if modify_raw_chain:
remaining = []
for comp in raw_chain:
if comp.type in extract_types:
extracted.append(comp)
else:
remaining.append(comp)
raw_chain[:] = remaining
else:
extracted = [comp for comp in raw_chain if comp.type in extract_types]
return extracted
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
@@ -123,7 +158,14 @@ class RespondStage(Stage):
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
logger.info(
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
if result.result_content_type == ResultContentType.STREAMING_RESULT:
if result.async_stream is None:
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
@@ -148,87 +190,71 @@ class RespondStage(Stage):
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
non_record_comps = [
c for c in result.chain if not isinstance(c, Comp.Record)
]
if (
self.enable_seg
and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
# 发送消息链
# Record 需要强制单独发送
need_separately = {ComponentType.Record}
if self.is_seg_reply_required(event):
header_comps = self._extract_comp(
result.chain,
{ComponentType.Reply, ComponentType.At},
modify_raw_chain=True,
)
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, Comp.At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Comp.Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
# leverage lock to guarentee the order of message sending among different events
if not result.chain or len(result.chain) == 0:
# may fix #2670
logger.warning(
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
)
return
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in non_record_comps:
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([*decorated_comps, comp]))
decorated_comps = [] # 清空已发送的装饰组件
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
else:
for rcomp in record_comps:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}
for comp in result.chain
):
# may fix #2670
logger.warning(
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
)
return
sep_comps = self._extract_comp(
result.chain,
need_separately,
modify_raw_chain=True,
)
for comp in sep_comps:
chain = MessageChain([comp])
try:
await event.send(MessageChain([rcomp]))
await event.send(chain)
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
chain = MessageChain(result.chain)
if result.chain and len(result.chain) > 0:
try:
await event.send(chain)
except Exception as e:
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
try:
await event.send(MessageChain(non_record_comps))
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
return
event.clear_result()

View File

@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
pass
self.ctx = ctx
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
async def process(
self, event: AstrMessageEvent
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
conv_id = await self.conv_mgr.get_curr_conversation_id(
event.unified_msg_origin
)
if not conv_id:
await self.conv_mgr.new_conversation(
event.unified_msg_origin, platform_id=event.get_platform_id()
)
event.stop_event()

View File

@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
activated_handlers.append(handler)
if "parsed_params" in event.get_extra():
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
is_group_cmd_handler = any(
isinstance(f, CommandGroupFilter) for f in handler.event_filters
)
if not is_group_cmd_handler:
activated_handlers.append(handler)
if "parsed_params" in event.get_extra(default={}):
handlers_parsed_params[handler.handler_full_name] = (
event.get_extra("parsed_params")
)
event._extras.pop("parsed_params", None)

View File

@@ -4,7 +4,7 @@ import re
import hashlib
import uuid
from typing import List, Union, Optional, AsyncGenerator
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
from astrbot import logger
from astrbot.core.db.po import Conversation
@@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
_VT = TypeVar("_VT")
class AstrMessageEvent(abc.ABC):
def __init__(
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
"""是否唤醒(是否通过 WakingStage)"""
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras = {}
self._extras: dict[str, Any] = {}
self.session = MessageSesion(
platform_name=platform_meta.id,
message_type=message_obj.type,
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
)
self.unified_msg_origin = str(self.session)
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self._result: MessageEventResult = None
self._result: MessageEventResult | None = None
"""消息事件的结果"""
self._has_send_oper = False
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(self, key=None):
def get_extra(
self, key: str | None = None, default: _VT = None
) -> dict[str, Any] | _VT:
"""
获取额外的信息。
"""
if key is None:
return self._extras
return self._extras.get(key, None)
return self._extras.get(key, default)
def clear_extra(self):
"""

View File

@@ -55,7 +55,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group_id: str = "" # 群组id如果为私聊则为空
group: Group # 群组
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -64,6 +64,28 @@ class AstrBotMessage:
def __init__(self) -> None:
self.timestamp = int(time.time())
self.group = None
def __str__(self) -> str:
return str(self.__dict__)
@property
def group_id(self) -> str:
"""
向后兼容的 group_id 属性
群组id如果为私聊则为空
"""
if self.group:
return self.group.group_id
return ""
@group_id.setter
def group_id(self, value: str):
"""设置 group_id"""
if value:
if self.group:
self.group.group_id = value
else:
self.group = Group(group_id=value)
else:
self.group = None

View File

@@ -90,6 +90,10 @@ class PlatformManager:
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "misskey":
from .sources.misskey.misskey_adapter import (
MisskeyPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
case "satori":

View File

@@ -182,11 +182,13 @@ class AiocqhttpAdapter(Platform):
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
str(event.sender["user_id"]), event.sender["nickname"]
str(event.sender["user_id"]),
event.sender.get("card") or event.sender.get("nickname", "N/A"),
)
if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:

View File

@@ -0,0 +1,391 @@
import asyncio
import json
from typing import Dict, Any, Optional, Awaitable
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import (
AstrBotMessage,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
import astrbot.api.message_components as Comp
from .misskey_api import MisskeyAPI
from .misskey_event import MisskeyPlatformEvent
from .misskey_utils import (
serialize_message_chain,
resolve_message_visibility,
is_valid_user_session_id,
is_valid_room_session_id,
add_at_mention_if_needed,
process_files,
extract_sender_info,
create_base_message,
process_at_mention,
cache_user_info,
cache_room_info,
)
@register_platform_adapter("misskey", "Misskey 平台适配器")
class MisskeyPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config or {}
self.settings = platform_settings or {}
self.instance_url = self.config.get("misskey_instance_url", "")
self.access_token = self.config.get("misskey_token", "")
self.max_message_length = self.config.get("max_message_length", 3000)
self.default_visibility = self.config.get(
"misskey_default_visibility", "public"
)
self.local_only = self.config.get("misskey_local_only", False)
self.enable_chat = self.config.get("misskey_enable_chat", True)
self.unique_session = platform_settings["unique_session"]
self.api: Optional[MisskeyAPI] = None
self._running = False
self.client_self_id = ""
self._bot_username = ""
self._user_cache = {}
def meta(self) -> PlatformMetadata:
default_config = {
"misskey_instance_url": "",
"misskey_token": "",
"max_message_length": 3000,
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
}
default_config.update(self.config)
return PlatformMetadata(
name="misskey",
description="Misskey 平台适配器",
id=self.config.get("id", "misskey"),
default_config_tmpl=default_config,
)
async def run(self):
if not self.instance_url or not self.access_token:
logger.error("[Misskey] 配置不完整,无法启动")
return
self.api = MisskeyAPI(self.instance_url, self.access_token)
self._running = True
try:
user_info = await self.api.get_current_user()
self.client_self_id = str(user_info.get("id", ""))
self._bot_username = user_info.get("username", "")
logger.info(
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
)
except Exception as e:
logger.error(f"[Misskey] 获取用户信息失败: {e}")
self._running = False
return
await self._start_websocket_connection()
async def _start_websocket_connection(self):
backoff_delay = 1.0
max_backoff = 300.0
backoff_multiplier = 1.5
connection_attempts = 0
while self._running:
try:
connection_attempts += 1
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
break
streaming = self.api.get_streaming_client()
streaming.add_message_handler("notification", self._handle_notification)
if self.enable_chat:
streaming.add_message_handler(
"newChatMessage", self._handle_chat_message
)
streaming.add_message_handler("_debug", self._debug_handler)
if await streaming.connect():
logger.info(
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
)
connection_attempts = 0 # 重置计数器
await streaming.subscribe_channel("main")
if self.enable_chat:
await streaming.subscribe_channel("messaging")
await streaming.subscribe_channel("messagingIndex")
logger.info("[Misskey] 聊天频道已订阅")
backoff_delay = 1.0 # 重置延迟
await streaming.listen()
else:
logger.error(
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
)
except Exception as e:
logger.error(
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
)
if self._running:
logger.info(
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
)
await asyncio.sleep(backoff_delay)
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
async def _handle_notification(self, data: Dict[str, Any]):
try:
logger.debug(
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
notification_type = data.get("type")
if notification_type in ["mention", "reply", "quote"]:
note = data.get("note")
if note and self._is_bot_mentioned(note):
logger.info(
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
)
message = await self.convert_message(note)
event = MisskeyPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.api,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理通知失败: {e}")
async def _handle_chat_message(self, data: Dict[str, Any]):
try:
logger.debug(
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
sender_id = str(
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
)
if sender_id == self.client_self_id:
return
room_id = data.get("toRoomId")
if room_id:
raw_text = data.get("text", "")
logger.debug(
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
)
message = await self.convert_room_message(data)
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
else:
message = await self.convert_chat_message(data)
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
event = MisskeyPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.api,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
async def _debug_handler(self, data: Dict[str, Any]):
logger.debug(
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
text = note.get("text", "")
if not text:
return False
mentions = note.get("mentions", [])
if self._bot_username and f"@{self._bot_username}" in text:
return True
if self.client_self_id in [str(uid) for uid in mentions]:
return True
reply = note.get("reply")
if reply and isinstance(reply, dict):
reply_user_id = str(reply.get("user", {}).get("id", ""))
if reply_user_id == self.client_self_id:
return bool(self._bot_username and f"@{self._bot_username}" in text)
return False
async def send_by_session(
self, session: MessageSession, message_chain: MessageChain
) -> Awaitable[Any]:
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain)
try:
session_id = session.session_id
text, has_at_user = serialize_message_chain(message_chain.chain)
if not has_at_user and session_id:
user_info = self._user_cache.get(session_id)
text = add_at_mention_if_needed(text, user_info, has_at_user)
if not text or not text.strip():
logger.warning("[Misskey] 消息内容为空,跳过发送")
return await super().send_by_session(session, message_chain)
if len(text) > self.max_message_length:
text = text[: self.max_message_length] + "..."
if session_id and is_valid_user_session_id(session_id):
from .misskey_utils import extract_user_id_from_session_id
user_id = extract_user_id_from_session_id(session_id)
await self.api.send_message(user_id, text)
elif session_id and is_valid_room_session_id(session_id):
from .misskey_utils import extract_room_id_from_session_id
room_id = extract_room_id_from_session_id(session_id)
await self.api.send_room_message(room_id, text)
else:
visibility, visible_user_ids = resolve_message_visibility(
user_id=session_id,
user_cache=self._user_cache,
self_id=self.client_self_id,
default_visibility=self.default_visibility,
)
await self.api.create_note(
text,
visibility=visibility,
visible_user_ids=visible_user_ids,
local_only=self.local_only,
)
except Exception as e:
logger.error(f"[Misskey] 发送消息失败: {e}")
return await super().send_by_session(session, message_chain)
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=False)
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=False,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
)
message_parts = []
raw_text = raw_data.get("text", "")
if raw_text:
text_parts, processed_text = process_at_mention(
message, raw_text, self._bot_username, self.client_self_id
)
message_parts.extend(text_parts)
files = raw_data.get("files", [])
file_parts = process_files(message, files)
message_parts.extend(file_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
else ""
)
return message
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=True)
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=True,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
)
raw_text = raw_data.get("text", "")
if raw_text:
message.message.append(Comp.Plain(raw_text))
files = raw_data.get("files", [])
process_files(message, files, include_text_parts=False)
message.message_str = raw_text if raw_text else ""
return message
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=True)
room_id = raw_data.get("toRoomId", "")
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=False,
room_id=room_id,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
)
cache_room_info(self._user_cache, raw_data, self.client_self_id)
raw_text = raw_data.get("text", "")
message_parts = []
if raw_text:
if self._bot_username and f"@{self._bot_username}" in raw_text:
text_parts, processed_text = process_at_mention(
message, raw_text, self._bot_username, self.client_self_id
)
message_parts.extend(text_parts)
else:
message.message.append(Comp.Plain(raw_text))
message_parts.append(raw_text)
files = raw_data.get("files", [])
file_parts = process_files(message, files)
message_parts.extend(file_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
else ""
)
return message
async def terminate(self):
self._running = False
if self.api:
await self.api.close()
def get_client(self) -> Any:
return self.api

View File

@@ -0,0 +1,404 @@
import json
from typing import Any, Optional, Dict, List, Callable, Awaitable
import uuid
try:
import aiohttp
import websockets
except ImportError as e:
raise ImportError(
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
) from e
from astrbot.api import logger
# Constants
API_MAX_RETRIES = 3
HTTP_OK = 200
class APIError(Exception):
"""Misskey API 基础异常"""
pass
class APIConnectionError(APIError):
"""网络连接异常"""
pass
class APIRateLimitError(APIError):
"""API 频率限制异常"""
pass
class AuthenticationError(APIError):
"""认证失败异常"""
pass
class WebSocketError(APIError):
"""WebSocket 连接异常"""
pass
class StreamingClient:
def __init__(self, instance_url: str, access_token: str):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self.websocket: Optional[Any] = None
self.is_connected = False
self.message_handlers: Dict[str, Callable] = {}
self.channels: Dict[str, str] = {}
self._running = False
self._last_pong = None
async def connect(self) -> bool:
try:
ws_url = self.instance_url.replace("https://", "wss://").replace(
"http://", "ws://"
)
ws_url += f"/streaming?i={self.access_token}"
self.websocket = await websockets.connect(
ws_url, ping_interval=30, ping_timeout=10
)
self.is_connected = True
self._running = True
logger.info("[Misskey WebSocket] 已连接")
return True
except Exception as e:
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
self.is_connected = False
return False
async def disconnect(self):
self._running = False
if self.websocket:
await self.websocket.close()
self.websocket = None
self.is_connected = False
logger.info("[Misskey WebSocket] 连接已断开")
async def subscribe_channel(
self, channel_type: str, params: Optional[Dict] = None
) -> str:
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
channel_id = str(uuid.uuid4())
message = {
"type": "connect",
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
}
await self.websocket.send(json.dumps(message))
self.channels[channel_id] = channel_type
return channel_id
async def unsubscribe_channel(self, channel_id: str):
if (
not self.is_connected
or not self.websocket
or channel_id not in self.channels
):
return
message = {"type": "disconnect", "body": {"id": channel_id}}
await self.websocket.send(json.dumps(message))
del self.channels[channel_id]
def add_message_handler(
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
):
self.message_handlers[event_type] = handler
async def listen(self):
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
try:
async for message in self.websocket:
if not self._running:
break
try:
data = json.loads(message)
await self._handle_message(data)
except json.JSONDecodeError as e:
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
except Exception as e:
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
self.is_connected = False
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
)
self.is_connected = False
except websockets.exceptions.InvalidHandshake as e:
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
self.is_connected = False
except Exception as e:
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
self.is_connected = False
async def _handle_message(self, data: Dict[str, Any]):
message_type = data.get("type")
body = data.get("body", {})
logger.debug(
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
)
if message_type == "channel":
channel_id = body.get("id")
event_type = body.get("type")
event_body = body.get("body", {})
logger.debug(
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
)
if channel_id in self.channels:
channel_type = self.channels[channel_id]
handler_key = f"{channel_type}:{event_type}"
if handler_key in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
await self.message_handlers[handler_key](event_body)
elif event_type in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
await self.message_handlers[event_type](event_body)
else:
logger.debug(
f"[Misskey WebSocket] 未找到处理器: {handler_key}{event_type}"
)
if "_debug" in self.message_handlers:
await self.message_handlers["_debug"](
{
"type": event_type,
"body": event_body,
"channel": channel_type,
}
)
elif message_type in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
await self.message_handlers[message_type](body)
else:
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
if "_debug" in self.message_handlers:
await self.message_handlers["_debug"](data)
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
def decorator(func):
async def wrapper(*args, **kwargs):
last_exc = None
for _ in range(max_retries):
try:
return await func(*args, **kwargs)
except retryable_exceptions as e:
last_exc = e
continue
if last_exc:
raise last_exc
return wrapper
return decorator
class MisskeyAPI:
def __init__(self, instance_url: str, access_token: str):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self._session: Optional[aiohttp.ClientSession] = None
self.streaming: Optional[StreamingClient] = None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
return False
async def close(self) -> None:
if self.streaming:
await self.streaming.disconnect()
self.streaming = None
if self._session:
await self._session.close()
self._session = None
logger.debug("[Misskey API] 客户端已关闭")
def get_streaming_client(self) -> StreamingClient:
if not self.streaming:
self.streaming = StreamingClient(self.instance_url, self.access_token)
return self.streaming
@property
def session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
headers = {"Authorization": f"Bearer {self.access_token}"}
self._session = aiohttp.ClientSession(headers=headers)
return self._session
def _handle_response_status(self, status: int, endpoint: str):
"""处理 HTTP 响应状态码"""
if status == 400:
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
raise APIError(f"Bad request for {endpoint}")
elif status in (401, 403):
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
raise AuthenticationError(f"Authentication failed for {endpoint}")
elif status == 429:
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
else:
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
raise APIConnectionError(f"HTTP {status} for {endpoint}")
async def _process_response(
self, response: aiohttp.ClientResponse, endpoint: str
) -> Any:
"""处理 API 响应"""
if response.status == HTTP_OK:
try:
result = await response.json()
if endpoint == "i/notifications":
notifications_data = (
result
if isinstance(result, list)
else result.get("notifications", [])
if isinstance(result, dict)
else []
)
if notifications_data:
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
else:
logger.debug(f"API 请求成功: {endpoint}")
return result
except json.JSONDecodeError as e:
logger.error(f"响应不是有效的 JSON 格式: {e}")
raise APIConnectionError("Invalid JSON response") from e
else:
try:
error_text = await response.text()
logger.error(
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
)
except Exception:
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
self._handle_response_status(response.status, endpoint)
raise APIConnectionError(f"Request failed for {endpoint}")
@retry_async(
max_retries=API_MAX_RETRIES,
retryable_exceptions=(APIConnectionError, APIRateLimitError),
)
async def _make_request(
self, endpoint: str, data: Optional[Dict[str, Any]] = None
) -> Any:
url = f"{self.instance_url}/api/{endpoint}"
payload = {"i": self.access_token}
if data:
payload.update(data)
try:
async with self.session.post(url, json=payload) as response:
return await self._process_response(response, endpoint)
except aiohttp.ClientError as e:
logger.error(f"HTTP 请求错误: {e}")
raise APIConnectionError(f"HTTP request failed: {e}") from e
async def create_note(
self,
text: str,
visibility: str = "public",
reply_id: Optional[str] = None,
visible_user_ids: Optional[List[str]] = None,
local_only: bool = False,
) -> Dict[str, Any]:
"""创建新贴文"""
data: Dict[str, Any] = {
"text": text,
"visibility": visibility,
"localOnly": local_only,
}
if reply_id:
data["replyId"] = reply_id
if visible_user_ids and visibility == "specified":
data["visibleUserIds"] = visible_user_ids
result = await self._make_request("notes/create", data)
note_id = result.get("createdNote", {}).get("id", "unknown")
logger.debug(f"发帖成功note_id: {note_id}")
return result
async def get_current_user(self) -> Dict[str, Any]:
"""获取当前用户信息"""
return await self._make_request("i", {})
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
"""发送聊天消息"""
result = await self._make_request(
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
)
message_id = result.get("id", "unknown")
logger.debug(f"聊天发送成功message_id: {message_id}")
return result
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
"""发送房间消息"""
result = await self._make_request(
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
)
message_id = result.get("id", "unknown")
logger.debug(f"房间消息发送成功message_id: {message_id}")
return result
async def get_messages(
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""获取聊天消息历史"""
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
if since_id:
data["sinceId"] = since_id
result = await self._make_request("chat/messages/user-timeline", data)
if isinstance(result, list):
return result
else:
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
return []
async def get_mentions(
self, limit: int = 10, since_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""获取提及通知"""
data: Dict[str, Any] = {"limit": limit}
if since_id:
data["sinceId"] = since_id
data["includeTypes"] = ["mention", "reply", "quote"]
result = await self._make_request("i/notifications", data)
if isinstance(result, list):
return result
elif isinstance(result, dict) and "notifications" in result:
return result["notifications"]
else:
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
return []

View File

@@ -0,0 +1,123 @@
import asyncio
import re
from typing import AsyncGenerator
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
from astrbot.api.message_components import Plain
from .misskey_utils import (
serialize_message_chain,
resolve_visibility_from_raw_message,
is_valid_user_session_id,
is_valid_room_session_id,
add_at_mention_if_needed,
extract_user_id_from_session_id,
extract_room_id_from_session_id,
)
class MisskeyPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
def _is_system_command(self, message_str: str) -> bool:
"""检测是否为系统指令"""
if not message_str or not message_str.strip():
return False
system_prefixes = ["/", "!", "#", ".", "^"]
message_trimmed = message_str.strip()
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
async def send(self, message: MessageChain):
content, has_at = serialize_message_chain(message.chain)
if not content:
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
return
try:
original_message_id = getattr(self.message_obj, "message_id", None)
raw_message = getattr(self.message_obj, "raw_message", {})
if raw_message and not has_at:
user_data = raw_message.get("user", {})
user_info = {
"username": user_data.get("username", ""),
"nickname": user_data.get("name", user_data.get("username", "")),
}
content = add_at_mention_if_needed(content, user_info, has_at)
# 根据会话类型选择发送方式
if hasattr(self.client, "send_message") and is_valid_user_session_id(
self.session_id
):
user_id = extract_user_id_from_session_id(self.session_id)
await self.client.send_message(user_id, content)
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
self.session_id
):
room_id = extract_room_id_from_session_id(self.session_id)
await self.client.send_room_message(room_id, content)
elif original_message_id and hasattr(self.client, "create_note"):
visibility, visible_user_ids = resolve_visibility_from_raw_message(
raw_message
)
await self.client.create_note(
content,
reply_id=original_message_id,
visibility=visibility,
visible_user_ids=visible_user_ids,
)
elif hasattr(self.client, "create_note"):
logger.debug("[MisskeyEvent] 创建新帖子")
await self.client.create_note(content)
await super().send(message)
except Exception as e:
logger.error(f"[MisskeyEvent] 发送失败: {e}")
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,327 @@
"""Misskey 平台适配器通用工具函数"""
from typing import Dict, Any, List, Tuple, Optional, Union
import astrbot.api.message_components as Comp
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
"""将消息链序列化为文本字符串"""
text_parts = []
has_at = False
def process_component(component):
nonlocal has_at
if isinstance(component, Comp.Plain):
return component.text
elif isinstance(component, Comp.File):
file_name = getattr(component, "name", "文件")
return f"[文件: {file_name}]"
elif isinstance(component, Comp.At):
has_at = True
return f"@{component.qq}"
elif hasattr(component, "text"):
text = getattr(component, "text", "")
if "@" in text:
has_at = True
return text
else:
return str(component)
for component in chain:
if isinstance(component, Comp.Node) and component.content:
for node_comp in component.content:
result = process_component(node_comp)
if result:
text_parts.append(result)
else:
result = process_component(component)
if result:
text_parts.append(result)
return "".join(text_parts), has_at
def resolve_message_visibility(
user_id: Optional[str],
user_cache: Dict[str, Any],
self_id: Optional[str],
default_visibility: str = "public",
) -> Tuple[str, Optional[List[str]]]:
"""解析 Misskey 消息的可见性设置"""
visibility = default_visibility
visible_user_ids = None
if user_id and user_cache:
user_info = user_cache.get(user_id)
if user_info:
original_visibility = user_info.get("visibility", default_visibility)
if original_visibility == "specified":
visibility = "specified"
original_visible_users = user_info.get("visible_user_ids", [])
users_to_include = [user_id]
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
def resolve_visibility_from_raw_message(
raw_message: Dict[str, Any], self_id: Optional[str] = None
) -> Tuple[str, Optional[List[str]]]:
"""从原始消息数据中解析可见性设置"""
visibility = "public"
visible_user_ids = None
if not raw_message:
return visibility, visible_user_ids
original_visibility = raw_message.get("visibility", "public")
if original_visibility == "specified":
visibility = "specified"
original_visible_users = raw_message.get("visibleUserIds", [])
sender_id = raw_message.get("userId", "")
users_to_include = []
if sender_id:
users_to_include.append(sender_id)
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "chat"
and bool(parts[1])
and parts[1] != "unknown"
)
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "room"
and bool(parts[1])
and parts[1] != "unknown"
)
def extract_user_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取用户 ID"""
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2:
return parts[1]
return session_id
def extract_room_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取房间 ID"""
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2 and parts[0] == "room":
return parts[1]
return session_id
def add_at_mention_if_needed(
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
) -> str:
"""如果需要且没有@用户,则添加@用户"""
if has_at or not user_info:
return text
username = user_info.get("username")
nickname = user_info.get("nickname")
if username:
mention = f"@{username}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
elif nickname:
mention = f"@{nickname}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
return text
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
"""创建文件组件和描述文本"""
file_url = file_info.get("url", "")
file_name = file_info.get("name", "未知文件")
file_type = file_info.get("type", "")
if file_type.startswith("image/"):
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
elif file_type.startswith("audio/"):
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
elif file_type.startswith("video/"):
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
else:
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
def process_files(
message: AstrBotMessage, files: list, include_text_parts: bool = True
) -> list:
"""处理文件列表,添加到消息组件中并返回文本描述"""
file_parts = []
for file_info in files:
component, part_text = create_file_component(file_info)
message.message.append(component)
if include_text_parts:
file_parts.append(part_text)
return file_parts
def extract_sender_info(
raw_data: Dict[str, Any], is_chat: bool = False
) -> Dict[str, Any]:
"""提取发送者信息"""
if is_chat:
sender = raw_data.get("fromUser", {})
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
else:
sender = raw_data.get("user", {})
sender_id = str(sender.get("id", ""))
return {
"sender": sender,
"sender_id": sender_id,
"nickname": sender.get("name", sender.get("username", "")),
"username": sender.get("username", ""),
}
def create_base_message(
raw_data: Dict[str, Any],
sender_info: Dict[str, Any],
client_self_id: str,
is_chat: bool = False,
room_id: Optional[str] = None,
unique_session: bool = False,
) -> AstrBotMessage:
"""创建基础消息对象"""
message = AstrBotMessage()
message.raw_message = raw_data
message.message = []
message.sender = MessageMember(
user_id=sender_info["sender_id"],
nickname=sender_info["nickname"],
)
if room_id:
session_prefix = "room"
session_id = f"{session_prefix}%{room_id}"
if unique_session:
session_id += f"_{sender_info['sender_id']}"
message.type = MessageType.GROUP_MESSAGE
message.group_id = room_id
elif is_chat:
session_prefix = "chat"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.FRIEND_MESSAGE
else:
session_prefix = "note"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.FRIEND_MESSAGE
message.session_id = (
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
)
message.message_id = str(raw_data.get("id", ""))
message.self_id = client_self_id
return message
def process_at_mention(
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
) -> Tuple[List[str], str]:
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
message_parts = []
if not raw_text:
return message_parts, ""
if bot_username and raw_text.startswith(f"@{bot_username}"):
at_mention = f"@{bot_username}"
message.message.append(Comp.At(qq=client_self_id))
remaining_text = raw_text[len(at_mention) :].strip()
if remaining_text:
message.message.append(Comp.Plain(remaining_text))
message_parts.append(remaining_text)
return message_parts, remaining_text
else:
message.message.append(Comp.Plain(raw_text))
message_parts.append(raw_text)
return message_parts, raw_text
def cache_user_info(
user_cache: Dict[str, Any],
sender_info: Dict[str, Any],
raw_data: Dict[str, Any],
client_self_id: str,
is_chat: bool = False,
):
"""缓存用户信息"""
if is_chat:
user_cache_data = {
"username": sender_info["username"],
"nickname": sender_info["nickname"],
"visibility": "specified",
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
}
else:
user_cache_data = {
"username": sender_info["username"],
"nickname": sender_info["nickname"],
"visibility": raw_data.get("visibility", "public"),
"visible_user_ids": raw_data.get("visibleUserIds", []),
}
user_cache[sender_info["sender_id"]] = user_cache_data
def cache_room_info(
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
):
"""缓存房间信息"""
room_data = raw_data.get("toRoom")
room_id = raw_data.get("toRoomId")
if room_data and room_id:
room_cache_key = f"room:{room_id}"
user_cache[room_cache_key] = {
"room_id": room_id,
"room_name": room_data.get("name", ""),
"room_description": room_data.get("description", ""),
"owner_id": room_data.get("ownerId", ""),
"visibility": "specified",
"visible_user_ids": [client_self_id],
}

View File

@@ -17,7 +17,14 @@ from astrbot.api.platform import (
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.api.message_components import Plain, Image, At, File, Record
from astrbot.api.message_components import (
Plain,
Image,
At,
File,
Record,
Reply,
)
from xml.etree import ElementTree as ET
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
)
self.token = self.config.get("satori_token", "")
self.endpoint = self.config.get(
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
)
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
self.metadata = PlatformMetadata(
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
)
self.ws: Optional[ClientConnection] = None
self.session: Optional[ClientSession] = None
self.sequence = 0
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
return self.metadata
def _is_websocket_closed(self, ws) -> bool:
"""检查WebSocket连接是否已关闭"""
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
abm.self_id = login.get("user", {}).get("id", "")
content = message.get("content", "")
abm.message = await self.parse_satori_elements(content)
# 消息链
abm.message = []
content = message.get("content", "")
quote = message.get("quote")
content_for_parsing = content # 副本
# 提取<quote>标签
if "<quote" in content:
try:
quote_info = await self._extract_quote_element(content)
if quote_info:
quote = quote_info["quote"]
content_for_parsing = quote_info["content_without_quote"]
except Exception as e:
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
if quote:
# 引用消息
quote_abm = await self._convert_quote_message(quote)
if quote_abm:
sender_id = quote_abm.sender.user_id
if isinstance(sender_id, str) and sender_id.isdigit():
sender_id = int(sender_id)
elif not isinstance(sender_id, int):
sender_id = 0 # 默认值
reply_component = Reply(
id=quote_abm.message_id,
chain=quote_abm.message,
sender_id=quote_abm.sender.user_id,
sender_nickname=quote_abm.sender.nickname,
time=quote_abm.timestamp,
message_str=quote_abm.message_str,
text=quote_abm.message_str,
qq=sender_id,
)
abm.message.append(reply_component)
# 解析消息内容
content_elements = await self.parse_satori_elements(content_for_parsing)
abm.message.extend(content_elements)
# parse message_str
abm.message_str = ""
for comp in abm.message:
for comp in content_elements:
if isinstance(comp, Plain):
abm.message_str += comp.text
@@ -333,6 +386,163 @@ class SatoriPlatformAdapter(Platform):
logger.error(f"转换 Satori 消息失败: {e}")
return None
def _extract_namespace_prefixes(self, content: str) -> set:
"""提取XML内容中的命名空间前缀"""
prefixes = set()
# 查找所有标签
i = 0
while i < len(content):
# 查找开始标签
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 1 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content and "xmlns:" not in tag_content:
# 分割标签名
parts = tag_content.split()
if parts:
tag_name = parts[0]
if ":" in tag_name:
prefix = tag_name.split(":")[0]
# 确保是有效的命名空间前缀
if (
prefix.isalnum()
or prefix.replace("_", "").isalnum()
):
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
# 查找结束标签
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 2 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content:
prefix = tag_content.split(":")[0]
# 确保是有效的命名空间前缀
if prefix.isalnum() or prefix.replace("_", "").isalnum():
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
else:
i += 1
return prefixes
async def _extract_quote_element(self, content: str) -> Optional[dict]:
"""提取<quote>标签信息"""
try:
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
# 查找<quote>标签
quote_element = None
for elem in root.iter():
tag_name = elem.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
if tag_name.lower() == "quote":
quote_element = elem
break
if quote_element is not None:
# 提取quote标签的属性
quote_id = quote_element.get("id", "")
# 提取<quote>标签内部的内容
inner_content = ""
if quote_element.text:
inner_content += quote_element.text
for child in quote_element:
inner_content += ET.tostring(
child, encoding="unicode", method="xml"
)
if child.tail:
inner_content += child.tail
# 构造移除了<quote>标签的内容
content_without_quote = content.replace(
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
)
return {
"quote": {"id": quote_id, "content": inner_content},
"content_without_quote": content_without_quote,
}
return None
except Exception as e:
logger.error(f"提取<quote>标签时发生错误: {e}")
return None
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
"""转换引用消息"""
try:
quote_abm = AstrBotMessage()
quote_abm.message_id = quote.get("id", "")
# 解析引用消息的发送者
quote_author = quote.get("author", {})
if quote_author:
quote_abm.sender = MessageMember(
user_id=quote_author.get("id", ""),
nickname=quote_author.get("nick", quote_author.get("name", "")),
)
else:
# 如果没有作者信息,使用默认值
quote_abm.sender = MessageMember(
user_id=quote.get("user_id", ""),
nickname="内容",
)
# 解析引用消息内容
quote_content = quote.get("content", "")
quote_abm.message = await self.parse_satori_elements(quote_content)
quote_abm.message_str = ""
for comp in quote_abm.message:
if isinstance(comp, Plain):
quote_abm.message_str += comp.text
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
# 如果没有任何内容,使用默认文本
if not quote_abm.message_str.strip():
quote_abm.message_str = "[引用消息]"
return quote_abm
except Exception as e:
logger.error(f"转换引用消息失败: {e}")
return None
async def parse_satori_elements(self, content: str) -> list:
"""解析 Satori 消息元素"""
elements = []
@@ -341,12 +551,35 @@ class SatoriPlatformAdapter(Platform):
return elements
try:
wrapped_content = f"<root>{content}</root>"
root = ET.fromstring(wrapped_content)
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
await self._parse_xml_node(root, elements)
except ET.ParseError as e:
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
# 如果解析失败,将整个内容当作纯文本
if content.strip():
elements.append(Plain(text=content))
except Exception as e:
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
raise e
# 如果没有解析到任何元素,将整个内容当作纯文本
@@ -361,7 +594,12 @@ class SatoriPlatformAdapter(Platform):
elements.append(Plain(text=node.text))
for child in node:
tag_name = child.tag.lower()
# 获取标签名,去除命名空间前缀
tag_name = child.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
tag_name = tag_name.lower()
attrs = child.attrib
if tag_name == "at":
@@ -372,31 +610,59 @@ class SatoriPlatformAdapter(Platform):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:image/"):
src = src.split(",")[1]
elements.append(Image.fromBase64(src))
elif src.startswith("http"):
elements.append(Image.fromURL(src))
else:
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
elements.append(Image(file=src))
elif tag_name == "file":
src = attrs.get("src", "")
name = attrs.get("name", "文件")
if src:
elements.append(File(file=src, name=name))
elements.append(File(name=name, file=src))
elif tag_name in ("audio", "record"):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:audio/"):
src = src.split(",")[1]
elements.append(Record.fromBase64(src))
elif src.startswith("http"):
elements.append(Record.fromURL(src))
elements.append(Record(file=src))
elif tag_name == "quote":
# quote标签已经被特殊处理
pass
elif tag_name == "face":
face_id = attrs.get("id", "")
face_name = attrs.get("name", "")
face_type = attrs.get("type", "")
if face_name:
elements.append(Plain(text=f"[表情:{face_name}]"))
elif face_id and face_type:
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
elif face_id:
elements.append(Plain(text=f"[表情ID:{face_id}]"))
else:
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
elements.append(Plain(text="[表情]"))
elif tag_name == "ark":
# 作为纯文本添加到消息链中
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[ARK卡片]"))
elif tag_name == "json":
# JSON标签 视为ARK卡片消息
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[JSON卡片]"))
else:
# 未知标签,递归处理其内容

View File

@@ -17,6 +17,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
session_id: str,
adapter: "SatoriPlatformAdapter",
):
# 更新平台元数据
if adapter and hasattr(adapter, "logins") and adapter.logins:
current_login = adapter.logins[0]
platform_name = current_login.get("platform", "satori")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
if not platform_meta.id and user_id:
platform_meta.id = f"{platform_name}({user_id})"
super().__init__(message_str, message_obj, platform_meta, session_id)
self.adapter = adapter
self.platform = None

View File

@@ -218,7 +218,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
delta = ""
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id

View File

@@ -185,6 +185,7 @@ class WecomPlatformAdapter(Platform):
return PlatformMetadata(
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
)
@override

View File

@@ -184,6 +184,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
return PlatformMetadata(
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
)
@override

View File

@@ -65,13 +65,16 @@ class AssistantMessageSegment:
role: str = "assistant"
def to_dict(self):
ret = {
ret: dict[str, str | list[dict]] = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
if self.tool_calls:
ret["tool_calls"] = self.tool_calls
tool_calls_dict = [
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
]
ret["tool_calls"] = tool_calls_dict
return ret
@@ -117,7 +120,14 @@ class ProviderRequest:
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
return (
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
f"image_count={len(self.image_urls or [])}, "
f"func_tool={self.func_tool}, "
f"contexts={self._print_friendly_context()}, "
f"system_prompt={self.system_prompt}, "
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
)
def __str__(self):
return self.__repr__()

View File

@@ -4,7 +4,7 @@ import os
import asyncio
import aiohttp
from typing import Dict, List, Awaitable
from typing import Dict, List, Awaitable, Callable, Any
from astrbot import logger
from astrbot.core import sp
@@ -109,7 +109,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
@@ -132,7 +132,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> None:
"""添加函数调用工具
@@ -220,7 +220,7 @@ class FunctionToolManager:
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future = None,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:

View File

@@ -7,7 +7,13 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
from .provider import (
Provider,
STTProvider,
TTSProvider,
EmbeddingProvider,
RerankProvider,
)
from .register import llm_tools, provider_cls_map
from ..persona_mgr import PersonaManager
@@ -38,7 +44,12 @@ class ProviderManager:
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map: dict[str, Provider] = {}
self.rerank_provider_insts: List[RerankProvider] = []
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -87,19 +98,31 @@ class ProviderManager:
)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
prov = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
prov, TTSProvider
):
self.curr_tts_provider_inst = prov
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.SPEECH_TO_TEXT:
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov, STTProvider
):
self.curr_stt_provider_inst = prov
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.CHAT_COMPLETION:
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov, Provider
):
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(self, provider_type: ProviderType, umo=None):
def get_using_provider(
self, provider_type: ProviderType, umo=None
) -> Provider | STTProvider | TTSProvider | None:
"""获取正在使用的提供商实例。
Args:
@@ -211,6 +234,8 @@ class ProviderManager:
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "coze":
from .sources.coze_source import ProviderCoze as ProviderCoze
case "dashscope":
from .sources.dashscope_source import (
ProviderDashscope as ProviderDashscope,
@@ -303,12 +328,14 @@ class ProviderManager:
provider_metadata = provider_cls_map[provider_config["type"]]
try:
# 按任务实例化提供商
cls_type = provider_metadata.cls_type
if not cls_type:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -327,9 +354,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -345,7 +370,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
inst = cls_type(
provider_config,
self.provider_settings,
self.selected_default_persona,
@@ -366,16 +391,16 @@ class ProviderManager:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type in [
ProviderType.EMBEDDING,
ProviderType.RERANK,
]:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
@@ -430,11 +455,17 @@ class ProviderManager:
)
if self.inst_map[provider_id] in self.provider_insts:
self.provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, Provider):
self.provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.stt_provider_insts:
self.stt_provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, STTProvider):
self.stt_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.tts_provider_insts:
self.tts_provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, TTSProvider):
self.tts_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] == self.curr_provider_inst:
self.curr_provider_inst = None

View File

@@ -0,0 +1,314 @@
import json
import asyncio
import aiohttp
import io
from typing import Dict, List, Any, AsyncGenerator
from astrbot.core import logger
class CozeAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
self.api_key = api_key
self.api_base = api_base
self.session = None
async def _ensure_session(self):
"""确保HTTP session存在"""
if self.session is None:
connector = aiohttp.TCPConnector(
ssl=False if self.api_base.startswith("http://") else True,
limit=100,
limit_per_host=30,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
timeout = aiohttp.ClientTimeout(
total=120, # 默认超时时间
connect=30,
sock_read=120,
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "text/event-stream",
}
self.session = aiohttp.ClientSession(
headers=headers, timeout=timeout, connector=connector
)
return self.session
async def upload_file(
self,
file_data: bytes,
) -> str:
"""上传文件到 Coze 并返回 file_id
Args:
file_data (bytes): 文件的二进制数据
Returns:
str: 上传成功后返回的 file_id
"""
session = await self._ensure_session()
url = f"{self.api_base}/v1/files/upload"
try:
file_io = io.BytesIO(file_data)
async with session.post(
url,
data={
"file": file_io,
},
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
response_text = await response.text()
logger.debug(
f"文件上传响应状态: {response.status}, 内容: {response_text}"
)
if response.status != 200:
raise Exception(
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
)
try:
result = await response.json()
except json.JSONDecodeError:
raise Exception(f"文件上传响应解析失败: {response_text}")
if result.get("code") != 0:
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
file_id = result["data"]["id"]
logger.debug(f"[Coze] 图片上传成功file_id: {file_id}")
return file_id
except asyncio.TimeoutError:
logger.error("文件上传超时")
raise Exception("文件上传超时")
except Exception as e:
logger.error(f"文件上传失败: {str(e)}")
raise Exception(f"文件上传失败: {str(e)}")
async def download_image(self, image_url: str) -> bytes:
"""下载图片并返回字节数据
Args:
image_url (str): 图片的URL
Returns:
bytes: 图片的二进制数据
"""
session = await self._ensure_session()
try:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"下载图片失败,状态码: {response.status}")
image_data = await response.read()
return image_data
except Exception as e:
logger.error(f"下载图片失败 {image_url}: {str(e)}")
raise Exception(f"下载图片失败: {str(e)}")
async def chat_messages(
self,
bot_id: str,
user_id: str,
additional_messages: List[Dict] | None = None,
conversation_id: str | None = None,
auto_save_history: bool = True,
stream: bool = True,
timeout: float = 120,
) -> AsyncGenerator[Dict[str, Any], None]:
"""发送聊天消息并返回流式响应
Args:
bot_id: Bot ID
user_id: 用户ID
additional_messages: 额外消息列表
conversation_id: 会话ID
auto_save_history: 是否自动保存历史
stream: 是否流式响应
timeout: 超时时间
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/chat"
payload = {
"bot_id": bot_id,
"user_id": user_id,
"stream": stream,
"auto_save_history": auto_save_history,
}
if additional_messages:
payload["additional_messages"] = additional_messages
params = {}
if conversation_id:
params["conversation_id"] = conversation_id
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
try:
async with session.post(
url,
json=payload,
params=params,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
# SSE
buffer = ""
event_type = None
event_data = None
async for chunk in response.content:
if chunk:
buffer += chunk.decode("utf-8", errors="ignore")
lines = buffer.split("\n")
buffer = lines[-1]
for line in lines[:-1]:
line = line.strip()
if not line:
if event_type and event_data:
yield {"event": event_type, "data": event_data}
event_type = None
event_data = None
elif line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
data_str = line[5:].strip()
if data_str and data_str != "[DONE]":
try:
event_data = json.loads(data_str)
except json.JSONDecodeError:
event_data = {"content": data_str}
except asyncio.TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
except Exception as e:
raise Exception(f"Coze API 流式请求失败: {str(e)}")
async def clear_context(self, conversation_id: str):
"""清空会话上下文
Args:
conversation_id: 会话ID
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/clear_context"
payload = {"conversation_id": conversation_id}
try:
async with session.post(url, json=payload) as response:
response_text = await response.text()
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
try:
return json.loads(response_text)
except json.JSONDecodeError:
raise Exception("Coze API 返回非JSON格式")
except asyncio.TimeoutError:
raise Exception("Coze API 请求超时")
except aiohttp.ClientError as e:
raise Exception(f"Coze API 请求失败: {str(e)}")
async def get_message_list(
self,
conversation_id: str,
order: str = "desc",
limit: int = 10,
offset: int = 0,
):
"""获取消息列表
Args:
conversation_id: 会话ID
order: 排序方式 (asc/desc)
limit: 限制数量
offset: 偏移量
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/list"
params = {
"conversation_id": conversation_id,
"order": order,
"limit": limit,
"offset": offset,
}
try:
async with session.get(url, params=params) as response:
response.raise_for_status()
return await response.json()
except Exception as e:
logger.error(f"获取Coze消息列表失败: {str(e)}")
raise Exception(f"获取Coze消息列表失败: {str(e)}")
async def close(self):
"""关闭会话"""
if self.session:
await self.session.close()
self.session = None
if __name__ == "__main__":
import os
import asyncio
async def test_coze_api_client():
api_key = os.getenv("COZE_API_KEY", "")
bot_id = os.getenv("COZE_BOT_ID", "")
client = CozeAPIClient(api_key=api_key)
try:
with open("README.md", "rb") as f:
file_data = f.read()
file_id = await client.upload_file(file_data)
print(f"Uploaded file_id: {file_id}")
async for event in client.chat_messages(
bot_id=bot_id,
user_id="test_user",
additional_messages=[
{
"role": "user",
"content": json.dumps(
[
{"type": "text", "text": "这是什么"},
{"type": "file", "file_id": file_id},
],
ensure_ascii=False,
),
"content_type": "object_string",
},
],
stream=True,
):
print(f"Event: {event}")
finally:
await client.close()
asyncio.run(test_coze_api_client())

View File

@@ -0,0 +1,635 @@
import json
import os
import base64
import hashlib
from typing import AsyncGenerator, Dict
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.entities import LLMResponse
from ..register import register_provider_adapter
from .coze_api_client import CozeAPIClient
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
class ProviderCoze(Provider):
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://")
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
self.conversation_ids: Dict[str, str] = {}
self.file_id_cache: Dict[str, Dict[str, str]] = {}
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
"""生成统一的缓存键
Args:
data: 图片数据或路径
is_base64: 是否是 base64 数据
Returns:
str: 缓存键
"""
try:
if is_base64 and data.startswith("data:image/"):
try:
header, encoded = data.split(",", 1)
image_bytes = base64.b64decode(encoded)
cache_key = hashlib.md5(image_bytes).hexdigest()
return cache_key
except Exception:
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
return cache_key
else:
if data.startswith(("http://", "https://")):
# URL图片使用URL作为缓存键
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
return cache_key
else:
clean_path = (
data.split("_")[0]
if "_" in data and len(data.split("_")) >= 3
else data
)
if os.path.exists(clean_path):
with open(clean_path, "rb") as f:
file_content = f.read()
cache_key = hashlib.md5(file_content).hexdigest()
return cache_key
else:
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
return cache_key
except Exception as e:
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
return cache_key
async def _upload_file(
self,
file_data: bytes,
session_id: str | None = None,
cache_key: str | None = None,
) -> str:
"""上传文件到 Coze 并返回 file_id"""
# 使用 API 客户端上传文件
file_id = await self.api_client.upload_file(file_data)
# 缓存 file_id
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
async def _download_and_upload_image(
self, image_url: str, session_id: str | None = None
) -> str:
"""下载图片并上传到 Coze返回 file_id"""
# 计算哈希实现缓存
cache_key = self._generate_cache_key(image_url) if session_id else None
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self._upload_file(image_data, session_id, cache_key)
if session_id and cache_key:
self.file_id_cache[session_id][cache_key] = file_id
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {str(e)}")
raise Exception(f"处理图片失败: {str(e)}")
async def _process_context_images(
self, content: str | list, session_id: str
) -> str:
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
try:
if isinstance(content, str):
return content
processed_content = []
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
for item in content:
if not isinstance(item, dict):
processed_content.append(item)
continue
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片逻辑
if "file_id" in item:
# 已经有 file_id
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
processed_content.append(item)
else:
# 获取图片数据
image_data = ""
if "image_url" in item and isinstance(item["image_url"], dict):
image_data = item["image_url"].get("url", "")
elif "data" in item:
image_data = item.get("data", "")
elif "url" in item:
image_data = item.get("url", "")
if not image_data:
continue
# 计算哈希用于缓存
cache_key = self._generate_cache_key(
image_data, is_base64=image_data.startswith("data:image/")
)
# 检查缓存
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
processed_content.append(
{"type": "image", "file_id": file_id}
)
else:
# 上传图片并缓存
if image_data.startswith("data:image/"):
# base64 处理
_, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
elif image_data.startswith(("http://", "https://")):
# URL 图片
file_id = await self._download_and_upload_image(
image_data, session_id
)
# 为URL图片也添加缓存
self.file_id_cache[session_id][cache_key] = file_id
elif os.path.exists(image_data):
# 本地文件
with open(image_data, "rb") as f:
image_bytes = f.read()
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
else:
logger.warning(
f"无法处理的图片格式: {image_data[:50]}..."
)
continue
processed_content.append(
{"type": "image", "file_id": file_id}
)
result = json.dumps(processed_content, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"处理上下文图片失败: {str(e)}")
if isinstance(content, str):
return content
else:
return json.dumps(content, ensure_ascii=False)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
"""文本对话, 内部使用流式接口实现非流式
Args:
prompt (str): 用户提示词
session_id (str): 会话ID
image_urls (List[str]): 图片URL列表
func_tool (FuncCall): 函数调用工具(不支持)
contexts (List): 上下文列表
system_prompt (str): 系统提示语
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
model (str): 模型名称(不支持)
Returns:
LLMResponse: LLM响应对象
"""
accumulated_content = ""
final_response = None
async for llm_response in self.text_chat_stream(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
model=model,
**kwargs,
):
if llm_response.is_chunk:
if llm_response.completion_text:
accumulated_content += llm_response.completion_text
else:
final_response = llm_response
if final_response:
return final_response
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
return LLMResponse(role="assistant", result_chain=chain)
else:
return LLMResponse(role="assistant", completion_text="")
async def text_chat_stream(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话接口"""
# 用户ID参数(参考文档, 可以自定义)
user_id = session_id or kwargs.get("user", "default_user")
# 获取或创建会话ID
conversation_id = self.conversation_ids.get(user_id)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{"role": "system", "content": system_prompt, "content_type": "text"}
)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
content = ctx["content"]
content_type = ctx.get("content_type", "text")
# 处理可能包含图片的上下文
if (
content_type == "object_string"
or (isinstance(content, str) and content.startswith("["))
or (
isinstance(content, list)
and any(
isinstance(item, dict)
and item.get("type") == "image_url"
for item in content
)
)
):
processed_content = await self._process_context_images(
content, user_id
)
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
}
)
else:
# 纯文本
additional_messages.append(
{
"role": ctx["role"],
"content": (
content
if isinstance(content, str)
else json.dumps(content, ensure_ascii=False)
),
"content_type": "text",
}
)
else:
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
try:
if url.startswith(("http://", "https://")):
# 网络图片
file_id = await self._download_and_upload_image(
url, user_id
)
else:
# 本地文件或 base64
if url.startswith("data:image/"):
# base64
_, encoded = url.split(",", 1)
image_data = base64.b64decode(encoded)
cache_key = self._generate_cache_key(
url, is_base64=True
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
# 本地文件
if os.path.exists(url):
with open(url, "rb") as f:
image_data = f.read()
# 用文件路径和修改时间来缓存
file_stat = os.stat(url)
cache_key = self._generate_cache_key(
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
is_base64=False,
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
logger.warning(f"图片文件不存在: {url}")
continue
object_string_content.append(
{
"type": "image",
"file_id": file_id,
}
)
except Exception as e:
logger.error(f"处理图片失败 {url}: {str(e)}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
}
)
else:
# 纯文本
if prompt:
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
}
)
try:
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
self.conversation_ids[user_id] = data["conversation_id"]
elif event_type == "conversation.message.delta":
if isinstance(data, dict):
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
message_started = True
accumulated_content += content
yield LLMResponse(
role="assistant",
completion_text=content,
is_chunk=True,
)
elif event_type == "conversation.message.completed":
if isinstance(data, dict):
msg_type = data.get("type")
if msg_type == "answer" and data.get("role") == "assistant":
final_content = data.get("content", "")
if not accumulated_content and final_content:
chain = MessageChain(chain=[Comp.Plain(final_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
elif event_type == "conversation.chat.completed":
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
break
elif event_type == "done":
break
elif event_type == "error":
error_msg = (
data.get("message", "未知错误")
if isinstance(data, dict)
else str(data)
)
logger.error(f"Coze 流式响应错误: {error_msg}")
yield LLMResponse(
role="err",
completion_text=f"Coze 错误: {error_msg}",
is_chunk=False,
)
break
if not message_started and not accumulated_content:
yield LLMResponse(
role="assistant",
completion_text="LLM 未响应任何内容。",
is_chunk=False,
)
elif message_started and accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
except Exception as e:
logger.error(f"Coze 流式请求失败: {str(e)}")
yield LLMResponse(
role="err",
completion_text=f"Coze 流式请求失败: {str(e)}",
is_chunk=False,
)
async def forget(self, session_id: str):
"""清空指定会话的上下文"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if user_id in self.file_id_cache:
self.file_id_cache.pop(user_id, None)
if not conversation_id:
return True
try:
response = await self.api_client.clear_context(conversation_id)
if "code" in response and response["code"] == 0:
self.conversation_ids.pop(user_id, None)
return True
else:
logger.warning(f"清空 Coze 会话上下文失败: {response}")
return False
except Exception as e:
logger.error(f"清空 Coze 会话失败: {str(e)}")
return False
async def get_current_key(self):
"""获取当前API Key"""
return self.api_key
async def set_key(self, key: str):
"""设置新的API Key"""
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
async def get_models(self):
"""获取可用模型列表"""
return [f"bot_{self.bot_id}"]
def get_model(self):
"""获取当前模型"""
return f"bot_{self.bot_id}"
def set_model(self, model: str):
"""设置模型在Coze中是Bot ID"""
if model.startswith("bot_"):
self.bot_id = model[4:]
else:
self.bot_id = model
async def get_human_readable_context(
self, session_id: str, page: int = 1, page_size: int = 10
):
"""获取人类可读的上下文历史"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if not conversation_id:
return []
try:
data = await self.api_client.get_message_list(
conversation_id=conversation_id,
order="desc",
limit=page_size,
offset=(page - 1) * page_size,
)
if data.get("code") != 0:
logger.warning(f"获取 Coze 消息历史失败: {data}")
return []
messages = data.get("data", {}).get("messages", [])
readable_history = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
msg_type = msg.get("type", "")
if role == "user":
readable_history.append(f"用户: {content}")
elif role == "assistant" and msg_type == "answer":
readable_history.append(f"助手: {content}")
return readable_history
except Exception as e:
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
return []
async def terminate(self):
"""清理资源"""
await self.api_client.close()

View File

@@ -6,6 +6,7 @@ from astrbot.core.provider.provider import (
TTSProvider,
STTProvider,
EmbeddingProvider,
RerankProvider,
)
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
@@ -23,7 +24,7 @@ from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from typing import Awaitable, Any, Callable
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
@@ -103,9 +104,14 @@ class Context:
"""
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
def get_provider_by_id(
self, provider_id: str
) -> (
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
):
"""通过 ID 获取对应的 LLM Provider。"""
prov = self.provider_manager.inst_map.get(provider_id)
return prov
def get_all_providers(self) -> List[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
@@ -130,34 +136,43 @@ class Context:
Args:
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
if prov and not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
"""
获取当前使用的用于 TTS 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.TEXT_TO_SPEECH,
umo=umo,
)
if prov and not isinstance(prov, TTSProvider):
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
return prov
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
"""
获取当前使用的用于 STT 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
if prov and not isinstance(prov, STTProvider):
raise ValueError("返回的 Provider 不是 STTProvider 类型")
return prov
def get_config(self, umo: str | None = None) -> AstrBotConfig:
"""获取 AstrBot 的配置。"""
@@ -245,7 +260,11 @@ class Context:
"""
def register_llm_tool(
self, name: str, func_args: list, desc: str, func_obj: Awaitable
self,
name: str,
func_args: list,
desc: str,
func_obj: Callable[..., Awaitable[Any]],
) -> None:
"""
为函数调用function-calling / tools-use添加工具。
@@ -267,9 +286,7 @@ class Context:
desc=desc,
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(
name, func_args, desc, func_obj, func_obj
)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
def unregister_llm_tool(self, name: str) -> None:
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
@@ -281,7 +298,7 @@ class Context:
command_name: str,
desc: str,
priority: int,
awaitable: Awaitable,
awaitable: Callable[..., Awaitable[Any]],
use_regex=False,
ignore_prefix=False,
):

View File

@@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter):
self.init_handler_md(handler_md)
self.custom_filter_list: List[CustomFilter] = []
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def print_types(self):
result = ""
for k, v in self.handler_params.items():
@@ -136,6 +139,28 @@ class CommandFilter(HandlerFilter):
)
return result
def get_complete_command_names(self):
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
self._cmpl_cmd_names = [
f"{parent} {cmd}" if parent else cmd
for cmd in [self.command_name] + list(self.alias)
for parent in self.parent_command_names or [""]
]
return self._cmpl_cmd_names
def startswith(self, message_str: str) -> bool:
for full_cmd in self.get_complete_command_names():
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
return True
return False
def equals(self, message_str: str) -> bool:
for full_cmd in self.get_complete_command_names():
if message_str == full_cmd:
return True
return False
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -145,19 +170,7 @@ class CommandFilter(HandlerFilter):
# 检查是否以指令开头
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
candidates = [self.command_name] + list(self.alias)
ok = False
for candidate in candidates:
for parent_command_name in self.parent_command_names:
if parent_command_name:
_full = f"{parent_command_name} {candidate}"
else:
_full = candidate
if message_str.startswith(f"{_full} ") or message_str == _full:
message_str = message_str[len(_full) :].strip()
ok = True
break
if not ok:
if not self.startswith(message_str):
return False
# 分割为列表

View File

@@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter):
def __init__(
self,
group_name: str,
alias: set = None,
parent_group: CommandGroupFilter = None,
alias: set | None = None,
parent_group: CommandGroupFilter | None = None,
):
self.group_name = group_name
self.alias = alias if alias else set()
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
self.custom_filter_list: List[CustomFilter] = []
self.parent_group = parent_group
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def add_sub_command_filter(
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
):
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
"""遍历父节点获取完整的指令名。
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
parent_cmd_names = (
self.parent_group.get_complete_command_names() if self.parent_group else []
)
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
for parent_cmd_name in parent_cmd_names:
for candidate in candidates:
result.append(parent_cmd_name + " " + candidate)
self._cmpl_cmd_names = result
return result
# 以树的形式打印出来
@@ -54,8 +61,8 @@ class CommandGroupFilter(HandlerFilter):
self,
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
prefix: str = "",
event: AstrMessageEvent = None,
cfg: AstrBotConfig = None,
event: AstrMessageEvent | None = None,
cfg: AstrBotConfig | None = None,
) -> str:
result = ""
for sub_filter in sub_command_filters:
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
return False
return True
def startswith(self, message_str: str) -> bool:
return message_str.startswith(tuple(self.get_complete_command_names()))
def equals(self, message_str: str) -> bool:
return message_str in self.get_complete_command_names()
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
if not self.custom_filter_ok(event, cfg):
return False
complete_command_names = self.get_complete_command_names()
if event.message_str.strip() in complete_command_names:
if self.equals(event.message_str.strip()):
tree = (
self.group_name
+ "\n"
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
)
# complete_command_names = [name + " " for name in complete_command_names]
# return event.message_str.startswith(tuple(complete_command_names))
return False
return self.startswith(event.message_str)

View File

@@ -2,7 +2,6 @@ import enum
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from typing import Union
class PlatformAdapterType(enum.Flag):
@@ -19,6 +18,7 @@ class PlatformAdapterType(enum.Flag):
VOCECHAT = enum.auto()
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
SATORI = enum.auto()
MISSKEY = enum.auto()
ALL = (
AIOCQHTTP
| QQOFFICIAL
@@ -33,6 +33,7 @@ class PlatformAdapterType(enum.Flag):
| VOCECHAT
| WEIXIN_OFFICIAL_ACCOUNT
| SATORI
| MISSKEY
)
@@ -50,15 +51,19 @@ ADAPTER_NAME_2_TYPE = {
"vocechat": PlatformAdapterType.VOCECHAT,
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
"satori": PlatformAdapterType.SATORI,
"misskey": PlatformAdapterType.MISSKEY,
}
class PlatformAdapterTypeFilter(HandlerFilter):
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
self.type_or_str = platform_adapter_type_or_str
def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str):
if isinstance(platform_adapter_type_or_str, str):
self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str)
else:
self.platform_type = platform_adapter_type_or_str
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
adapter_name = event.get_platform_name()
if adapter_name in ADAPTER_NAME_2_TYPE:
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None:
return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type)
return False

View File

@@ -5,7 +5,9 @@ from astrbot.core.star import StarMetadata, star_map
_warned_register_star = False
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
def register_star(
name: str, author: str, desc: str, version: str, repo: str | None = None
):
"""注册一个插件(Star)。
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。

View File

@@ -12,7 +12,7 @@ from ..filter.platform_adapter_type import (
from ..filter.permission import PermissionTypeFilter, PermissionType
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
from ..filter.regex import RegexFilter
from typing import Awaitable
from typing import Awaitable, Any, Callable
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core.agent.agent import Agent
@@ -20,15 +20,19 @@ from astrbot.core.agent.tool import FunctionTool
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core import logger
def get_handler_full_name(awaitable: Awaitable) -> str:
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
"""获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(
handler: Awaitable, event_type: EventType, dont_add=False, **kwargs
handler: Callable[..., Awaitable[Any]],
event_type: EventType,
dont_add=False,
**kwargs,
) -> StarHandlerMetadata:
"""获取 Handler 或者创建一个新的 Handler"""
handler_full_name = get_handler_full_name(handler)
@@ -59,22 +63,35 @@ def get_handler_or_create(
def register_command(
command_name: str = None, sub_command: str = None, alias: set = None, **kwargs
command_name: str | None = None,
sub_command: str | None = None,
alias: set | None = None,
**kwargs,
):
"""注册一个 Command."""
new_command = None
add_to_event_filters = False
if isinstance(command_name, RegisteringCommandable):
# 子指令
parent_command_names = command_name.parent_group.get_complete_command_names()
new_command = CommandFilter(
sub_command, alias, None, parent_command_names=parent_command_names
)
command_name.parent_group.add_sub_command_filter(new_command)
if sub_command is not None:
parent_command_names = (
command_name.parent_group.get_complete_command_names()
)
new_command = CommandFilter(
sub_command, alias, None, parent_command_names=parent_command_names
)
command_name.parent_group.add_sub_command_filter(new_command)
else:
logger.warning(
f"注册指令{command_name} 的子指令时未提供 sub_command 参数。"
)
else:
# 裸指令
new_command = CommandFilter(command_name, alias, None)
add_to_event_filters = True
if command_name is None:
logger.warning("注册裸指令时未提供 command_name 参数。")
else:
new_command = CommandFilter(command_name, alias, None)
add_to_event_filters = True
def decorator(awaitable):
if not add_to_event_filters:
@@ -84,8 +101,9 @@ def register_command(
handler_md = get_handler_or_create(
awaitable, EventType.AdapterMessageEvent, **kwargs
)
new_command.init_handler_md(handler_md)
handler_md.event_filters.append(new_command)
if new_command:
new_command.init_handler_md(handler_md)
handler_md.event_filters.append(new_command)
return awaitable
return decorator
@@ -163,26 +181,38 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
def register_command_group(
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
command_group_name: str | None = None,
sub_command: str | None = None,
alias: set | None = None,
**kwargs,
):
"""注册一个 CommandGroup"""
new_group = None
if isinstance(command_group_name, RegisteringCommandable):
# 子指令组
new_group = CommandGroupFilter(
sub_command, alias, parent_group=command_group_name.parent_group
)
command_group_name.parent_group.add_sub_command_filter(new_group)
if sub_command is None:
logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定")
else:
new_group = CommandGroupFilter(
sub_command, alias, parent_group=command_group_name.parent_group
)
command_group_name.parent_group.add_sub_command_filter(new_group)
else:
# 根指令组
new_group = CommandGroupFilter(command_group_name, alias)
if command_group_name is None:
logger.warning("根指令组的名称未指定")
else:
new_group = CommandGroupFilter(command_group_name, alias)
def decorator(obj):
# 根指令组
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(new_group)
if new_group:
handler_md = get_handler_or_create(
obj, EventType.AdapterMessageEvent, **kwargs
)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
return RegisteringCommandable(new_group)
return decorator
@@ -323,7 +353,7 @@ def register_on_llm_response(**kwargs):
return decorator
def register_llm_tool(name: str = None, **kwargs):
def register_llm_tool(name: str | None = None, **kwargs):
"""为函数调用function-calling / tools-use添加工具。
请务必按照以下格式编写一个工具包括函数注释AstrBot 会尝试解析该函数注释)
@@ -361,9 +391,10 @@ def register_llm_tool(name: str = None, **kwargs):
if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Awaitable):
def decorator(awaitable: Callable[..., Awaitable[Any]]):
llm_tool_name = name_ if name_ else awaitable.__name__
docstring = docstring_parser.parse(awaitable.__doc__)
func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
@@ -379,20 +410,18 @@ def register_llm_tool(name: str = None, **kwargs):
)
# print(llm_tool_name, registering_agent)
if not registering_agent:
doc_desc = docstring.description.strip() if docstring.description else ""
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(
llm_tool_name, args, docstring.description.strip(), md.handler
)
llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler)
else:
assert isinstance(registering_agent, RegisteringAgent)
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
if registering_agent._agent.tools is None:
registering_agent._agent.tools = []
registering_agent._agent.tools.append(
llm_tools.spec_to_func(
llm_tool_name, args, docstring.description.strip(), awaitable
)
)
desc = docstring.description.strip() if docstring.description else ""
tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable)
registering_agent._agent.tools.append(tool)
return awaitable
@@ -413,8 +442,8 @@ class RegisteringAgent:
def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
):
"""注册一个 Agent
@@ -426,7 +455,7 @@ def register_agent(
"""
tools_ = tools or []
def decorator(awaitable: Awaitable):
def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent(
name=name,

View File

@@ -52,10 +52,6 @@ class SessionServiceManager:
"session_service_config", session_config, scope="umo", scope_id=session_id
)
logger.info(
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
)
@staticmethod
def should_process_llm_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理LLM请求

View File

@@ -140,6 +140,9 @@ class SessionPluginManager:
filtered_handlers.append(handler)
continue
if plugin.name is None:
continue
# 检查插件是否在当前会话中启用
if SessionPluginManager.is_plugin_enabled_for_session(
session_id, plugin.name

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import Awaitable, List, Dict, TypeVar, Generic
from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
@@ -60,7 +60,7 @@ class StarHandlerRegistry(Generic[T]):
handlers.append(handler)
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None:
return self.star_handlers_map.get(full_name, None)
def get_handlers_by_module_name(
@@ -87,7 +87,7 @@ class StarHandlerRegistry(Generic[T]):
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry()
star_handlers_registry = StarHandlerRegistry() # type: ignore
class EventType(enum.Enum):
@@ -123,7 +123,7 @@ class StarHandlerMetadata:
handler_module_path: str
"""Handler 所在的模块路径。"""
handler: Awaitable
handler: Callable[..., Awaitable[Any]]
"""Handler 的函数对象,应当是一个异步函数"""
event_filters: List[HandlerFilter]

View File

@@ -43,7 +43,7 @@ class PluginManager:
self.updator = PluginUpdator()
self.context = context
self.context._star_manager = self
self.context._star_manager = self # type: ignore
self.config = config
self.plugin_store_path = get_astrbot_plugin_path()
@@ -478,9 +478,10 @@ class PluginManager:
if isinstance(func_tool, HandoffTool):
need_apply = []
sub_tools = func_tool.agent.tools
for sub_tool in sub_tools:
if isinstance(sub_tool, FunctionTool):
need_apply.append(sub_tool)
if sub_tools:
for sub_tool in sub_tools:
if isinstance(sub_tool, FunctionTool):
need_apply.append(sub_tool)
else:
need_apply = [func_tool]
@@ -686,6 +687,9 @@ class PluginManager:
)
# 从 star_registry 和 star_map 中删除
if plugin.module_path is None or root_dir_name is None:
raise Exception(f"插件 {plugin_name} 数据不完整,无法卸载。")
await self._unbind_plugin(plugin_name, plugin.module_path)
try:
@@ -800,6 +804,8 @@ class PluginManager:
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if plugin is None:
raise Exception(f"插件 {plugin_name} 不存在。")
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:

View File

@@ -22,7 +22,7 @@ import inspect
import os
import uuid
from pathlib import Path
from typing import Union, Awaitable, List, Optional, ClassVar
from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
@@ -221,7 +221,11 @@ class StarTools:
@classmethod
def register_llm_tool(
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
cls,
name: str,
func_args: list,
desc: str,
func_obj: Callable[..., Awaitable[Any]],
) -> None:
"""
为函数调用function-calling/tools-use添加工具

View File

@@ -32,6 +32,9 @@ class PluginUpdator(RepoZipUpdator):
if not repo_url:
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
if not plugin.root_dir_name:
raise Exception(f"插件 {plugin.name} 的根目录名未指定。")
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")

View File

@@ -1,9 +1,33 @@
import codecs
import json
from astrbot.core import logger
from aiohttp import ClientSession
from aiohttp import ClientSession, ClientResponse
from typing import Dict, List, Any, AsyncGenerator
async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
decoder = codecs.getincrementaldecoder("utf-8")()
buffer = ""
async for chunk in resp.content.iter_chunked(8192):
buffer += decoder.decode(chunk)
while "\n\n" in buffer:
block, buffer = buffer.split("\n\n", 1)
if block.strip().startswith("data:"):
try:
yield json.loads(block[5:])
except json.JSONDecodeError:
logger.warning(f"Drop invalid dify json data: {block[5:]}")
continue
# flush any remaining text
buffer += decoder.decode(b"", final=True)
if buffer.strip().startswith("data:"):
try:
yield json.loads(buffer[5:])
except json.JSONDecodeError:
logger.warning(f"Drop invalid dify json data: {buffer[5:]}")
pass
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
@@ -33,31 +57,11 @@ class DifyAPIClient:
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode("utf-8")
blocks = buffer.split("\n\n")
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith("data:"):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
raise Exception(
f"Dify /chat-messages 接口请求失败:{resp.status}. {text}"
)
async for event in _stream_sse(resp):
yield event
async def workflow_run(
self,
@@ -77,31 +81,11 @@ class DifyAPIClient:
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode("utf-8")
blocks = buffer.split("\n\n")
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith("data:"):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
raise Exception(
f"Dify /workflows/run 接口请求失败:{resp.status}. {text}"
)
async for event in _stream_sse(resp):
yield event
async def file_upload(
self,
@@ -109,12 +93,15 @@ class DifyAPIClient:
user: str,
) -> Dict[str, Any]:
url = f"{self.api_base}/files/upload"
payload = {
"user": user,
"file": open(file_path, "rb"),
}
async with self.session.post(url, data=payload, headers=self.headers) as resp:
return await resp.json() # {"id": "xxx", ...}
with open(file_path, "rb") as f:
payload = {
"user": user,
"file": f,
}
async with self.session.post(
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()

View File

@@ -227,9 +227,11 @@ async def download_dashboard(
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
if latest or len(str(version)) != 40:
logger.info(f"准备下载 {version} 发行版本的 AstrBot WebUI 文件")
ver_name = "latest" if latest else version
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
logger.info(
f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}"
)
try:
await download_file(dashboard_release_url, path, show_progress=True)
except BaseException as _:
@@ -241,24 +243,10 @@ async def download_dashboard(
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
await download_file(dashboard_release_url, path, show_progress=True)
else:
logger.info(f"准备下载指定版本的 AstrBot WebUI: {version}")
url = (
"https://api.github.com/repos/AstrBotDevs/astrbot-release-harbour/releases"
)
url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
if proxy:
url = f"{proxy}/{url}"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as resp:
if resp.status == 200:
releases = await resp.json()
for release in releases:
if version in release["tag_name"]:
download_url = release["assets"][0]["browser_download_url"]
await download_file(download_url, path, show_progress=True)
else:
logger.warning(f"未找到指定的版本的 Dashboard 构建文件: {version}")
return
await download_file(url, path, show_progress=True)
with zipfile.ZipFile(path, "r") as z:
z.extractall(extract_path)

View File

@@ -1,17 +1,27 @@
import uuid
import json
import os
import asyncio
from contextlib import asynccontextmanager
from .route import Route, Response, RouteContext
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from quart import request, Response as QuartResponse, g, make_response
from astrbot.core.db import BaseDatabase
import asyncio
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.platform.astr_message_event import MessageSession
@asynccontextmanager
async def track_conversation(convs: dict, conv_id: str):
convs[conv_id] = True
try:
yield
finally:
convs.pop(conv_id, None)
class ChatRoute(Route):
def __init__(
self,
@@ -40,6 +50,8 @@ class ChatRoute(Route):
self.conv_mgr = core_lifecycle.conversation_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.running_convs: dict[str, bool] = {}
async def get_file(self):
filename = request.args.get("filename")
if not filename:
@@ -139,42 +151,63 @@ class ChatRoute(Route):
)
async def stream():
client_disconnected = False
try:
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=10)
except asyncio.TimeoutError:
continue
async with track_conversation(self.running_convs, webchat_conv_id):
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True
except Exception as e:
logger.error(f"WebChat stream error: {e}")
if not result:
continue
if not result:
continue
result_text = result["data"]
type = result.get("type")
streaming = result.get("streaming", False)
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.05)
result_text = result["data"]
type = result.get("type")
streaming = result.get("streaming", False)
if type == "end":
break
elif (
(streaming and type == "complete")
or not streaming
or type == "break"
):
# append bot message
new_his = {"type": "bot", "message": result_text}
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
try:
if not client_disconnected:
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
except Exception as e:
if not client_disconnected:
logger.debug(
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
)
client_disconnected = True
except BaseException as _:
logger.debug(f"用户 {username} 断开聊天长连接。")
return
try:
if not client_disconnected:
await asyncio.sleep(0.05)
except asyncio.CancelledError:
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True
if type == "end":
break
elif (
(streaming and type == "complete")
or not streaming
or type == "break"
):
# append bot message
new_his = {"type": "bot", "message": result_text}
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
# Put message to conversation-specific queue
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
@@ -291,6 +324,7 @@ class ChatRoute(Route):
.ok(
data={
"history": history_res,
"is_running": self.running_convs.get(webchat_conv_id, False),
}
)
.__dict__

View File

@@ -51,24 +51,6 @@ def validate_config(
def validate(data: dict, metadata: dict = schema, path=""):
for key, value in data.items():
if key not in metadata:
# 无 schema 的配置项,执行类型猜测
if isinstance(value, str):
try:
data[key] = int(value)
continue
except ValueError:
pass
try:
data[key] = float(value)
continue
except ValueError:
pass
if value.lower() == "true":
data[key] = True
elif value.lower() == "false":
data[key] = False
continue
meta = metadata[key]
if "type" not in meta:
@@ -127,12 +109,12 @@ def validate_config(
)
if is_core:
for key, group in schema.items():
group_meta = group.get("metadata")
if not group_meta:
continue
# logger.info(f"验证配置: 组 {key} ...")
validate(data, group_meta, path=f"{key}.")
meta_all = {
**schema["platform_group"]["metadata"],
**schema["provider_group"]["metadata"],
**schema["misc_config_group"]["metadata"],
}
validate(data, meta_all)
else:
validate(data, schema)
@@ -142,6 +124,7 @@ def validate_config(
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
"""验证并保存配置"""
errors = None
logger.info(f"Saving config, is_core={is_core}")
try:
if is_core:
errors, post_config = validate_config(

View File

@@ -169,15 +169,65 @@ class ConversationRoute(Route):
"""删除对话"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
await self.core_lifecycle.conversation_manager.delete_conversation(
unified_msg_origin=user_id, conversation_id=cid
)
return Response().ok({"message": "对话删除成功"}).__dict__
# 检查是否是批量删除
if "conversations" in data:
# 批量删除
conversations = data.get("conversations", [])
if not conversations:
return (
Response().error("批量删除时conversations参数不能为空").__dict__
)
deleted_count = 0
failed_items = []
for conv in conversations:
user_id = conv.get("user_id")
cid = conv.get("cid")
if not user_id or not cid:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 缺少必要参数"
)
continue
try:
await self.core_lifecycle.conversation_manager.delete_conversation(
unified_msg_origin=user_id, conversation_id=cid
)
deleted_count += 1
except Exception as e:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {str(e)}")
message = f"成功删除 {deleted_count} 个对话"
if failed_items:
message += f",失败 {len(failed_items)}"
return (
Response()
.ok(
{
"message": message,
"deleted_count": deleted_count,
"failed_count": len(failed_items),
"failed_items": failed_items,
}
)
.__dict__
)
else:
# 单个删除
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
await self.core_lifecycle.conversation_manager.delete_conversation(
unified_msg_origin=user_id, conversation_id=cid
)
return Response().ok({"message": "对话删除成功"}).__dict__
except Exception as e:
logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")

View File

@@ -30,6 +30,7 @@ class SessionManagementRoute(Route):
"/session/update_tts": ("POST", self.update_session_tts),
"/session/update_name": ("POST", self.update_session_name),
"/session/update_status": ("POST", self.update_session_status),
"/session/delete": ("POST", self.delete_session),
}
self.conv_mgr = core_lifecycle.conversation_manager
self.core_lifecycle = core_lifecycle
@@ -180,60 +181,132 @@ class SessionManagementRoute(Route):
logger.error(error_msg)
return Response().error(f"获取会话列表失败: {str(e)}").__dict__
async def _update_single_session_persona(self, session_id: str, persona_name: str):
"""更新单个会话的 persona 的内部方法"""
conversation_manager = self.core_lifecycle.star_context.conversation_manager
conversation_id = await conversation_manager.get_curr_conversation_id(
session_id
)
conv = None
if conversation_id:
conv = await conversation_manager.get_conversation(
unified_msg_origin=session_id,
conversation_id=conversation_id,
)
if not conv or not conversation_id:
conversation_id = await conversation_manager.new_conversation(session_id)
# 更新 persona
await conversation_manager.update_conversation_persona_id(
session_id, persona_name
)
async def _handle_batch_operation(
self, session_ids: list, operation_func, operation_name: str, **kwargs
):
"""通用的批量操作处理方法"""
success_count = 0
error_sessions = []
for session_id in session_ids:
try:
await operation_func(session_id, **kwargs)
success_count += 1
except Exception as e:
logger.error(f"批量{operation_name} 会话 {session_id} 失败: {str(e)}")
error_sessions.append(session_id)
if error_sessions:
return (
Response()
.ok(
{
"message": f"批量更新完成,成功: {success_count},失败: {len(error_sessions)}",
"success_count": success_count,
"error_count": len(error_sessions),
"error_sessions": error_sessions,
}
)
.__dict__
)
else:
return (
Response()
.ok(
{
"message": f"成功批量{operation_name} {success_count} 个会话",
"success_count": success_count,
}
)
.__dict__
)
async def update_session_persona(self):
"""更新指定会话的 persona"""
"""更新指定会话的 persona,支持批量操作"""
try:
data = await request.get_json()
session_id = data.get("session_id")
is_batch = data.get("is_batch", False)
persona_name = data.get("persona_name")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
if persona_name is None:
return Response().error("缺少必要参数: persona_name").__dict__
# 获取会话当前的对话 ID
conversation_manager = self.core_lifecycle.star_context.conversation_manager
conversation_id = await conversation_manager.get_curr_conversation_id(
session_id
)
if is_batch:
session_ids = data.get("session_ids", [])
if not session_ids:
return Response().error("缺少必要参数: session_ids").__dict__
if not conversation_id:
# 如果没有对话,创建一个新的对话
conversation_id = await conversation_manager.new_conversation(
session_id
return await self._handle_batch_operation(
session_ids,
self._update_single_session_persona,
"更新人格",
persona_name=persona_name,
)
else:
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
# 更新 persona
await conversation_manager.update_conversation_persona_id(
session_id, persona_name
)
return (
Response()
.ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"})
.__dict__
)
await self._update_single_session_persona(session_id, persona_name)
return (
Response()
.ok(
{
"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"
}
)
.__dict__
)
except Exception as e:
error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"更新会话人格失败: {str(e)}").__dict__
async def _update_single_session_provider(
self, session_id: str, provider_id: str, provider_type_enum
):
"""更新单个会话的 provider 的内部方法"""
provider_manager = self.core_lifecycle.star_context.provider_manager
await provider_manager.set_provider(
provider_id=provider_id,
provider_type=provider_type_enum,
umo=session_id,
)
async def update_session_provider(self):
"""更新指定会话的 provider"""
"""更新指定会话的 provider,支持批量操作"""
try:
data = await request.get_json()
session_id = data.get("session_id")
is_batch = data.get("is_batch", False)
provider_id = data.get("provider_id")
# "chat_completion", "speech_to_text", "text_to_speech"
provider_type = data.get("provider_type")
if not session_id or not provider_id or not provider_type:
if not provider_id or not provider_type:
return (
Response()
.error("缺少必要参数: session_id, provider_id, provider_type")
.error("缺少必要参数: provider_id, provider_type")
.__dict__
)
@@ -251,23 +324,35 @@ class SessionManagementRoute(Route):
.__dict__
)
# 设置 provider
provider_manager = self.core_lifecycle.star_context.provider_manager
await provider_manager.set_provider(
provider_id=provider_id,
provider_type=provider_type_enum,
umo=session_id,
)
if is_batch:
session_ids = data.get("session_ids", [])
if not session_ids:
return Response().error("缺少必要参数: session_ids").__dict__
return (
Response()
.ok(
{
"message": f"成功更新会话 {session_id}{provider_type} 提供商为 {provider_id}"
}
return await self._handle_batch_operation(
session_ids,
self._update_single_session_provider,
f"更新 {provider_type} 提供商",
provider_id=provider_id,
provider_type_enum=provider_type_enum,
)
else:
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
await self._update_single_session_provider(
session_id, provider_id, provider_type_enum
)
return (
Response()
.ok(
{
"message": f"成功更新会话 {session_id}{provider_type} 提供商为 {provider_id}"
}
)
.__dict__
)
.__dict__
)
except Exception as e:
error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}"
@@ -376,66 +461,98 @@ class SessionManagementRoute(Route):
logger.error(error_msg)
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
async def _update_single_session_llm(self, session_id: str, enabled: bool):
"""更新单个会话的LLM状态的内部方法"""
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
async def update_session_llm(self):
"""更新指定会话的LLM启停状态"""
"""更新指定会话的LLM启停状态,支持批量操作"""
try:
data = await request.get_json()
session_id = data.get("session_id")
is_batch = data.get("is_batch", False)
enabled = data.get("enabled")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
if enabled is None:
return Response().error("缺少必要参数: enabled").__dict__
# 使用 SessionServiceManager 更新LLM状态
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
if is_batch:
session_ids = data.get("session_ids", [])
if not session_ids:
return Response().error("缺少必要参数: session_ids").__dict__
return (
Response()
.ok(
{
"message": f"LLM已{'启用' if enabled else '禁用'}",
"session_id": session_id,
"llm_enabled": enabled,
}
result = await self._handle_batch_operation(
session_ids,
self._update_single_session_llm,
f"{'启用' if enabled else '禁用'}LLM",
enabled=enabled,
)
return result
else:
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
await self._update_single_session_llm(session_id, enabled)
return (
Response()
.ok(
{
"message": f"LLM已{'启用' if enabled else '禁用'}",
"session_id": session_id,
"llm_enabled": enabled,
}
)
.__dict__
)
.__dict__
)
except Exception as e:
error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
async def _update_single_session_tts(self, session_id: str, enabled: bool):
"""更新单个会话的TTS状态的内部方法"""
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
async def update_session_tts(self):
"""更新指定会话的TTS启停状态"""
"""更新指定会话的TTS启停状态,支持批量操作"""
try:
data = await request.get_json()
session_id = data.get("session_id")
is_batch = data.get("is_batch", False)
enabled = data.get("enabled")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
if enabled is None:
return Response().error("缺少必要参数: enabled").__dict__
# 使用 SessionServiceManager 更新TTS状态
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
if is_batch:
session_ids = data.get("session_ids", [])
if not session_ids:
return Response().error("缺少必要参数: session_ids").__dict__
return (
Response()
.ok(
{
"message": f"TTS已{'启用' if enabled else '禁用'}",
"session_id": session_id,
"tts_enabled": enabled,
}
result = await self._handle_batch_operation(
session_ids,
self._update_single_session_tts,
f"{'启用' if enabled else '禁用'}TTS",
enabled=enabled,
)
return result
else:
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
await self._update_single_session_tts(session_id, enabled)
return (
Response()
.ok(
{
"message": f"TTS已{'启用' if enabled else '禁用'}",
"session_id": session_id,
"tts_enabled": enabled,
}
)
.__dict__
)
.__dict__
)
except Exception as e:
error_msg = f"更新会话TTS状态失败: {str(e)}\n{traceback.format_exc()}"
@@ -507,3 +624,43 @@ class SessionManagementRoute(Route):
error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__
async def delete_session(self):
"""删除指定会话及其所有相关数据"""
try:
data = await request.get_json()
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
# 删除会话的所有相关数据
conversation_manager = self.core_lifecycle.conversation_manager
# 1. 删除会话的所有对话
try:
await conversation_manager.delete_conversations_by_user_id(session_id)
except Exception as e:
logger.warning(f"删除会话 {session_id} 的对话失败: {str(e)}")
# 2. 清除会话的偏好设置数据(清空该会话的所有配置)
try:
await sp.clear_async("umo", session_id)
except Exception as e:
logger.warning(f"清除会话 {session_id} 的偏好设置失败: {str(e)}")
return (
Response()
.ok(
{
"message": f"会话 {session_id} 及其相关所有对话数据已成功删除",
"session_id": session_id,
}
)
.__dict__
)
except Exception as e:
error_msg = f"删除会话失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"删除会话失败: {str(e)}").__dict__

8
changelogs/v4.1.3.md Normal file
View File

@@ -0,0 +1,8 @@
# What's Changed
0. ‼️ fix: 修复 4.0.0 版本之后,配置默认 TTS 或者 STT 模型之后仍无法生效的问题 ([#2758](https://github.com/Soulter/AstrBot/issues/2758))
1. ‼️ fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 ([#2757](https://github.com/Soulter/AstrBot/issues/2757))
2. feat: 支持在 WebUI 复制提供商配置以简化操作 ([#2767](https://github.com/Soulter/AstrBot/issues/2767))
3. fix: handle image value correctly for mcp BlobResourceContents ([#2753](https://github.com/Soulter/AstrBot/issues/2753))
4. feat: 增加 QQ 群名称识别到 system prompt, 并提供相应的配置 ([#2770](https://github.com/Soulter/AstrBot/issues/2770))
5. fix: parameter type/default handling in CommandFilter

10
changelogs/v4.1.4.md Normal file
View File

@@ -0,0 +1,10 @@
# What's Changed
0. ‼️ fix: 修复 4.0.0 版本之后,配置默认 TTS 或者 STT 模型之后仍无法生效的问题 ([#2758](https://github.com/Soulter/AstrBot/issues/2758))
1. ‼️ fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 ([#2757](https://github.com/Soulter/AstrBot/issues/2757))
2. feat: 支持在 WebUI 复制提供商配置以简化操作 ([#2767](https://github.com/Soulter/AstrBot/issues/2767))
3. fix: handle image value correctly for mcp BlobResourceContents ([#2753](https://github.com/Soulter/AstrBot/issues/2753))
4. feat: 增加 QQ 群名称识别到 system prompt, 并提供相应的配置 ([#2770](https://github.com/Soulter/AstrBot/issues/2770))
5. fix: 修复 4.1.3 的异常问题
**总之上个版本有很严重的 bug 赶快更新!**

11
changelogs/v4.1.5.md Normal file
View File

@@ -0,0 +1,11 @@
# What's Changed
0. feat: 新增 Misskey 平台适配器 ([#2774](https://github.com/AstrBotDevs/AstrBot/issues/2774))
1. fix: 修复aiocqhttp适配器at会获取群昵称而消息不会获取的逻辑不一致 ([#2769](https://github.com/AstrBotDevs/AstrBot/issues/2769))
2. fix: 修复「对话管理」页面的关键词搜索功能失效的问题并优化一些 UI 样式 ([#2837](https://github.com/AstrBotDevs/AstrBot/issues/2837))
3. fix: 识别「引用消息」的图片时优先使用默认图片转述提供商 ([#2836](https://github.com/AstrBotDevs/AstrBot/issues/2836))
5. fix: 修复 Telegram 下流式传输时,第一次输出的内容会被覆盖掉的问题
6. perf: 优化统计页内存占用和消息数据趋势的样式 ([#2826](https://github.com/AstrBotDevs/AstrBot/issues/2826))
7. perf: 优化 「插件页」、「对话管理页」、「会话管理页」的样式
8. fix: on_tool_end hook unavailable
9. feat: add audioop-lts dependencies ([#2809](https://github.com/AstrBotDevs/AstrBot/issues/2809))

3
changelogs/v4.1.6.md Normal file
View File

@@ -0,0 +1,3 @@
# What's Changed
1. fix: 修复在某些情况下,出现 「返回的 Provider 不是 Provider 类型的错误」

8
changelogs/v4.1.7.md Normal file
View File

@@ -0,0 +1,8 @@
# What's Changed
1. perf: 优化 WebChat 等组件的 UI 风格
2. fix: 修复 4.1.6 版本可能无法点击更新按钮的问题
3. fix: 修复更新开发版的时候,可能无法同时更新 WebUI 的问题
4. feat: 支持在「对话数据」页批量删除对话
5. fix: 修复部分错误地显示「格式校验未通过」的问题
6. perf: WebChat 支持手动填写模型名称

1
changelogs/v4.2.0.md Normal file
View File

@@ -0,0 +1 @@
# What's Changed

View File

@@ -1,7 +1,28 @@
<template>
<RouterView></RouterView>
<!-- 全局唯一 snackbar -->
<v-snackbar v-if="toastStore.current" v-model="snackbarShow" :color="toastStore.current.color"
:timeout="toastStore.current.timeout" :multi-line="toastStore.current.multiLine"
:location="toastStore.current.location" close-on-back>
{{ toastStore.current.message }}
<template #actions v-if="toastStore.current.closable">
<v-btn variant="text" @click="snackbarShow = false">关闭</v-btn>
</template>
</v-snackbar>
</template>
<script setup lang="ts">
<script setup>
import { RouterView } from 'vue-router';
import { computed } from 'vue'
import { useToastStore } from '@/stores/toast'
const toastStore = useToastStore()
const snackbarShow = computed({
get: () => !!toastStore.current,
set: (val) => {
if (!val) toastStore.shift()
}
})
</script>

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,774 @@
<template>
<div class="messages-container" ref="messageContainer">
<!-- 聊天消息列表 -->
<div class="message-list">
<div class="message-item fade-in" v-for="(msg, index) in messages" :key="index">
<!-- 用户消息 -->
<div v-if="msg.content.type == 'user'" class="user-message">
<div class="message-bubble user-bubble" :class="{ 'has-audio': msg.content.audio_url }"
:style="{ backgroundColor: isDark ? '#2d2e30' : '#e7ebf4' }">
<pre
style="font-family: inherit; white-space: pre-wrap; word-wrap: break-word;">{{ msg.content.message }}</pre>
<!-- 图片附件 -->
<div class="image-attachments" v-if="msg.content.image_url && msg.content.image_url.length > 0">
<div v-for="(img, index) in msg.content.image_url" :key="index" class="image-attachment">
<img :src="img" class="attached-image" @click="$emit('openImagePreview', img)" />
</div>
</div>
<!-- 音频附件 -->
<div class="audio-attachment" v-if="msg.content.audio_url && msg.content.audio_url.length > 0">
<audio controls class="audio-player">
<source :src="msg.content.audio_url" type="audio/wav">
{{ t('messages.errors.browser.audioNotSupported') }}
</audio>
</div>
</div>
</div>
<!-- Bot Messages -->
<div v-else class="bot-message">
<v-avatar class="bot-avatar" size="36">
<v-progress-circular :index="index" v-if="isStreaming && index === messages.length - 1" indeterminate size="28"
width="2"></v-progress-circular>
<span v-else-if="messages[index - 1]?.content.type !== 'bot'" class="text-h2"></span>
</v-avatar>
<div class="bot-message-content">
<div class="message-bubble bot-bubble">
<!-- Text -->
<div v-if="msg.content.message && msg.content.message.trim()"
v-html="md.render(msg.content.message)" class="markdown-content"></div>
<!-- Image -->
<div class="embedded-images"
v-if="msg.content.embedded_images && msg.content.embedded_images.length > 0">
<div v-for="(img, imgIndex) in msg.content.embedded_images" :key="imgIndex"
class="embedded-image">
<img :src="img" class="bot-embedded-image"
@click="$emit('openImagePreview', img)" />
</div>
</div>
<!-- Audio -->
<div class="embedded-audio" v-if="msg.content.embedded_audio">
<audio controls class="audio-player">
<source :src="msg.content.embedded_audio" type="audio/wav">
{{ t('messages.errors.browser.audioNotSupported') }}
</audio>
</div>
</div>
<div class="message-actions">
<v-btn :icon="getCopyIcon(index)" size="small" variant="text" class="copy-message-btn"
:class="{ 'copy-success': isCopySuccess(index) }"
@click="copyBotMessage(msg.content.message, index)" :title="t('core.common.copy')" />
</div>
</div>
</div>
</div>
</div>
</div>
</template>
<script>
import { useI18n, useModuleI18n } from '@/i18n/composables';
import MarkdownIt from 'markdown-it';
import hljs from 'highlight.js';
import 'highlight.js/styles/github.css';
const md = new MarkdownIt({
html: false,
breaks: true,
linkify: true,
highlight: function (code, lang) {
if (lang && hljs.getLanguage(lang)) {
try {
return hljs.highlight(code, { language: lang }).value;
} catch (err) {
console.error('Highlight error:', err);
}
}
return hljs.highlightAuto(code).value;
}
});
export default {
name: 'MessageList',
props: {
messages: {
type: Array,
required: true
},
isDark: {
type: Boolean,
default: false
},
isStreaming: {
type: Boolean,
default: false
}
},
emits: ['openImagePreview'],
setup() {
const { t } = useI18n();
const { tm } = useModuleI18n('features/chat');
return {
t,
tm,
md
};
},
data() {
return {
copiedMessages: new Set(),
isUserNearBottom: true,
scrollThreshold: 1,
scrollTimer: null
};
},
mounted() {
this.initCodeCopyButtons();
this.initImageClickEvents();
this.addScrollListener();
this.scrollToBottom();
},
updated() {
this.initCodeCopyButtons();
this.initImageClickEvents();
if (this.isUserNearBottom) {
this.scrollToBottom();
}
},
methods: {
// 复制代码到剪贴板
copyCodeToClipboard(code) {
navigator.clipboard.writeText(code).then(() => {
console.log('代码已复制到剪贴板');
}).catch(err => {
console.error('复制失败:', err);
// 如果现代API失败使用传统方法
const textArea = document.createElement('textarea');
textArea.value = code;
document.body.appendChild(textArea);
textArea.select();
try {
document.execCommand('copy');
console.log('代码已复制到剪贴板 (fallback)');
} catch (fallbackErr) {
console.error('复制失败 (fallback):', fallbackErr);
}
document.body.removeChild(textArea);
});
},
// 复制bot消息到剪贴板
copyBotMessage(message, messageIndex) {
// 获取对应的消息对象
const msgObj = this.messages[messageIndex].content;
let textToCopy = '';
// 如果有文本消息,添加到复制内容中
if (message && message.trim()) {
// 移除HTML标签获取纯文本
const tempDiv = document.createElement('div');
tempDiv.innerHTML = message;
textToCopy = tempDiv.textContent || tempDiv.innerText || message;
}
// 如果有内嵌图片,添加说明
if (msgObj && msgObj.embedded_images && msgObj.embedded_images.length > 0) {
if (textToCopy) textToCopy += '\n\n';
textToCopy += `[包含 ${msgObj.embedded_images.length} 张图片]`;
}
// 如果有内嵌音频,添加说明
if (msgObj && msgObj.embedded_audio) {
if (textToCopy) textToCopy += '\n\n';
textToCopy += '[包含音频内容]';
}
// 如果没有任何内容,使用默认文本
if (!textToCopy.trim()) {
textToCopy = '[媒体内容]';
}
navigator.clipboard.writeText(textToCopy).then(() => {
console.log('消息已复制到剪贴板');
this.showCopySuccess(messageIndex);
}).catch(err => {
console.error('复制失败:', err);
// 如果现代API失败使用传统方法
const textArea = document.createElement('textarea');
textArea.value = textToCopy;
document.body.appendChild(textArea);
textArea.select();
try {
document.execCommand('copy');
console.log('消息已复制到剪贴板 (fallback)');
this.showCopySuccess(messageIndex);
} catch (fallbackErr) {
console.error('复制失败 (fallback):', fallbackErr);
}
document.body.removeChild(textArea);
});
},
// 显示复制成功提示
showCopySuccess(messageIndex) {
this.copiedMessages.add(messageIndex);
// 2秒后移除成功状态
setTimeout(() => {
this.copiedMessages.delete(messageIndex);
}, 2000);
},
// 获取复制按钮图标
getCopyIcon(messageIndex) {
return this.copiedMessages.has(messageIndex) ? 'mdi-check' : 'mdi-content-copy';
},
// 检查是否为复制成功状态
isCopySuccess(messageIndex) {
return this.copiedMessages.has(messageIndex);
},
// 获取复制图标SVG
getCopyIconSvg() {
return '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>';
},
// 获取成功图标SVG
getSuccessIconSvg() {
return '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><polyline points="20,6 9,17 4,12"></polyline></svg>';
},
// 初始化代码块复制按钮
initCodeCopyButtons() {
this.$nextTick(() => {
const codeBlocks = this.$refs.messageContainer?.querySelectorAll('pre code') || [];
codeBlocks.forEach((codeBlock, index) => {
const pre = codeBlock.parentElement;
if (pre && !pre.querySelector('.copy-code-btn')) {
const button = document.createElement('button');
button.className = 'copy-code-btn';
button.innerHTML = this.getCopyIconSvg();
button.title = '复制代码';
button.addEventListener('click', () => {
this.copyCodeToClipboard(codeBlock.textContent);
// 显示复制成功提示
button.innerHTML = this.getSuccessIconSvg();
button.style.color = '#4caf50';
setTimeout(() => {
button.innerHTML = this.getCopyIconSvg();
button.style.color = '';
}, 2000);
});
pre.style.position = 'relative';
pre.appendChild(button);
}
});
});
},
initImageClickEvents() {
this.$nextTick(() => {
// 查找所有动态生成的图片在markdown-content中
const images = document.querySelectorAll('.markdown-content img');
images.forEach((img) => {
if (!img.hasAttribute('data-click-enabled')) {
img.style.cursor = 'pointer';
img.setAttribute('data-click-enabled', 'true');
img.onclick = () => this.$emit('openImagePreview', img.src);
}
});
});
},
scrollToBottom() {
this.$nextTick(() => {
const container = this.$refs.messageContainer;
if (container) {
container.scrollTop = container.scrollHeight;
this.isUserNearBottom = true; // 程序滚动到底部后标记用户在底部
}
});
},
// 添加滚动事件监听器
addScrollListener() {
const container = this.$refs.messageContainer;
if (container) {
container.addEventListener('scroll', this.throttledHandleScroll);
}
},
// 节流处理滚动事件
throttledHandleScroll() {
if (this.scrollTimer) return;
this.scrollTimer = setTimeout(() => {
this.handleScroll();
this.scrollTimer = null;
}, 50); // 50ms 节流
},
// 处理滚动事件
handleScroll() {
const container = this.$refs.messageContainer;
if (container) {
const { scrollTop, scrollHeight, clientHeight } = container;
const distanceFromBottom = scrollHeight - (scrollTop + clientHeight);
// 判断用户是否在底部附近
this.isUserNearBottom = distanceFromBottom <= this.scrollThreshold;
}
},
// 组件销毁时移除监听器
beforeUnmount() {
const container = this.$refs.messageContainer;
if (container) {
container.removeEventListener('scroll', this.throttledHandleScroll);
}
// 清理定时器
if (this.scrollTimer) {
clearTimeout(this.scrollTimer);
this.scrollTimer = null;
}
}
}
}
</script>
<style scoped>
/* 基础动画 */
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.messages-container {
height: 100%;
max-height: 100%;
overflow-y: auto;
padding: 16px;
display: flex;
flex-direction: column;
flex: 1;
min-height: 0;
}
/* 消息列表样式 */
.message-list {
max-width: 900px;
margin: 0 auto;
width: 100%;
}
.message-item {
margin-bottom: 24px;
animation: fadeIn 0.3s ease-out;
}
.user-message {
display: flex;
justify-content: flex-end;
align-items: flex-start;
gap: 12px;
}
.bot-message {
display: flex;
justify-content: flex-start;
align-items: flex-start;
gap: 12px;
}
.bot-message-content {
display: flex;
flex-direction: column;
align-items: flex-start;
max-width: 80%;
position: relative;
}
.message-actions {
display: flex;
gap: 4px;
opacity: 0;
transition: opacity 0.2s ease;
margin-left: 8px;
}
.bot-message:hover .message-actions {
opacity: 1;
}
.copy-message-btn {
opacity: 0.6;
transition: all 0.2s ease;
color: var(--v-theme-secondary);
}
.copy-message-btn:hover {
opacity: 1;
background-color: rgba(103, 58, 183, 0.1);
}
.copy-message-btn.copy-success {
color: #4caf50;
opacity: 1;
}
.copy-message-btn.copy-success:hover {
color: #4caf50;
background-color: rgba(76, 175, 80, 0.1);
}
.message-bubble {
padding: 2px 16px;
border-radius: 12px;
}
.user-bubble {
color: var(--v-theme-primaryText);
padding: 12px 18px;
font-size: 15px;
max-width: 60%;
border-radius: 1.5rem;
}
.bot-bubble {
border: 1px solid var(--v-theme-border);
color: var(--v-theme-primaryText);
font-size: 15px;
max-width: 100%;
}
.user-avatar,
.bot-avatar {
align-self: flex-start;
margin-top: 6px;
}
/* 附件样式 */
.image-attachments {
display: flex;
gap: 8px;
margin-top: 8px;
flex-wrap: wrap;
}
.image-attachment {
position: relative;
display: inline-block;
}
.attached-image {
width: 120px;
height: 120px;
object-fit: cover;
border-radius: 12px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
transition: transform 0.2s ease;
}
.audio-attachment {
margin-top: 8px;
min-width: 250px;
}
/* 包含音频的消息气泡最小宽度 */
.message-bubble.has-audio {
min-width: 280px;
}
.audio-player {
width: 100%;
height: 36px;
border-radius: 18px;
}
.embedded-images {
margin-top: 8px;
display: flex;
flex-direction: column;
gap: 8px;
}
.embedded-image {
display: flex;
justify-content: flex-start;
}
.bot-embedded-image {
max-width: 80%;
width: auto;
height: auto;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
cursor: pointer;
transition: transform 0.2s ease;
}
.bot-embedded-image:hover {
transform: scale(1.02);
}
.embedded-audio {
width: 300px;
margin-top: 8px;
}
.embedded-audio .audio-player {
width: 100%;
max-width: 300px;
}
/* 动画类 */
.fade-in {
animation: fadeIn 0.3s ease-in-out;
}
</style>
<style>
/* Markdown内容样式 - 需要全局样式 */
.markdown-content {
font-family: inherit;
line-height: 1.6;
}
.markdown-content h1,
.markdown-content h2,
.markdown-content h3,
.markdown-content h4,
.markdown-content h5,
.markdown-content h6 {
margin-top: 16px;
margin-bottom: 10px;
font-weight: 600;
color: var(--v-theme-primaryText);
}
.markdown-content h1 {
font-size: 1.8em;
border-bottom: 1px solid var(--v-theme-border);
padding-bottom: 6px;
}
.markdown-content h2 {
font-size: 1.5em;
}
.markdown-content h3 {
font-size: 1.3em;
}
.markdown-content li {
margin-left: 16px;
margin-bottom: 4px;
}
.markdown-content p {
margin-top: .5rem;
margin-bottom: .5rem;
}
.markdown-content pre {
background-color: var(--v-theme-surface);
padding: 12px;
border-radius: 6px;
overflow-x: auto;
margin: 12px 0;
position: relative;
}
.markdown-content code {
background-color: rgb(var(--v-theme-codeBg));
padding: 2px 4px;
border-radius: 4px;
font-family: 'Fira Code', monospace;
font-size: 0.9em;
color: var(--v-theme-code);
}
/* 代码块中的code标签样式 */
.markdown-content pre code {
background-color: transparent;
padding: 0;
border-radius: 0;
font-family: 'Fira Code', 'Consolas', 'Monaco', 'Courier New', monospace;
font-size: 0.85em;
color: inherit;
display: block;
overflow-x: auto;
line-height: 1.5;
}
/* 自定义代码高亮样式 */
.markdown-content pre {
border: 1px solid var(--v-theme-border);
background-color: rgb(var(--v-theme-preBg));
border-radius: 16px;
padding: 16px;
}
/* 确保highlight.js的样式正确应用 */
.markdown-content pre code.hljs {
background: transparent !important;
color: inherit;
}
/* 亮色主题下的代码高亮 */
.v-theme--light .markdown-content pre {
background-color: #f6f8fa;
}
/* 暗色主题下的代码块样式 */
.v-theme--dark .markdown-content pre {
background-color: #0d1117 !important;
border-color: rgba(255, 255, 255, 0.1);
}
.v-theme--dark .markdown-content pre code {
color: #e6edf3 !important;
}
/* 暗色主题下的highlight.js样式覆盖 */
.v-theme--dark .hljs {
background: #0d1117 !important;
color: #e6edf3 !important;
}
.v-theme--dark .hljs-keyword,
.v-theme--dark .hljs-selector-tag,
.v-theme--dark .hljs-built_in,
.v-theme--dark .hljs-name,
.v-theme--dark .hljs-tag {
color: #ff7b72 !important;
}
.v-theme--dark .hljs-string,
.v-theme--dark .hljs-title,
.v-theme--dark .hljs-section,
.v-theme--dark .hljs-attribute,
.v-theme--dark .hljs-literal,
.v-theme--dark .hljs-template-tag,
.v-theme--dark .hljs-template-variable,
.v-theme--dark .hljs-type,
.v-theme--dark .hljs-addition {
color: #a5d6ff !important;
}
.v-theme--dark .hljs-comment,
.v-theme--dark .hljs-quote,
.v-theme--dark .hljs-deletion,
.v-theme--dark .hljs-meta {
color: #8b949e !important;
}
.v-theme--dark .hljs-number,
.v-theme--dark .hljs-regexp,
.v-theme--dark .hljs-symbol,
.v-theme--dark .hljs-variable,
.v-theme--dark .hljs-template-variable,
.v-theme--dark .hljs-link,
.v-theme--dark .hljs-selector-attr,
.v-theme--dark .hljs-selector-pseudo {
color: #79c0ff !important;
}
.v-theme--dark .hljs-function,
.v-theme--dark .hljs-class,
.v-theme--dark .hljs-title.class_ {
color: #d2a8ff !important;
}
/* 复制按钮样式 */
.copy-code-btn {
position: absolute;
top: 8px;
right: 8px;
background: rgba(255, 255, 255, 0.9);
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: 4px;
padding: 6px;
cursor: pointer;
opacity: 0;
transition: all 0.2s ease;
display: flex;
align-items: center;
justify-content: center;
color: #666;
font-size: 12px;
z-index: 10;
backdrop-filter: blur(4px);
}
.copy-code-btn:hover {
background: rgba(255, 255, 255, 1);
color: #333;
transform: scale(1.05);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
}
.copy-code-btn:active {
transform: scale(0.95);
}
.markdown-content pre:hover .copy-code-btn {
opacity: 1;
}
.v-theme--dark .copy-code-btn {
background: rgba(45, 45, 45, 0.9);
border-color: rgba(255, 255, 255, 0.15);
color: #ccc;
}
.v-theme--dark .copy-code-btn:hover {
background: rgba(45, 45, 45, 1);
color: #fff;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
}
.markdown-content img {
max-width: 100%;
border-radius: 8px;
margin: 10px 0;
}
.markdown-content blockquote {
border-left: 4px solid var(--v-theme-secondary);
padding-left: 16px;
color: var(--v-theme-secondaryText);
margin: 16px 0;
}
.markdown-content table {
border-collapse: collapse;
width: 100%;
margin: 16px 0;
}
.markdown-content th,
.markdown-content td {
border: 1px solid var(--v-theme-background);
padding: 8px 12px;
text-align: left;
}
.markdown-content th {
background-color: var(--v-theme-containerBg);
}
</style>

View File

@@ -1,21 +1,11 @@
<template>
<div>
<!-- 选择提供商和模型按钮 -->
<v-btn
class="text-none"
variant="tonal"
rounded="xl"
size="small"
v-if="selectedProviderId && selectedModelName"
@click="showDialog = true">
<v-btn class="text-none" variant="tonal" rounded="xl" size="small"
v-if="selectedProviderId && selectedModelName" @click="openDialog">
{{ selectedProviderId }} / {{ selectedModelName }}
</v-btn>
<v-btn
variant="tonal"
rounded="xl"
size="small"
v-else
@click="showDialog = true">
<v-btn variant="tonal" rounded="xl" size="small" v-else @click="openDialog">
选择模型
</v-btn>
@@ -33,16 +23,12 @@
<h4>提供商</h4>
</div>
<v-list density="compact" nav class="provider-list">
<v-list-item
v-for="provider in providerConfigs"
:key="provider.id"
:value="provider.id"
@click="selectProvider(provider)"
:active="selectedProviderId === provider.id"
rounded="lg"
class="provider-item">
<v-list-item v-for="provider in providerConfigs" :key="provider.id" :value="provider.id"
@click="selectProvider(provider)" :active="tempSelectedProviderId === provider.id"
rounded="lg" class="provider-item">
<v-list-item-title>{{ provider.id }}</v-list-item-title>
<v-list-item-subtitle v-if="provider.api_base">{{ provider.api_base }}</v-list-item-subtitle>
<v-list-item-subtitle v-if="provider.api_base">{{ provider.api_base
}}</v-list-item-subtitle>
</v-list-item>
</v-list>
<div v-if="providerConfigs.length === 0" class="empty-state">
@@ -55,33 +41,28 @@
<div class="model-list-panel">
<div class="panel-header">
<h4>模型</h4>
<v-btn
v-if="selectedProviderId"
icon="mdi-refresh"
size="small"
variant="text"
@click="refreshModels"
:loading="loadingModels">
<v-btn v-if="tempSelectedProviderId" icon="mdi-refresh" size="small" variant="text"
@click="refreshModels" :loading="loadingModels">
</v-btn>
</div>
<v-list density="compact" nav class="model-list" v-if="selectedProviderId">
<v-list-item
v-for="model in modelList"
:key="model"
:value="model"
@click="selectModel(model)"
:active="selectedModelName === model"
rounded="lg"
<v-list density="compact" nav class="model-list" v-if="tempSelectedProviderId">
<v-text-field v-model="tempSelectedModelName" placeholder="自定义模型" hide-details solo variant="outlined" density="compact" class="mb-2 mx-2"></v-text-field>
<v-list-item v-for="model in modelList" :key="model" :value="model"
@click="selectModel(model)" :active="tempSelectedModelName === model" rounded="lg"
class="model-item">
<v-list-item-title>{{ model }}</v-list-item-title>
<v-list-item-subtitle v-if="model.description">{{ model.description }}</v-list-item-subtitle>
<v-list-item-subtitle v-if="model.description">{{ model.description
}}</v-list-item-subtitle>
</v-list-item>
</v-list>
<div v-else class="empty-state">
<v-icon icon="mdi-robot-outline" size="large" color="grey-lighten-1"></v-icon>
<div class="empty-text">请先选择提供商</div>
</div>
<div v-if="selectedProviderId && modelList.length === 0 && !loadingModels" class="empty-state">
<div v-if="tempSelectedProviderId && modelList.length === 0 && !loadingModels"
class="empty-state">
<v-icon icon="mdi-robot-off-outline" size="large" color="grey-lighten-1"></v-icon>
<div class="empty-text">该提供商暂无可用模型</div>
</div>
@@ -91,11 +72,8 @@
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="closeDialog" color="grey-darken-1">取消</v-btn>
<v-btn
text
@click="confirmSelection"
color="primary"
:disabled="!selectedProviderId || !selectedModelName">
<v-btn text @click="confirmSelection" color="primary"
:disabled="!tempSelectedProviderId || !tempSelectedModelName">
确认选择
</v-btn>
</v-card-actions>
@@ -127,12 +105,17 @@ export default {
modelList: [],
selectedProviderId: '',
selectedModelName: '',
// 临时选择状态,用于对话框内的选择
tempSelectedProviderId: '',
tempSelectedModelName: '',
loadingModels: false
};
},
mounted() {
// 从localStorage加载保存的选择
this.loadFromStorage();
// 初始化临时选择
this.resetTempSelection();
// 获取提供商列表
this.loadProviderConfigs();
// 如果有保存的选择,加载对应的模型列表
@@ -145,13 +128,13 @@ export default {
loadFromStorage() {
const savedProvider = localStorage.getItem('selectedProvider');
const savedModel = localStorage.getItem('selectedModel');
if (savedProvider) {
this.selectedProviderId = savedProvider;
} else if (this.initialProvider) {
this.selectedProviderId = this.initialProvider;
}
if (savedModel) {
this.selectedModelName = savedModel;
} else if (this.initialModel) {
@@ -215,36 +198,40 @@ export default {
// 选择提供商
selectProvider(provider) {
this.selectedProviderId = provider.id;
this.selectedModelName = ''; // 清空已选择的模型
this.tempSelectedProviderId = provider.id;
this.tempSelectedModelName = ''; // 清空已选择的模型
this.modelList = []; // 清空模型列表
this.getProviderModels(provider.id); // 获取该提供商的模型列表
},
// 选择模型
selectModel(model) {
this.selectedModelName = model;
this.tempSelectedModelName = model;
},
// 刷新模型列表
refreshModels() {
if (this.selectedProviderId) {
this.getProviderModels(this.selectedProviderId);
if (this.tempSelectedProviderId) {
this.getProviderModels(this.tempSelectedProviderId);
}
},
// 确认选择
confirmSelection() {
if (this.selectedProviderId && this.selectedModelName) {
if (this.tempSelectedProviderId && this.tempSelectedModelName) {
// 将临时选择应用到正式选择
this.selectedProviderId = this.tempSelectedProviderId;
this.selectedModelName = this.tempSelectedModelName;
// 保存到localStorage
this.saveToStorage();
// 触发事件通知父组件
this.$emit('selection-changed', {
providerId: this.selectedProviderId,
modelName: this.selectedModelName
});
this.closeDialog();
}
},
@@ -252,6 +239,24 @@ export default {
// 关闭对话框
closeDialog() {
this.showDialog = false;
// 重置临时选择为当前选择
this.resetTempSelection();
},
// 重置临时选择
resetTempSelection() {
this.tempSelectedProviderId = this.selectedProviderId;
this.tempSelectedModelName = this.selectedModelName;
// 如果有临时选择的提供商,重新加载模型列表
if (this.tempSelectedProviderId) {
this.getProviderModels(this.tempSelectedProviderId);
}
},
// 打开对话框
openDialog() {
this.resetTempSelection();
this.showDialog = true;
},
// 公开方法:获取当前选择

View File

@@ -0,0 +1,173 @@
<template>
<v-dialog v-model="showDialog" max-width="900px" min-height="80%">
<v-card class="platform-selection-dialog" :title="tm('dialog.addPlatform')">
<v-card-text class="pa-4" style="overflow-y: auto;">
<v-row style="padding: 0px 8px;">
<v-col v-for="(template, name) in platformTemplates"
:key="name" cols="12" sm="6" md="6">
<v-card variant="outlined" hover class="platform-card" @click="selectTemplate(name)">
<div class="platform-card-content">
<div class="platform-card-text">
<v-card-title class="platform-card-title">{{ tm('dialog.connectTitle', { name }) }}</v-card-title>
<v-card-text class="text-caption text-medium-emphasis platform-card-description">
{{ getPlatformDescription(template, name) }}
</v-card-text>
</div>
<div class="platform-card-logo">
<img :src="getPlatformIcon(template.type)" v-if="getPlatformIcon(template.type)" class="platform-logo-img">
<div v-else class="platform-logo-fallback">
{{ name[0].toUpperCase() }}
</div>
</div>
</div>
</v-card>
</v-col>
<v-col
v-if="Object.keys(platformTemplates).length === 0"
cols="12">
<v-alert type="info" variant="tonal">
{{ tm('dialog.noTemplates') }}
</v-alert>
</v-col>
</v-row>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="closeDialog">{{ tm('dialog.cancel') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<script>
import { useModuleI18n } from '@/i18n/composables';
import { getPlatformIcon, getPlatformDescription } from '@/utils/platformUtils';
export default {
name: 'AddNewPlatform',
emits: ['update:show', 'select-template'],
props: {
show: {
type: Boolean,
default: false
},
metadata: {
type: Object,
default: () => ({})
}
},
setup() {
const { tm } = useModuleI18n('features/platform');
return { tm };
},
computed: {
showDialog: {
get() {
return this.show;
},
set(value) {
this.$emit('update:show', value);
}
},
platformTemplates() {
return this.metadata['platform_group']?.metadata?.platform?.config_template || {};
}
},
methods: {
// 从工具函数导入
getPlatformIcon,
getPlatformDescription,
selectTemplate(name) {
this.$emit('select-template', name);
this.closeDialog();
},
closeDialog() {
this.showDialog = false;
}
}
}
</script>
<style scoped>
.platform-selection-dialog .v-card-title {
border-top-left-radius: 4px;
border-top-right-radius: 4px;
}
.platform-card {
transition: all 0.3s ease;
height: 100%;
cursor: pointer;
overflow: hidden;
position: relative;
}
.platform-card:hover {
transform: translateY(-4px);
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
border-color: var(--v-primary-base);
}
.platform-card-content {
display: flex;
align-items: center;
height: 100px;
padding: 16px;
position: relative;
z-index: 2;
}
.platform-card-text {
flex: 1;
display: flex;
flex-direction: column;
justify-content: center;
}
.platform-card-title {
font-size: 15px;
font-weight: 600;
margin-bottom: 4px;
padding: 0;
}
.platform-card-description {
padding: 0;
margin: 0;
}
.platform-card-logo {
position: absolute;
right: 0;
top: 0;
bottom: 0;
width: 80px;
display: flex;
align-items: center;
justify-content: center;
z-index: 1;
}
.platform-logo-img {
max-width: 60px;
max-height: 60px;
opacity: 0.6;
object-fit: contain;
}
.platform-logo-fallback {
width: 50px;
height: 50px;
border-radius: 50%;
background-color: var(--v-primary-base);
color: white;
display: flex;
align-items: center;
justify-content: center;
font-size: 24px;
font-weight: bold;
opacity: 0.3;
}
</style>

View File

@@ -0,0 +1,237 @@
<template>
<v-dialog v-model="showDialog" max-width="1100px" min-height="95%">
<v-card :title="tm('dialogs.addProvider.title')">
<v-card-text style="overflow-y: auto;">
<v-tabs v-model="activeProviderTab" grow>
<v-tab value="chat_completion" class="font-weight-medium px-3">
<v-icon start>mdi-message-text</v-icon>
{{ tm('dialogs.addProvider.tabs.basic') }}
</v-tab>
<v-tab value="speech_to_text" class="font-weight-medium px-3">
<v-icon start>mdi-microphone-message</v-icon>
{{ tm('dialogs.addProvider.tabs.speechToText') }}
</v-tab>
<v-tab value="text_to_speech" class="font-weight-medium px-3">
<v-icon start>mdi-volume-high</v-icon>
{{ tm('dialogs.addProvider.tabs.textToSpeech') }}
</v-tab>
<v-tab value="embedding" class="font-weight-medium px-3">
<v-icon start>mdi-code-json</v-icon>
{{ tm('dialogs.addProvider.tabs.embedding') }}
</v-tab>
<v-tab value="rerank" class="font-weight-medium px-3">
<v-icon start>mdi-compare-vertical</v-icon>
{{ tm('dialogs.addProvider.tabs.rerank') }}
</v-tab>
</v-tabs>
<v-window v-model="activeProviderTab" class="mt-4">
<v-window-item
v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
:key="tabType" :value="tabType">
<v-row class="mt-1">
<v-col v-for="(template, name) in getTemplatesByType(tabType)" :key="name" cols="12" sm="6"
md="4">
<v-card variant="outlined" hover class="provider-card"
@click="selectProviderTemplate(name)">
<div class="provider-card-content">
<div class="provider-card-text">
<v-card-title class="provider-card-title">接入 {{ name }}</v-card-title>
<v-card-text
class="text-caption text-medium-emphasis provider-card-description">
{{ getProviderDescription(template, name) }}
</v-card-text>
</div>
<div class="provider-card-logo">
<img :src="getProviderIcon(template.provider)"
v-if="getProviderIcon(template.provider)" class="provider-logo-img">
<div v-else class="provider-logo-fallback">
{{ name[0].toUpperCase() }}
</div>
</div>
</div>
</v-card>
</v-col>
<v-col v-if="Object.keys(getTemplatesByType(tabType)).length === 0" cols="12">
<v-alert type="info" variant="tonal">
{{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
</v-alert>
</v-col>
</v-row>
</v-window-item>
</v-window>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="closeDialog">{{ tm('dialogs.config.cancel') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<script>
import { useModuleI18n } from '@/i18n/composables';
import { getProviderIcon, getProviderDescription } from '@/utils/providerUtils';
export default {
name: 'AddNewProvider',
props: {
show: {
type: Boolean,
default: false
},
metadata: {
type: Object,
default: () => ({})
}
},
emits: ['update:show', 'select-template'],
setup() {
const { tm } = useModuleI18n('features/provider');
return { tm };
},
data() {
return {
activeProviderTab: 'chat_completion'
};
},
computed: {
showDialog: {
get() {
return this.show;
},
set(value) {
this.$emit('update:show', value);
}
},
// 翻译消息的计算属性
messages() {
return {
tabTypes: {
'chat_completion': this.tm('providers.tabs.chatCompletion'),
'speech_to_text': this.tm('providers.tabs.speechToText'),
'text_to_speech': this.tm('providers.tabs.textToSpeech'),
'embedding': this.tm('providers.tabs.embedding'),
'rerank': this.tm('providers.tabs.rerank')
}
};
}
},
methods: {
closeDialog() {
this.showDialog = false;
},
// 按提供商类型获取模板列表
getTemplatesByType(type) {
const templates = this.metadata['provider_group']?.metadata?.provider?.config_template || {};
const filtered = {};
for (const [name, template] of Object.entries(templates)) {
if (template.provider_type === type) {
filtered[name] = template;
}
}
return filtered;
},
// 从工具函数导入
getProviderIcon,
// 获取Tab类型的中文名称
getTabTypeName(tabType) {
return this.messages.tabTypes[tabType] || tabType;
},
// 获取提供商简介
getProviderDescription(template, name) {
return getProviderDescription(template, name, this.tm);
},
// 选择提供商模板
selectProviderTemplate(name) {
this.$emit('select-template', name);
this.closeDialog();
}
}
}
</script>
<style scoped>
.provider-card {
transition: all 0.3s ease;
height: 100%;
cursor: pointer;
overflow: hidden;
position: relative;
}
.provider-card:hover {
transform: translateY(-4px);
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
border-color: var(--v-primary-base);
}
.provider-card-content {
display: flex;
align-items: center;
height: 100px;
padding: 16px;
position: relative;
z-index: 2;
}
.provider-card-text {
flex: 1;
display: flex;
flex-direction: column;
justify-content: center;
}
.provider-card-title {
font-size: 15px;
font-weight: 600;
margin-bottom: 4px;
padding: 0;
}
.provider-card-description {
padding: 0;
margin: 0;
}
.provider-card-logo {
position: absolute;
right: 0;
top: 0;
bottom: 0;
width: 80px;
display: flex;
align-items: center;
justify-content: center;
z-index: 1;
}
.provider-logo-img {
width: 60px;
height: 60px;
opacity: 0.6;
object-fit: contain;
}
.provider-logo-fallback {
width: 50px;
height: 50px;
border-radius: 50%;
background-color: var(--v-primary-base);
color: white;
display: flex;
align-items: center;
justify-content: center;
font-size: 24px;
font-weight: bold;
opacity: 0.3;
}
</style>

View File

@@ -4,28 +4,28 @@
<span class="text-h2 text-truncate" :title="getItemTitle()">{{ getItemTitle() }}</span>
<v-tooltip location="top">
<template v-slot:activator="{ props }">
<v-switch
color="primary"
hide-details
density="compact"
<v-switch
color="primary"
hide-details
density="compact"
:model-value="getItemEnabled()"
:loading="loading"
:disabled="loading"
v-bind="props"
v-bind="props"
@update:model-value="toggleEnabled"
></v-switch>
</template>
<span>{{ getItemEnabled() ? t('core.common.itemCard.enabled') : t('core.common.itemCard.disabled') }}</span>
</v-tooltip>
</v-card-title>
<v-card-text>
<slot name="item-details" :item="item"></slot>
</v-card-text>
<v-card-actions style="margin: 8px;">
<v-btn
variant="outlined"
variant="outlined"
color="error"
rounded="xl"
@click="$emit('delete', item)"
@@ -40,6 +40,15 @@
>
{{ t('core.common.itemCard.edit') }}
</v-btn>
<v-btn
v-if="showCopyButton"
variant="tonal"
color="secondary"
rounded="xl"
@click="$emit('copy', item)"
>
{{ t('core.common.itemCard.copy') }}
</v-btn>
<v-spacer></v-spacer>
</v-card-actions>
@@ -83,9 +92,13 @@ export default {
loading: {
type: Boolean,
default: false
},
showCopyButton: {
type: Boolean,
default: false
}
},
emits: ['toggle-enabled', 'delete', 'edit'],
emits: ['toggle-enabled', 'delete', 'edit', 'copy'],
methods: {
getItemTitle() {
return this.item[this.titleField];

View File

@@ -1,20 +0,0 @@
<script setup lang="ts">
const props = defineProps({
title: String
});
</script>
<template>
<v-card variant="outlined" elevation="0" class="withbg">
<v-card-item>
<div class="d-sm-flex align-center justify-space-between">
<v-card-title>{{ props.title }}</v-card-title>
<slot name="action"></slot>
</div>
</v-card-item>
<v-divider></v-divider>
<v-card-text>
<slot />
</v-card-text>
</v-card>
</template>

View File

@@ -12,6 +12,13 @@
"title": "Conversation History",
"refresh": "Refresh"
},
"batch": {
"deleteSelected": "Delete Selected ({count})"
},
"pagination": {
"itemsPerPage": "Items per page",
"showingItems": "Showing {start}-{end} of {total} items"
},
"table": {
"headers": {
"title": "Conversation Title",
@@ -61,6 +68,13 @@
"message": "Are you sure you want to delete conversation {title}? This action cannot be undone.",
"cancel": "Cancel",
"confirm": "Delete"
},
"batchDelete": {
"title": "Batch Delete Confirmation",
"message": "Are you sure you want to delete the selected {count} conversations? This action cannot be undone, please proceed with caution!",
"andMore": "and {count} more",
"cancel": "Cancel",
"confirm": "Batch Delete"
}
},
"messages": {
@@ -72,6 +86,10 @@
"historyError": "Failed to fetch conversation history",
"historySaveSuccess": "Conversation history saved successfully",
"historySaveError": "Failed to save conversation history",
"invalidJson": "Invalid JSON format"
"invalidJson": "Invalid JSON format",
"noItemSelected": "Please select conversations to delete first",
"batchDeleteSuccess": "Successfully deleted {count} conversations",
"batchDeleteError": "Batch delete failed",
"batchDeletePartial": "Delete completed: {deleted} successful, {failed} failed"
}
}

View File

@@ -7,7 +7,8 @@
"apply": "Apply Batch Settings",
"editName": "Edit Session Name",
"save": "Save",
"cancel": "Cancel"
"cancel": "Cancel",
"delete": "Delete"
},
"sessions": {
"activeSessions": "Active Sessions",
@@ -29,7 +30,8 @@
"ttsProvider": "TTS Provider",
"llmStatus": "LLM Status",
"ttsStatus": "TTS Status",
"pluginManagement": "Plugin Management"
"pluginManagement": "Plugin Management",
"actions": "Actions"
}
},
"status": {
@@ -65,6 +67,10 @@
"fullSessionId": "Full Session ID",
"hint": "Custom names help you easily identify sessions. The small information icon (!) will show the actual UMO when hovering."
},
"deleteConfirm": {
"message": "Are you sure you want to delete session {sessionName}?",
"warning": "This action will permanently delete all chat history and preference settings for this session (except for data linked via plugins), and this cannot be undone. Continue?"
},
"messages": {
"refreshSuccess": "Session list refreshed",
"personaUpdateSuccess": "Persona updated successfully",
@@ -82,6 +88,8 @@
"pluginStatusSuccess": "Plugin {name} {status}",
"pluginStatusError": "Failed to update plugin status",
"nameUpdateSuccess": "Session name updated successfully",
"nameUpdateError": "Failed to update session name"
"nameUpdateError": "Failed to update session name",
"deleteSuccess": "Session deleted successfully",
"deleteError": "Failed to delete session"
}
}

View File

@@ -73,6 +73,7 @@
"disabled": "已禁用",
"delete": "删除",
"edit": "编辑",
"copy": "复制",
"noData": "暂无数据"
}
}
}

View File

@@ -3,7 +3,7 @@
"subtitle": "管理和查看用户对话历史记录",
"filters": {
"title": "筛选条件",
"platform": "平台",
"platform": "消息平台 ID",
"type": "类型",
"search": "搜索关键词",
"reset": "重置"
@@ -12,12 +12,19 @@
"title": "对话历史",
"refresh": "刷新"
},
"batch": {
"deleteSelected": "删除选中 ({count})"
},
"pagination": {
"itemsPerPage": "每页",
"showingItems": "显示 {start}-{end} 项,共 {total} 项"
},
"table": {
"headers": {
"title": "对话标题",
"platform": "平台",
"platform": "消息平台 ID",
"type": "类型",
"sessionId": "ID",
"sessionId": "ID (UMO)",
"createdAt": "创建时间",
"updatedAt": "更新时间",
"actions": "操作"
@@ -61,6 +68,13 @@
"message": "确定要删除对话 {title} 吗?此操作不可恢复。",
"cancel": "取消",
"confirm": "删除"
},
"batchDelete": {
"title": "批量删除确认",
"message": "确定要删除选中的 {count} 个对话吗?此操作不可恢复,请谨慎操作!",
"andMore": "等 {count} 个",
"cancel": "取消",
"confirm": "批量删除"
}
},
"messages": {
@@ -72,6 +86,10 @@
"historyError": "获取对话历史失败",
"historySaveSuccess": "对话历史保存成功",
"historySaveError": "对话历史保存失败",
"invalidJson": "JSON格式无效"
"invalidJson": "JSON格式无效",
"noItemSelected": "请先选择要删除的对话",
"batchDeleteSuccess": "成功删除 {count} 个对话",
"batchDeleteError": "批量删除失败",
"batchDeletePartial": "删除完成:成功 {deleted} 个,失败 {failed} 个"
}
}

View File

@@ -5,9 +5,10 @@
"refresh": "刷新",
"edit": "编辑",
"apply": "应用批量设置",
"editName": "编辑会话名称",
"editName": "备注",
"save": "保存",
"cancel": "取消"
"cancel": "取消",
"delete": "删除"
},
"sessions": {
"activeSessions": "活跃会话",
@@ -22,14 +23,15 @@
"table": {
"headers": {
"sessionStatus": "会话状态",
"sessionInfo": "会话信息",
"sessionInfo": "ID (UMO)",
"persona": "人格",
"chatProvider": "Chat Provider",
"sttProvider": "STT Provider",
"ttsProvider": "TTS Provider",
"llmStatus": "LLM启停",
"ttsStatus": "TTS启停",
"pluginManagement": "插件管理"
"chatProvider": "聊天模型",
"sttProvider": "语音识别模型",
"ttsProvider": "语音合成模型",
"llmStatus": "启用 LLM",
"ttsStatus": "启用 TTS",
"pluginManagement": "插件管理",
"actions": "操作"
}
},
"status": {
@@ -65,6 +67,10 @@
"fullSessionId": "完整会话ID",
"hint": "自定义名称帮助您轻松识别会话。当设置了自定义名称时会显示一个小感叹号标识鼠标悬停时会显示实际的UMO。"
},
"deleteConfirm": {
"message": "确定要删除会话 {sessionName} 吗?",
"warning": "此操作将永久删除本次会话的「全部对话记录」与「偏好设置」(插件对会话的关联数据除外),且无法恢复。确认继续?"
},
"messages": {
"refreshSuccess": "会话列表已刷新",
"personaUpdateSuccess": "人格更新成功",
@@ -82,6 +88,8 @@
"pluginStatusSuccess": "插件 {name} {status}",
"pluginStatusError": "插件状态更新失败",
"nameUpdateSuccess": "会话名称更新成功",
"nameUpdateError": "会话名称更新失败"
"nameUpdateError": "会话名称更新失败",
"deleteSuccess": "会话删除成功",
"deleteError": "会话删除失败"
}
}

View File

@@ -1,22 +1,22 @@
<script setup lang="ts">
import {ref, computed} from 'vue';
import {useCustomizerStore} from '@/stores/customizer';
import { ref, computed } from 'vue';
import { useCustomizerStore } from '@/stores/customizer';
import axios from 'axios';
import Logo from '@/components/shared/Logo.vue';
import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue';
import {md5} from 'js-md5';
import {useAuthStore} from '@/stores/auth';
import {useCommonStore} from '@/stores/common';
import { md5 } from 'js-md5';
import { useAuthStore } from '@/stores/auth';
import { useCommonStore } from '@/stores/common';
import MarkdownIt from 'markdown-it';
import { useI18n } from '@/i18n/composables';
import { router } from '@/router';
// 配置markdown-it默认安全设置
const md = new MarkdownIt({
html: true, // 启用HTML标签
breaks: true, // 换行转<br>
linkify: true, // 自动转链接
typographer: false // 禁用智能引号
html: true, // 启用HTML标签
breaks: true, // 换行转<br>
linkify: true, // 自动转链接
typographer: false // 禁用智能引号
});
const customizer = useCustomizerStore();
@@ -44,11 +44,11 @@ let installLoading = ref(false);
let tab = ref(0);
const releasesHeader = computed(() => [
{title: t('core.header.updateDialog.table.tag'), key: 'tag_name'},
{title: t('core.header.updateDialog.table.publishDate'), key: 'published_at'},
{title: t('core.header.updateDialog.table.content'), key: 'body'},
{title: t('core.header.updateDialog.table.sourceUrl'), key: 'zipball_url'},
{title: t('core.header.updateDialog.table.actions'), key: 'switch'}
{ title: t('core.header.updateDialog.table.tag'), key: 'tag_name' },
{ title: t('core.header.updateDialog.table.publishDate'), key: 'published_at' },
{ title: t('core.header.updateDialog.table.content'), key: 'body' },
{ title: t('core.header.updateDialog.table.sourceUrl'), key: 'zipball_url' },
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
]);
// Form validation
@@ -103,90 +103,90 @@ function accountEdit() {
new_password: newPassword.value,
new_username: newUsername.value ? newUsername.value : username
})
.then((res) => {
if (res.data.status == 'error') {
accountEditStatus.value.error = true;
accountEditStatus.value.message = res.data.message;
password.value = '';
newPassword.value = '';
return;
}
accountEditStatus.value.success = true;
accountEditStatus.value.message = res.data.message;
setTimeout(() => {
dialog.value = !dialog.value;
const authStore = useAuthStore();
authStore.logout();
}, 2000);
})
.catch((err) => {
console.log(err);
.then((res) => {
if (res.data.status == 'error') {
accountEditStatus.value.error = true;
accountEditStatus.value.message = typeof err === 'string' ? err : t('core.header.accountDialog.messages.updateFailed');
accountEditStatus.value.message = res.data.message;
password.value = '';
newPassword.value = '';
})
.finally(() => {
accountEditStatus.value.loading = false;
});
return;
}
accountEditStatus.value.success = true;
accountEditStatus.value.message = res.data.message;
setTimeout(() => {
dialog.value = !dialog.value;
const authStore = useAuthStore();
authStore.logout();
}, 2000);
})
.catch((err) => {
console.log(err);
accountEditStatus.value.error = true;
accountEditStatus.value.message = typeof err === 'string' ? err : t('core.header.accountDialog.messages.updateFailed');
password.value = '';
newPassword.value = '';
})
.finally(() => {
accountEditStatus.value.loading = false;
});
}
function getVersion() {
axios.get('/api/stat/version')
.then((res) => {
botCurrVersion.value = "v" + res.data.data.version;
dashboardCurrentVersion.value = res.data.data?.dashboard_version;
let change_pwd_hint = res.data.data?.change_pwd_hint;
if (change_pwd_hint) {
dialog.value = true;
accountWarning.value = true;
localStorage.setItem('change_pwd_hint', 'true');
} else {
localStorage.removeItem('change_pwd_hint');
}
})
.catch((err) => {
console.log(err);
});
.then((res) => {
botCurrVersion.value = "v" + res.data.data.version;
dashboardCurrentVersion.value = res.data.data?.dashboard_version;
let change_pwd_hint = res.data.data?.change_pwd_hint;
if (change_pwd_hint) {
dialog.value = true;
accountWarning.value = true;
localStorage.setItem('change_pwd_hint', 'true');
} else {
localStorage.removeItem('change_pwd_hint');
}
})
.catch((err) => {
console.log(err);
});
}
function checkUpdate() {
updateStatus.value = t('core.header.updateDialog.status.checking');
axios.get('/api/update/check')
.then((res) => {
hasNewVersion.value = res.data.data.has_new_version;
.then((res) => {
hasNewVersion.value = res.data.data.has_new_version;
if (res.data.data.has_new_version) {
releaseMessage.value = res.data.message;
updateStatus.value = t('core.header.version.hasNewVersion');
} else {
updateStatus.value = res.data.message;
}
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
})
.catch((err) => {
if (err.response && err.response.status == 401) {
console.log("401");
const authStore = useAuthStore();
authStore.logout();
return;
}
console.log(err);
updateStatus.value = err
});
if (res.data.data.has_new_version) {
releaseMessage.value = res.data.message;
updateStatus.value = t('core.header.version.hasNewVersion');
} else {
updateStatus.value = res.data.message;
}
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
})
.catch((err) => {
if (err.response && err.response.status == 401) {
console.log("401");
const authStore = useAuthStore();
authStore.logout();
return;
}
console.log(err);
updateStatus.value = err
});
}
function getReleases() {
axios.get('/api/update/releases')
.then((res) => {
releases.value = res.data.data.map((item: any) => {
item.published_at = new Date(item.published_at).toLocaleString();
return item;
})
.then((res) => {
releases.value = res.data.data.map((item: any) => {
item.published_at = new Date(item.published_at).toLocaleString();
return item;
})
.catch((err) => {
console.log(err);
});
})
.catch((err) => {
console.log(err);
});
}
function getDevCommits() {
@@ -209,10 +209,10 @@ function getDevCommits() {
.then(data => {
devCommits.value = Array.isArray(data)
? data.map((commit: any) => ({
sha: commit.sha,
date: new Date(commit.commit.author.date).toLocaleString(),
message: commit.commit.message
}))
sha: commit.sha,
date: new Date(commit.commit.author.date).toLocaleString(),
message: commit.commit.message
}))
: [];
})
.catch(err => {
@@ -239,40 +239,40 @@ function switchVersion(version: string) {
version: version,
proxy: localStorage.getItem('selectedGitHubProxy') || ''
})
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
}).finally(() => {
installLoading.value = false;
});
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
}).finally(() => {
installLoading.value = false;
});
}
function updateDashboard() {
updatingDashboardLoading.value = true;
updateStatus.value = t('core.header.updateDialog.status.updating');
axios.post('/api/update/dashboard')
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
}).finally(() => {
updatingDashboardLoading.value = false;
});
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
}).finally(() => {
updatingDashboardLoading.value = false;
});
}
function toggleDarkMode() {
@@ -291,29 +291,32 @@ commonStore.getStartTime();
<template>
<v-app-bar elevation="0" height="55">
<v-btn v-if="useCustomizerStore().uiTheme==='PurpleTheme'" style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-btn v-if="useCustomizerStore().uiTheme === 'PurpleTheme'" style="margin-left: 22px;"
class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm" variant="flat"
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<v-btn v-else style="margin-left: 22px; color: var(--v-theme-primaryText); background-color: var(--v-theme-secondary)" class="hidden-md-and-down" icon rounded="sm"
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-btn v-else
style="margin-left: 22px; color: var(--v-theme-primaryText); background-color: var(--v-theme-secondary)"
class="hidden-md-and-down" icon rounded="sm" variant="flat"
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<v-btn v-if="useCustomizerStore().uiTheme==='PurpleTheme'" class="hidden-lg-and-up ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
<v-btn v-if="useCustomizerStore().uiTheme === 'PurpleTheme'" class="hidden-lg-and-up ms-3" color="lightsecondary"
icon rounded="sm" variant="flat" @click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<v-btn v-else class="hidden-lg-and-up ms-3" icon rounded="sm" variant="flat"
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<div class="logo-container" :class="{'mobile-logo': $vuetify.display.xs}" @click="$router.push('/about')">
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs }" @click="$router.push('/about')">
<span class="logo-text">Astr<span class="logo-text-light">Bot</span></span>
<span class="version-text hidden-xs">{{ botCurrVersion }}</span>
</div>
<v-spacer/>
<v-spacer />
<!-- 版本提示信息 - 在手机上隐藏 -->
<div class="mr-4 hidden-xs">
@@ -329,19 +332,19 @@ commonStore.getStartTime();
<LanguageSwitcher variant="header" />
<!-- 主题切换按钮 -->
<v-btn size="small" @click="toggleDarkMode();" class="action-btn"
color="var(--v-theme-surface)" variant="flat" rounded="sm">
<v-btn size="small" @click="toggleDarkMode();" class="action-btn" color="var(--v-theme-surface)" variant="flat"
rounded="sm" icon>
<v-icon v-if="useCustomizerStore().uiTheme === 'PurpleThemeDark'">mdi-weather-night</v-icon>
<v-icon v-else>mdi-white-balance-sunny</v-icon>
</v-btn>
<!-- 更新对话框 -->
<v-dialog v-model="updateStatusDialog" :width="$vuetify.display.smAndDown ? '100%' : '1200'" :fullscreen="$vuetify.display.xs">
<v-dialog v-model="updateStatusDialog" :width="$vuetify.display.smAndDown ? '100%' : '1200'"
:fullscreen="$vuetify.display.xs">
<template v-slot:activator="{ props }">
<v-btn size="small" @click="checkUpdate(); getReleases(); getDevCommits();" class="action-btn"
color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props">
<v-icon class="hidden-sm-and-up">mdi-update</v-icon>
<span class="hidden-xs">{{ t('core.header.buttons.update') }}</span>
color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props" icon>
<v-icon>mdi-arrow-up-circle</v-icon>
</v-btn>
</template>
<v-card>
@@ -361,8 +364,8 @@ commonStore.getStartTime();
</div>
<div v-if="releaseMessage"
style="background-color: #646cff24; padding: 16px; border-radius: 10px; font-size: 14px; max-height: 400px; overflow-y: auto;"
v-html="md.render(releaseMessage)" class="markdown-content">
style="background-color: #646cff24; padding: 16px; border-radius: 10px; font-size: 14px; max-height: 400px; overflow-y: auto;"
v-html="md.render(releaseMessage)" class="markdown-content">
</div>
<div class="mb-4 mt-4">
@@ -380,15 +383,13 @@ commonStore.getStartTime();
<v-tabs-window-item key="0" v-show="tab == 0">
<div class="mb-4">
<small>{{ t('core.header.updateDialog.dockerTip') }} <a
href="https://containrrr.dev/watchtower/usage-overview/">{{ t('core.header.updateDialog.dockerTipLink') }}</a> {{ t('core.header.updateDialog.dockerTipContinue') }}</small>
href="https://containrrr.dev/watchtower/usage-overview/">{{
t('core.header.updateDialog.dockerTipLink')
}}</a> {{ t('core.header.updateDialog.dockerTipContinue') }}</small>
</div>
<v-alert
v-if="releases.some(item => isPreRelease(item['tag_name']))"
type="warning"
variant="tonal"
border="start"
>
<v-alert v-if="releases.some(item => isPreRelease(item['tag_name']))" type="warning" variant="tonal"
border="start">
<template v-slot:prepend>
<v-icon>mdi-alert-circle-outline</v-icon>
</template>
@@ -406,13 +407,8 @@ commonStore.getStartTime();
<template v-slot:item.tag_name="{ item }: { item: { tag_name: string } }">
<div class="d-flex align-center">
<span>{{ item.tag_name }}</span>
<v-chip
v-if="isPreRelease(item.tag_name)"
size="x-small"
color="warning"
variant="tonal"
class="ml-2"
>
<v-chip v-if="isPreRelease(item.tag_name)" size="x-small" color="warning" variant="tonal"
class="ml-2">
{{ t('core.header.updateDialog.preRelease') }}
</v-chip>
</div>
@@ -420,7 +416,8 @@ commonStore.getStartTime();
<template v-slot:item.body="{ item }: { item: { body: string } }">
<v-tooltip :text="item.body">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" rounded="xl" variant="tonal" color="primary" size="x-small">{{ t('core.header.updateDialog.table.view') }}</v-btn>
<v-btn v-bind="props" rounded="xl" variant="tonal" color="primary" size="x-small">{{
t('core.header.updateDialog.table.view') }}</v-btn>
</template>
</v-tooltip>
</template>
@@ -435,14 +432,12 @@ commonStore.getStartTime();
<!-- 开发版 -->
<v-tabs-window-item key="1" v-show="tab == 1">
<div style="margin-top: 16px;">
<v-data-table
:headers="[
{ title: t('core.header.updateDialog.table.sha'), key: 'sha' },
{ title: t('core.header.updateDialog.table.date'), key: 'date' },
{ title: t('core.header.updateDialog.table.message'), key: 'message' },
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
]"
:items="devCommits" item-key="sha">
<v-data-table :headers="[
{ title: t('core.header.updateDialog.table.sha'), key: 'sha' },
{ title: t('core.header.updateDialog.table.date'), key: 'date' },
{ title: t('core.header.updateDialog.table.message'), key: 'message' },
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
]" :items="devCommits" item-key="sha">
<template v-slot:item.switch="{ item }: { item: { sha: string } }">
<v-btn @click="switchVersion(item.sha)" rounded="xl" variant="plain" color="primary">
{{ t('core.header.updateDialog.table.switch') }}
@@ -457,11 +452,12 @@ commonStore.getStartTime();
<h3 class="mb-4">{{ t('core.header.updateDialog.manualInput.title') }}</h3>
<v-text-field :label="t('core.header.updateDialog.manualInput.placeholder')" v-model="version" required
variant="outlined"></v-text-field>
variant="outlined"></v-text-field>
<div class="mb-4">
<small>{{ t('core.header.updateDialog.manualInput.hint') }}</small>
<br>
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>{{ t('core.header.updateDialog.manualInput.linkText') }}</small></a>
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>{{
t('core.header.updateDialog.manualInput.linkText') }}</small></a>
</div>
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
{{ t('core.header.updateDialog.manualInput.confirm') }}
@@ -471,7 +467,8 @@ commonStore.getStartTime();
<div style="margin-top: 16px;">
<h3 class="mb-4">{{ t('core.header.updateDialog.dashboardUpdate.title') }}</h3>
<div class="mb-4">
<small>{{ t('core.header.updateDialog.dashboardUpdate.currentVersion') }} {{ dashboardCurrentVersion }}</small>
<small>{{ t('core.header.updateDialog.dashboardUpdate.currentVersion') }} {{ dashboardCurrentVersion
}}</small>
<br>
</div>
@@ -486,7 +483,7 @@ commonStore.getStartTime();
</div>
<v-btn color="primary" style="border-radius: 10px;" @click="updateDashboard()"
:disabled="!dashboardHasNewVersion" :loading="updatingDashboardLoading">
:disabled="!dashboardHasNewVersion" :loading="updatingDashboardLoading">
{{ t('core.header.updateDialog.dashboardUpdate.downloadAndUpdate') }}
</v-btn>
</div>
@@ -504,9 +501,9 @@ commonStore.getStartTime();
<!-- 账户对话框 -->
<v-dialog v-model="dialog" persistent :max-width="$vuetify.display.xs ? '90%' : '500'">
<template v-slot:activator="{ props }">
<v-btn size="small" class="action-btn mr-4" color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props">
<v-btn size="small" class="action-btn mr-4" color="var(--v-theme-surface)" variant="flat" rounded="sm"
v-bind="props" icon>
<v-icon>mdi-account</v-icon>
<span class="hidden-xs ml-1">{{ t('core.header.buttons.account') }}</span>
</v-btn>
</template>
<v-card class="account-dialog">
@@ -514,105 +511,51 @@ commonStore.getStartTime();
<div class="d-flex flex-column align-center mb-6">
<logo :title="t('core.header.logoTitle')" :subtitle="t('core.header.accountDialog.title')"></logo>
</div>
<v-alert
v-if="accountWarning"
type="warning"
variant="tonal"
border="start"
class="mb-4"
>
<v-alert v-if="accountWarning" type="warning" variant="tonal" border="start" class="mb-4">
<strong>{{ t('core.header.accountDialog.securityWarning') }}</strong>
</v-alert>
<v-alert
v-if="accountEditStatus.success"
type="success"
variant="tonal"
border="start"
class="mb-4"
>
<v-alert v-if="accountEditStatus.success" type="success" variant="tonal" border="start" class="mb-4">
{{ accountEditStatus.message }}
</v-alert>
<v-alert
v-if="accountEditStatus.error"
type="error"
variant="tonal"
border="start"
class="mb-4"
>
<v-alert v-if="accountEditStatus.error" type="error" variant="tonal" border="start" class="mb-4">
{{ accountEditStatus.message }}
</v-alert>
<v-form v-model="formValid" @submit.prevent="accountEdit">
<v-text-field
v-model="password"
:append-inner-icon="showPassword ? 'mdi-eye-off' : 'mdi-eye'"
:type="showPassword ? 'text' : 'password'"
:label="t('core.header.accountDialog.form.currentPassword')"
variant="outlined"
required
clearable
@click:append-inner="showPassword = !showPassword"
prepend-inner-icon="mdi-lock-outline"
hide-details="auto"
class="mb-4"
></v-text-field>
<v-text-field v-model="password" :append-inner-icon="showPassword ? 'mdi-eye-off' : 'mdi-eye'"
:type="showPassword ? 'text' : 'password'" :label="t('core.header.accountDialog.form.currentPassword')"
variant="outlined" required clearable @click:append-inner="showPassword = !showPassword"
prepend-inner-icon="mdi-lock-outline" hide-details="auto" class="mb-4"></v-text-field>
<v-text-field
v-model="newPassword"
:append-inner-icon="showNewPassword ? 'mdi-eye-off' : 'mdi-eye'"
:type="showNewPassword ? 'text' : 'password'"
:rules="passwordRules"
:label="t('core.header.accountDialog.form.newPassword')"
variant="outlined"
required
clearable
@click:append-inner="showNewPassword = !showNewPassword"
prepend-inner-icon="mdi-lock-plus-outline"
:hint="t('core.header.accountDialog.form.passwordHint')"
persistent-hint
class="mb-4"
></v-text-field>
<v-text-field v-model="newPassword" :append-inner-icon="showNewPassword ? 'mdi-eye-off' : 'mdi-eye'"
:type="showNewPassword ? 'text' : 'password'" :rules="passwordRules"
:label="t('core.header.accountDialog.form.newPassword')" variant="outlined" required clearable
@click:append-inner="showNewPassword = !showNewPassword" prepend-inner-icon="mdi-lock-plus-outline"
:hint="t('core.header.accountDialog.form.passwordHint')" persistent-hint class="mb-4"></v-text-field>
<v-text-field
v-model="newUsername"
:rules="usernameRules"
:label="t('core.header.accountDialog.form.newUsername')"
variant="outlined"
clearable
prepend-inner-icon="mdi-account-edit-outline"
:hint="t('core.header.accountDialog.form.usernameHint')"
persistent-hint
class="mb-3"
></v-text-field>
<v-text-field v-model="newUsername" :rules="usernameRules"
:label="t('core.header.accountDialog.form.newUsername')" variant="outlined" clearable
prepend-inner-icon="mdi-account-edit-outline" :hint="t('core.header.accountDialog.form.usernameHint')"
persistent-hint class="mb-3"></v-text-field>
</v-form>
<div class="text-caption text-medium-emphasis mt-2">
{{ t('core.header.accountDialog.form.defaultCredentials') }}
</div>
</v-card-text>
<v-divider></v-divider>
<v-card-actions class="pa-4">
<v-spacer></v-spacer>
<v-btn
v-if="!accountWarning"
variant="tonal"
color="secondary"
@click="dialog = false"
:disabled="accountEditStatus.loading"
>
<v-btn v-if="!accountWarning" variant="tonal" color="secondary" @click="dialog = false"
:disabled="accountEditStatus.loading">
{{ t('core.header.accountDialog.actions.cancel') }}
</v-btn>
<v-btn
color="primary"
@click="accountEdit"
:loading="accountEditStatus.loading"
:disabled="!formValid"
prepend-icon="mdi-content-save"
>
<v-btn color="primary" @click="accountEdit" :loading="accountEditStatus.loading" :disabled="!formValid"
prepend-icon="mdi-content-save">
{{ t('core.header.accountDialog.actions.save') }}
</v-btn>
</v-card-actions>
@@ -665,9 +608,9 @@ commonStore.getStartTime();
/* 响应式布局样式 */
.logo-container {
margin-left: 16px;
display: flex;
align-items: center;
margin-left: 16px;
display: flex;
align-items: center;
gap: 8px;
cursor: pointer;
}
@@ -678,7 +621,7 @@ commonStore.getStartTime();
}
.logo-text {
font-size: 24px;
font-size: 24px;
font-weight: 1000;
}
@@ -687,7 +630,7 @@ commonStore.getStartTime();
}
.version-text {
font-size: 12px;
font-size: 12px;
color: var(--v-theme-secondaryText);
}
@@ -707,7 +650,7 @@ commonStore.getStartTime();
.logo-text {
font-size: 20px;
}
.action-btn {
margin-right: 4px;
min-width: 32px !important;
@@ -717,11 +660,11 @@ commonStore.getStartTime();
.v-card-title {
padding: 12px 16px;
}
.v-card-text {
padding: 16px;
}
.v-tabs .v-tab {
padding: 0 10px;
font-size: 0.9rem;

View File

@@ -0,0 +1,31 @@
import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
export const useToastStore = defineStore('toast', () => {
const queue = ref([])
const current = computed(() => queue.value[0])
function add({
message,
color = 'info', // Vuetify 颜色
timeout = 3000,
closable = true,
multiLine = false,
location = 'top center'
}) {
queue.value.push({
message,
color,
timeout,
closable,
multiLine,
location
})
}
function shift() {
queue.value.shift()
}
return { current, add, shift }
})

View File

@@ -0,0 +1,78 @@
/**
* 平台相关工具函数
*/
/**
* 获取平台图标
* @param {string} name - 平台名称或类型
* @returns {string|undefined} 图标URL
*/
export function getPlatformIcon(name) {
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
} else if (name === 'wecom') {
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
} else if (name === 'lark') {
return new URL('@/assets/images/platform_logos/lark.png', import.meta.url).href
} else if (name === 'dingtalk') {
return new URL('@/assets/images/platform_logos/dingtalk.svg', import.meta.url).href
} else if (name === 'telegram') {
return new URL('@/assets/images/platform_logos/telegram.svg', import.meta.url).href
} else if (name === 'discord') {
return new URL('@/assets/images/platform_logos/discord.svg', import.meta.url).href
} else if (name === 'slack') {
return new URL('@/assets/images/platform_logos/slack.svg', import.meta.url).href
} else if (name === 'kook') {
return new URL('@/assets/images/platform_logos/kook.png', import.meta.url).href
} else if (name === 'vocechat') {
return new URL('@/assets/images/platform_logos/vocechat.png', import.meta.url).href
} else if (name === 'satori' || name === 'Satori') {
return new URL('@/assets/images/platform_logos/satori.png', import.meta.url).href
} else if (name === 'misskey') {
return new URL('@/assets/images/platform_logos/misskey.png', import.meta.url).href
}
}
/**
* 获取平台教程链接
* @param {string} platformType - 平台类型
* @returns {string} 教程链接
*/
export function getTutorialLink(platformType) {
const tutorialMap = {
"qq_official_webhook": "https://docs.astrbot.app/deploy/platform/qqofficial/webhook.html",
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
"wechatpadpro": "https://docs.astrbot.app/deploy/platform/wechat/wechatpadpro.html",
"weixin_official_account": "https://docs.astrbot.app/deploy/platform/weixin-official-account.html",
"discord": "https://docs.astrbot.app/deploy/platform/discord.html",
"slack": "https://docs.astrbot.app/deploy/platform/slack.html",
"kook": "https://docs.astrbot.app/deploy/platform/kook.html",
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
"misskey": "https://docs.astrbot.app/deploy/platform/misskey.html",
}
return tutorialMap[platformType] || "https://docs.astrbot.app";
}
/**
* 获取平台描述
* @param {Object} template - 平台模板
* @param {string} name - 平台名称
* @returns {string} 平台描述
*/
export function getPlatformDescription(template, name) {
// special judge for community platforms
if (name.includes('vocechat')) {
return "由 @HikariFroya 提供。";
} else if (name.includes('kook')) {
return "由 @wuyan1003 提供。"
}
return '';
}

View File

@@ -0,0 +1,52 @@
/**
* 提供商相关的工具函数
*/
/**
* 获取提供商类型对应的图标
* @param {string} type - 提供商类型
* @returns {string} 图标 URL
*/
export function getProviderIcon(type) {
const icons = {
'openai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
'azure': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/azure.svg',
'xai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/xai.svg',
'anthropic': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/anthropic.svg',
'ollama': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ollama.svg',
'google': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
'deepseek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
'modelscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/modelscope.svg',
'zhipu': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
'siliconflow': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
"coze": "https://registry.npmmirror.com/@lobehub/icons-static-svg/1.66.0/files/icons/coze.svg",
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',
'minimax': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/minimax.svg',
'302ai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
'microsoft': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/microsoft.svg',
'vllm': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/vllm.svg',
};
return icons[type] || '';
}
/**
* 获取提供商简介
* @param {Object} template - 模板对象
* @param {string} name - 提供商名称
* @param {Function} tm - 翻译函数
* @returns {string} 提供商描述
*/
export function getProviderDescription(template, name, tm) {
if (name == 'OpenAI') {
return tm('providers.description.openai', { type: template.type });
} else if (name == 'vLLM Rerank') {
return tm('providers.description.vllm_rerank', { type: template.type });
}
return tm('providers.description.default', { type: template.type });
}

View File

@@ -0,0 +1,16 @@
import { useToastStore } from '@/stores/toast'
export function useToast() {
const store = useToastStore()
const toast = (message, color = 'info', opts = {}) =>
store.add({ message, color, ...opts })
return {
toast,
success: (msg, opts) => toast(msg, 'success', opts),
error: (msg, opts) => toast(msg, 'error', opts),
info: (msg, opts) => toast(msg, 'primary', opts),
warning: (msg, opts) => toast(msg, 'warning', opts)
}
}

View File

@@ -1,5 +1,5 @@
<script setup>
import ChatPage from './ChatPage.vue';
import Chat from '@/components/chat/Chat.vue'
import { useCustomizerStore } from '@/stores/customizer';
const customizer = useCustomizerStore();
</script>
@@ -9,7 +9,7 @@ const customizer = useCustomizerStore();
<div
style="height: 100%; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
<div id="container">
<ChatPage :chatbox-mode="true"></ChatPage>
<Chat :chatbox-mode="true"></Chat>
</div>
</div>
</v-app>
@@ -18,24 +18,6 @@ const customizer = useCustomizerStore();
<style scoped>
#container {
width: 100%;
height: 100%;
}
@media (min-width: 768px) {
#container {
min-width: 600px;
min-height: 370px;
max-width: 1100px;
max-height: 860px;
padding: 36px;
}
}
@media (max-width: 767px) {
#container {
width: 100%;
height: 100%;
padding: 0;
}
height: 100vh;
}
</style>

File diff suppressed because it is too large Load Diff

View File

@@ -1,50 +1,30 @@
<template>
<div class="conversation-page">
<v-container fluid class="pa-0">
<!-- 页面标题 -->
<v-row>
<v-col cols="12">
<h1 class="text-h4 font-weight-bold mb-2">
<v-icon size="x-large" color="primary" class="me-2">mdi-chat-processing</v-icon>{{ tm('title') }}
</h1>
<p class="text-subtitle-1 text-medium-emphasis mb-4">
{{ tm('subtitle') }}
</p>
</v-col>
</v-row>
<!-- 过滤器部分 -->
<v-card class="mb-4" elevation="2">
<!-- 对话列表部分 -->
<v-card flat>
<v-card-title class="d-flex align-center py-3 px-4">
<v-icon color="primary" class="me-2">mdi-filter-variant</v-icon>
<span class="text-h6">{{ tm('filters.title') }}</span>
<v-spacer></v-spacer>
<v-btn color="primary" variant="text" @click="resetFilters" class="ml-2">
<v-icon class="mr-1">mdi-refresh</v-icon>{{ tm('filters.reset') }}
</v-btn>
</v-card-title>
<v-divider></v-divider>
<v-card-text class="py-4">
<v-row>
<span class="text-h4">{{ tm('history.title') }}</span>
<v-chip size="small" class="ml-2">{{ pagination.total || 0 }}</v-chip>
<v-row class="me-4 ms-4" dense>
<v-col cols="12" sm="6" md="4">
<v-select v-model="platformFilter" :label="tm('filters.platform')" :items="availablePlatforms" chips multiple
clearable variant="outlined" density="compact" hide-details>
<v-combobox v-model="platformFilter" :label="tm('filters.platform')"
:items="availablePlatforms" chips multiple clearable variant="solo-filled" flat
density="compact" hide-details :disabled="loading">
<template v-slot:selection="{ item }">
<v-chip size="small" :color="getPlatformColor(item.value)" label>
<v-chip size="small" label>
{{ item.title }}
</v-chip>
</template>
</v-select>
</v-combobox>
</v-col>
<v-col cols="12" sm="6" md="4">
<v-select v-model="messageTypeFilter" :label="tm('filters.type')" :items="messageTypeItems" chips multiple
clearable variant="outlined" density="compact" hide-details>
<v-select v-model="messageTypeFilter" :label="tm('filters.type')" :items="messageTypeItems"
chips multiple clearable variant="solo-filled" density="compact" hide-details flat
:disabled="loading">
<template v-slot:selection="{ item }">
<v-chip size="small" :color="getMessageTypeColor(item.value)" variant="outlined"
label>
<v-chip size="small" variant="solo-filled" label>
{{ item.title }}
</v-chip>
</template>
@@ -52,49 +32,49 @@
</v-col>
<v-col cols="12" sm="12" md="4">
<v-text-field v-model="search" prepend-inner-icon="mdi-magnify" :label="tm('filters.search')" hide-details
density="compact" variant="outlined" clearable></v-text-field>
<v-text-field v-model="search" prepend-inner-icon="mdi-magnify"
:label="tm('filters.search')" hide-details density="compact" variant="solo-filled" flat
clearable :disabled="loading"></v-text-field>
</v-col>
</v-row>
</v-card-text>
</v-card>
<!-- 对话列表部分 -->
<v-card class="mb-6" elevation="2">
<v-card-title class="d-flex align-center py-3 px-4">
<v-icon color="primary" class="me-2">mdi-message</v-icon>
<span class="text-h6">{{ tm('history.title') }}</span>
<v-chip color="info" size="small" class="ml-2">{{ pagination.total || 0 }}</v-chip>
<v-spacer></v-spacer>
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="fetchConversations"
:loading="loading">
:loading="loading" size="small" class="mr-2">
{{ tm('history.refresh') }}
</v-btn>
<v-btn
v-if="selectedItems.length > 0"
color="error"
prepend-icon="mdi-delete"
variant="tonal"
@click="confirmBatchDelete"
:disabled="loading"
size="small">
{{ tm('batch.deleteSelected', { count: selectedItems.length }) }}
</v-btn>
</v-card-title>
<v-divider></v-divider>
<v-card-text class="pa-0">
<v-data-table :headers="tableHeaders" :items="conversations" :loading="loading" density="comfortable"
hide-default-footer items-per-page="10" class="elevation-0"
:items-per-page="pagination.page_size" :items-per-page-options="[10, 20, 50, 100]"
@update:options="handleTableOptions">
<v-data-table v-model="selectedItems" :headers="tableHeaders" :items="conversations"
:loading="loading" style="font-size: 12px;" density="comfortable" hide-default-footer
class="elevation-0" :items-per-page="pagination.page_size"
:items-per-page-options="pageSizeOptions" show-select return-object
:disabled="loading" @update:options="handleTableOptions">
<template v-slot:item.title="{ item }">
<div class="d-flex align-center">
<v-icon color="primary" class="mr-2">mdi-chat</v-icon>
<span>{{ item.title || tm('status.noTitle') }}</span>
</div>
</template>
<template v-slot:item.platform="{ item }">
<v-chip :color="getPlatformColor(item.sessionInfo.platform)" size="small" label>
<v-chip size="small" label>
{{ item.sessionInfo.platform || tm('status.unknown') }}
</v-chip>
</template>
<template v-slot:item.messageType="{ item }">
<v-chip :color="getMessageTypeColor(item.sessionInfo.messageType)" size="small"
variant="outlined" label>
<v-chip size="small" label>
{{ getMessageTypeDisplay(item.sessionInfo.messageType) }}
</v-chip>
</template>
@@ -113,17 +93,17 @@
<template v-slot:item.actions="{ item }">
<div class="actions-wrapper">
<v-btn color="primary" variant="flat" size="small" class="action-button"
@click="viewConversation(item)">
<v-icon class="mr-1">mdi-eye</v-icon>{{ tm('actions.view') }}
<v-btn icon variant="plain" size="x-small" class="action-button"
@click="viewConversation(item)" :disabled="loading">
<v-icon>mdi-eye</v-icon>
</v-btn>
<v-btn color="warning" variant="flat" size="small" class="action-button"
@click="editConversation(item)">
<v-icon class="mr-1">mdi-pencil</v-icon>{{ tm('actions.edit') }}
<v-btn icon variant="plain" size="x-small" class="action-button"
@click="editConversation(item)" :disabled="loading">
<v-icon>mdi-pencil</v-icon>
</v-btn>
<v-btn color="error" variant="flat" size="small" class="action-button"
@click="confirmDeleteConversation(item)">
<v-icon class="mr-1">mdi-delete</v-icon>{{ tm('actions.delete') }}
<v-btn icon color="error" variant="plain" size="x-small" class="action-button"
@click="confirmDeleteConversation(item)" :disabled="loading">
<v-icon>mdi-delete</v-icon>
</v-btn>
</div>
</template>
@@ -137,9 +117,25 @@
</v-data-table>
<!-- 分页控制 -->
<div class="d-flex justify-end pa-4">
<div class="d-flex justify-center py-3">
<!-- 每页大小选择器 -->
<div class="d-flex justify-between align-center px-4 py-2 bg-grey-lighten-5">
<div class="d-flex align-center">
<span class="text-caption mr-2">{{ tm('pagination.itemsPerPage') }}:</span>
<v-select v-model="pagination.page_size" :items="pageSizeOptions" variant="outlined"
density="compact" hide-details style="max-width: 100px;"
:disabled="loading" @update:model-value="onPageSizeChange"></v-select>
</div>
<div class="text-caption ml-4">
{{ tm('pagination.showingItems', {
start: Math.min((pagination.page - 1) * pagination.page_size + 1, pagination.total),
end: Math.min(pagination.page * pagination.page_size, pagination.total),
total: pagination.total
}) }}
</div>
</div>
<v-pagination v-model="pagination.page" :length="pagination.total_pages" :disabled="loading"
@update:model-value="fetchConversations" rounded="circle"></v-pagination>
@update:model-value="fetchConversations" rounded="circle" :total-visible="7"></v-pagination>
</div>
</v-card-text>
</v-card>
@@ -148,24 +144,20 @@
<!-- 对话详情对话框 -->
<v-dialog v-model="dialogView" max-width="900px" scrollable>
<v-card class="conversation-detail-card">
<v-card-title class="bg-primary text-white py-3 d-flex align-center">
<v-icon color="white" class="me-2">mdi-eye</v-icon>
<v-card-title class="ml-2 mt-2 d-flex align-center">
<span class="text-truncate">{{ selectedConversation?.title || tm('status.noTitle') }}</span>
<v-spacer></v-spacer>
<div class="d-flex align-center" v-if="selectedConversation?.sessionInfo">
<v-chip color="white" text-color="primary" size="small" class="mr-2">
<v-chip text-color="primary" size="small" class="mr-2" rounded="md">
{{ selectedConversation.sessionInfo.platform }}
</v-chip>
<v-chip color="white" text-color="secondary" size="small">
<v-chip text-color="secondary" size="small" rounded="md">
{{ getMessageTypeDisplay(selectedConversation.sessionInfo.messageType) }}
</v-chip>
</div>
</v-card-title>
<v-divider></v-divider>
<v-card-text class="py-4">
<v-card-text>
<div class="mb-4 d-flex align-center">
<v-btn color="secondary" variant="tonal" size="small" class="mr-2"
@click="isEditingHistory = !isEditingHistory">
@@ -199,51 +191,11 @@
<p class="text-disabled mt-2">{{ tm('status.emptyContent') }}</p>
</div>
<!-- 消息列表 -->
<div v-else class="message-list">
<div class="message-item" v-for="(msg, index) in conversationHistory" :key="index">
<!-- 用户消息 -->
<div v-if="msg.role === 'user'" class="user-message">
<div class="message-bubble user-bubble">
<span v-html="formatMessage(msg.content)"></span>
<!-- 图片附件 -->
<div class="image-attachments" v-if="msg.image_url && msg.image_url.length > 0">
<div v-for="(img, imgIndex) in msg.image_url" :key="imgIndex"
class="image-attachment">
<img :src="img" class="attached-image" />
</div>
</div>
<!-- 音频附件 -->
<div class="audio-attachment" v-if="msg.audio_url">
<audio controls class="audio-player">
<source :src="msg.audio_url" type="audio/wav">
{{ tm('status.audioNotSupported') }}
</audio>
</div>
</div>
<v-avatar class="user-avatar" color="deep-purple-lighten-3" size="36">
<v-icon icon="mdi-account" />
</v-avatar>
</div>
<!-- 机器人消息 -->
<div v-else class="bot-message">
<v-avatar class="bot-avatar" color="deep-purple" size="36">
<span class="text-h6"></span>
</v-avatar>
<div class="message-bubble bot-bubble">
<div v-html="formatMessage(msg.content)" class="markdown-content"></div>
</div>
</div>
</div>
</div>
<!-- 消息列表组件 -->
<MessageList v-else :messages="formattedMessages" :isDark="false" />
</div>
</v-card-text>
<v-divider></v-divider>
<v-card-actions class="pa-4">
<v-spacer></v-spacer>
<v-btn variant="text" @click="closeHistoryDialog">
@@ -263,8 +215,9 @@
<v-card-text class="py-4">
<v-form ref="form" v-model="valid">
<v-text-field v-model="editedItem.title" :label="tm('dialogs.edit.titleLabel')" :placeholder="tm('dialogs.edit.titlePlaceholder')" variant="outlined"
density="comfortable" class="mb-3"></v-text-field>
<v-text-field v-model="editedItem.title" :label="tm('dialogs.edit.titleLabel')"
:placeholder="tm('dialogs.edit.titlePlaceholder')" variant="outlined" density="comfortable"
class="mb-3"></v-text-field>
</v-form>
</v-card-text>
@@ -291,7 +244,8 @@
</v-card-title>
<v-card-text class="py-4">
<p>{{ tm('dialogs.delete.message', { title: selectedConversation?.title || tm('status.noTitle') }) }}</p>
<p>{{ tm('dialogs.delete.message', { title: selectedConversation?.title || tm('status.noTitle') })
}}</p>
</v-card-text>
<v-divider></v-divider>
@@ -308,6 +262,48 @@
</v-card>
</v-dialog>
<!-- 批量删除确认对话框 -->
<v-dialog v-model="dialogBatchDelete" max-width="600px">
<v-card>
<v-card-title class="bg-error text-white py-3">
<v-icon color="white" class="me-2">mdi-delete</v-icon>
<span>{{ tm('dialogs.batchDelete.title') }}</span>
</v-card-title>
<v-card-text class="py-4">
<p class="mb-3">{{ tm('dialogs.batchDelete.message', { count: selectedItems.length }) }}</p>
<!-- 显示前几个要删除的对话 -->
<div v-if="selectedItems.length > 0" class="mb-3">
<v-chip v-for="(item, index) in selectedItems.slice(0, 5)" :key="`${item.user_id}-${item.cid}`"
size="small" class="mr-1 mb-1" closable @click:close="removeFromSelection(item)"
:disabled="loading">
{{ item.title || tm('status.noTitle') }}
</v-chip>
<v-chip v-if="selectedItems.length > 5" size="small" class="mr-1 mb-1">
{{ tm('dialogs.batchDelete.andMore', { count: selectedItems.length - 5 }) }}
</v-chip>
</div>
<v-alert type="warning" variant="tonal" class="mb-3">
{{ tm('dialogs.batchDelete.warning') }}
</v-alert>
</v-card-text>
<v-divider></v-divider>
<v-card-actions class="pa-4">
<v-spacer></v-spacer>
<v-btn variant="text" @click="dialogBatchDelete = false" :disabled="loading">
{{ tm('dialogs.batchDelete.cancel') }}
</v-btn>
<v-btn color="error" @click="batchDeleteConversations" :loading="loading">
{{ tm('dialogs.batchDelete.confirm') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 消息提示 -->
<v-snackbar :timeout="3000" elevation="24" :color="messageType" v-model="showMessage" location="top">
{{ message }}
@@ -321,6 +317,7 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor';
import MarkdownIt from 'markdown-it';
import { useCommonStore } from '@/stores/common';
import { useI18n, useModuleI18n } from '@/i18n/composables';
import MessageList from '@/components/chat/MessageList.vue';
// 配置markdown-it默认安全设置
const md = new MarkdownIt({
@@ -333,13 +330,14 @@ const md = new MarkdownIt({
export default {
name: 'ConversationPage',
components: {
VueMonacoEditor
VueMonacoEditor,
MessageList
},
setup() {
const { t, locale } = useI18n();
const { tm } = useModuleI18n('features/conversation');
return {
t,
tm,
@@ -353,32 +351,13 @@ export default {
conversations: [],
search: '',
headers: [],
selectedItems: [], // 批量选择的项目
// 筛选条件
platformFilter: [],
messageTypeFilter: [],
lastAppliedFilters: null, // 记录上次应用的筛选条件
// 平台颜色映射
platformColors: {
'telegram': 'blue-lighten-1',
'qq_official': 'purple-lighten-1',
'qq_official_webhook': 'purple-lighten-2',
'aiocqhttp': 'deep-purple-lighten-1',
'lark': 'cyan-darken-1',
'wecom': 'green-darken-1',
'dingtalk': 'blue-darken-2',
'default': 'grey-lighten-1'
},
// 消息类型颜色映射
messageTypeColors: {
'GroupMessage': 'green',
'FriendMessage': 'blue',
'GuildMessage': 'purple',
'default': 'grey'
},
// 分页数据
pagination: {
page: 1,
@@ -386,11 +365,13 @@ export default {
total: 0,
total_pages: 0
},
pageSizeOptions: [10, 20, 50, 100], // 每页大小选项
// 对话框控制
dialogView: false,
dialogEdit: false,
dialogDelete: false,
dialogBatchDelete: false, // 批量删除对话框
// 选中的对话
selectedConversation: null,
@@ -402,11 +383,6 @@ export default {
cid: '',
title: ''
},
defaultItem: {
user_id: '',
cid: '',
title: ''
},
// 表单验证
valid: true,
@@ -454,12 +430,18 @@ export default {
tableHeaders() {
return [
{ title: this.tm('table.headers.title'), key: 'title', sortable: true },
{ title: this.tm('table.headers.platform'), key: 'platform', sortable: true, width: '120px' },
{ title: this.tm('table.headers.type'), key: 'messageType', sortable: true, width: '100px' },
{ title: this.tm('table.headers.sessionId'), key: 'sessionId', sortable: true, width: '100px' },
{
title: this.tm('table.headers.sessionId'),
align: 'center',
children: [
{ title: this.tm('table.headers.platform'), key: 'platform', sortable: true, width: '120px' },
{ title: this.tm('table.headers.type'), key: 'messageType', sortable: true, width: '100px' },
{ title: '会话 ID', key: 'sessionId', sortable: true, width: '100px' },
],
},
{ title: this.tm('table.headers.createdAt'), key: 'created_at', sortable: true, width: '180px' },
{ title: this.tm('table.headers.updatedAt'), key: 'updated_at', sortable: true, width: '180px' },
{ title: this.tm('table.headers.actions'), key: 'actions', sortable: false, align: 'center', width: '240px' }
{ title: this.tm('table.headers.actions'), key: 'actions', sortable: false, align: 'center' }
];
},
@@ -487,24 +469,40 @@ export default {
];
},
// 筛选后的对话 - 现在只用于额外的客户端筛选排除astrbot和webchat
filteredConversations() {
return this.conversations.filter(conv => {
// 排除 user_id 为 astrbot 或 platform 为 webchat 的对话
if (conv.user_id === 'astrbot' || conv.sessionInfo?.platform === 'webchat') {
return false;
}
return true;
});
},
// 当前的筛选条件对象
currentFilters() {
const platforms = this.platformFilter.map(item =>
typeof item === 'object' ? item.value : item
);
return {
platforms: this.platformFilter,
platforms: platforms,
messageTypes: this.messageTypeFilter,
search: this.search
};
},
// 将对话历史转换为 MessageList 组件期望的格式
formattedMessages() {
return this.conversationHistory.map(msg => {
console.log('处理消息:', msg.role, msg.image_url, msg.audio_url);
if (msg.role === 'user') {
return {
content: {
type: 'user',
message: this.extractTextFromContent(msg.content),
image_url: this.extractImagesFromContent(msg.content),
}
};
} else {
return {
content: {
type: 'bot',
message: this.extractTextFromContent(msg.content),
embedded_images: this.extractImagesFromContent(msg.content),
}
};
}
});
}
},
@@ -541,16 +539,6 @@ export default {
};
},
// 重置过滤条件
resetFilters() {
this.platformFilter = [];
this.messageTypeFilter = [];
this.search = '';
// 立即应用筛选,不使用防抖
this.pagination.page = 1;
this.fetchConversations();
},
// 处理表格选项变更(页面大小等)
handleTableOptions(options) {
// 处理页面大小变更
@@ -579,16 +567,6 @@ export default {
return { platform: 'default', messageType: 'default', sessionId: userId };
},
// 获取平台对应的颜色
getPlatformColor(platform) {
return this.platformColors[platform] || this.platformColors.default;
},
// 获取消息类型对应的颜色
getMessageTypeColor(messageType) {
return this.messageTypeColors[messageType] || this.messageTypeColors.default;
},
// 获取消息类型的显示文本
getMessageTypeDisplay(messageType) {
const typeMap = {
@@ -610,9 +588,12 @@ export default {
page_size: this.pagination.page_size
};
// 添加筛选条件
// 添加筛选条件 - 处理combobox的混合数据格式
if (this.platformFilter.length > 0) {
params.platforms = this.platformFilter.join(',');
const platforms = this.platformFilter.map(item =>
typeof item === 'object' ? item.value : item
);
params.platforms = platforms.join(',');
}
if (this.messageTypeFilter.length > 0) {
@@ -620,19 +601,15 @@ export default {
}
if (this.search) {
params.search = this.search;
params.search = this.search.trim();
}
// 添加排除条件
params.exclude_ids = 'astrbot';
params.exclude_platforms = 'webchat';
console.log(`正在请求对话列表: /api/conversation/list 参数:`, params);
const response = await axios.get('/api/conversation/list', { params });
console.log('收到对话列表响应:', response.data);
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
if (response.data.status === "ok") {
@@ -836,6 +813,88 @@ export default {
}
} catch (error) {
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.deleteError'));
} finally {
this.loading = false;
this.selectedItems = this.selectedItems.filter(item =>
!(item.user_id === this.selectedConversation.user_id && item.cid === this.selectedConversation.cid)
);
this.selectedConversation = null;
}
},
// 处理页面大小变更
onPageSizeChange() {
this.pagination.page = 1; // 重置到第一页
this.fetchConversations();
},
// 确认批量删除
confirmBatchDelete() {
if (this.selectedItems.length === 0) {
this.showErrorMessage(this.tm('messages.noItemSelected'));
return;
}
this.dialogBatchDelete = true;
},
// 从选择中移除项目
removeFromSelection(item) {
const index = this.selectedItems.findIndex(selected =>
selected.user_id === item.user_id && selected.cid === item.cid
);
if (index !== -1) {
this.selectedItems.splice(index, 1);
}
},
// 批量删除对话
async batchDeleteConversations() {
if (this.selectedItems.length === 0) {
this.showErrorMessage(this.tm('messages.noItemSelected'));
return;
}
this.loading = true;
try {
// 准备批量删除的数据
const conversations = this.selectedItems.map(item => ({
user_id: item.user_id,
cid: item.cid
}));
const response = await axios.post('/api/conversation/delete', {
conversations: conversations
});
if (response.data.status === "ok") {
const result = response.data.data;
this.dialogBatchDelete = false;
this.selectedItems = []; // 清空选择
// 显示结果消息
if (result.failed_count > 0) {
this.showErrorMessage(
this.tm('messages.batchDeletePartial', {
deleted: result.deleted_count,
failed: result.failed_count
})
);
} else {
this.showSuccessMessage(
this.tm('messages.batchDeleteSuccess', {
count: result.deleted_count
})
);
}
// 刷新列表
this.fetchConversations();
} else {
this.showErrorMessage(response.data.message || this.tm('messages.batchDeleteError'));
}
} catch (error) {
console.error('批量删除对话出错:', error);
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.batchDeleteError'));
} finally {
this.loading = false;
}
@@ -858,35 +917,6 @@ export default {
}).format(date);
},
// 格式化消息内容
formatMessage(content) {
// content 可能是数组
// [{"type": "image_url", "image_url": {"url": url_or_base64}}, {"type": "text", "text": "text"}]
let final_content = content;
if (Array.isArray(content)) {
// 处理数组内容
final_content = content.map(item => {
if (item.type === 'image_url') {
return `<img src="${item.image_url.url}" alt="Image" />`;
} else if (item.type === 'text') {
return item.text;
}
return '';
}).join('\n');
} else if (typeof content === 'object') {
// 处理对象内容
final_content = Object.values(content).join('');
} else if (typeof content === 'string') {
// 处理字符串内容
final_content = content;
} else if (!final_content) return this.tm('status.emptyContent');
// 使用markdown-it处理默认安全html: false会禁用HTML标签
return md.render(final_content);
},
// 显示成功消息
showSuccessMessage(message) {
this.message = message;
@@ -899,16 +929,36 @@ export default {
this.message = message;
this.messageType = 'error';
this.showMessage = true;
},
// 从内容中提取文本
extractTextFromContent(content) {
if (typeof content === 'string') {
return content;
} else if (Array.isArray(content)) {
return content.filter(item => item.type === 'text')
.map(item => item.text)
.join('\n');
} else if (typeof content === 'object') {
return Object.values(content).filter(val => typeof val === 'string').join('');
}
return '';
},
// 从内容中提取图片URL
extractImagesFromContent(content) {
if (Array.isArray(content)) {
return content.filter(item => item.type === 'image_url')
.map(item => item.image_url?.url)
.filter(url => url);
}
return [];
}
}
}
</script>
<style>
.conversation-page {
padding: 20px;
}
.actions-wrapper {
display: flex;
justify-content: flex-end;
@@ -918,11 +968,6 @@ export default {
.action-button {
border-radius: 8px;
font-weight: 500;
transition: all 0.2s ease;
}
.action-button:hover {
transform: translateY(-2px);
}
.monaco-editor-container {
@@ -932,7 +977,7 @@ export default {
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
}
/* 聊天消息样式 */
/* 聊天消息容器样式 */
.conversation-messages-container {
max-height: 500px;
overflow-y: auto;
@@ -941,87 +986,6 @@ export default {
background-color: #f9f9f9;
}
.message-list {
display: flex;
flex-direction: column;
gap: 16px;
}
.message-item {
margin-bottom: 8px;
animation: fadeIn 0.3s ease-out;
}
.user-message {
display: flex;
justify-content: flex-end;
align-items: flex-start;
gap: 12px;
}
.bot-message {
display: flex;
justify-content: flex-start;
align-items: flex-start;
gap: 12px;
}
.message-bubble {
padding: 12px 16px;
border-radius: 18px;
max-width: 80%;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
.user-bubble {
background-color: #f0f4ff;
color: #333;
border-top-right-radius: 4px;
}
.bot-bubble {
background-color: #fff;
border: 1px solid #eaeaea;
color: #333;
border-top-left-radius: 4px;
}
.user-avatar,
.bot-avatar {
margin-top: 2px;
}
/* 附件样式 */
.image-attachments {
display: flex;
gap: 8px;
margin-top: 8px;
flex-wrap: wrap;
}
.attached-image {
width: 120px;
height: 120px;
object-fit: cover;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
transition: transform 0.2s ease;
}
.attached-image:hover {
transform: scale(1.05);
}
.audio-attachment {
margin-top: 8px;
}
.audio-player {
width: 100%;
height: 36px;
border-radius: 18px;
}
/* 对话详情卡片 */
.conversation-detail-card {
max-height: 90vh;
@@ -1029,95 +993,6 @@ export default {
flex-direction: column;
}
/* Markdown内容样式 */
.markdown-content {
font-family: inherit;
line-height: 1.6;
}
.markdown-content h1,
.markdown-content h2,
.markdown-content h3,
.markdown-content h4,
.markdown-content h5,
.markdown-content h6 {
margin-top: 16px;
margin-bottom: 10px;
font-weight: 600;
color: #333;
}
.markdown-content h1 {
font-size: 1.8em;
border-bottom: 1px solid #eee;
padding-bottom: 6px;
}
.markdown-content h2 {
font-size: 1.5em;
}
.markdown-content h3 {
font-size: 1.3em;
}
.markdown-content li {
margin-left: 16px;
margin-bottom: 4px;
}
.markdown-content p {
margin-top: 10px;
margin-bottom: 10px;
}
.markdown-content pre {
background-color: #f8f8f8;
padding: 12px;
border-radius: 6px;
overflow-x: auto;
margin: 12px 0;
}
.markdown-content code {
background-color: #f5f0ff;
padding: 2px 4px;
border-radius: 4px;
font-family: 'Fira Code', monospace;
font-size: 0.9em;
color: #673ab7;
}
.markdown-content img {
max-width: 100%;
border-radius: 8px;
margin: 10px 0;
}
.markdown-content blockquote {
border-left: 4px solid #673ab7;
padding-left: 16px;
color: #666;
margin: 16px 0;
}
.markdown-content table {
border-collapse: collapse;
width: 100%;
margin: 16px 0;
}
.markdown-content th,
.markdown-content td {
border: 1px solid #eee;
padding: 8px 12px;
text-align: left;
}
.markdown-content th {
background-color: #f5f0ff;
}
/* 动画 */
@keyframes fadeIn {
from {

View File

@@ -470,7 +470,7 @@ const refreshPluginMarket = async () => {
trimExtensionName();
checkAlreadyInstalled();
checkUpdate();
toast(tm('messages.refreshSuccess'), "success");
} catch (err) {
toast(tm('messages.refreshFailed') + " " + err, "error");
@@ -518,27 +518,12 @@ onMounted(async () => {
<v-row>
<v-col cols="12" md="12">
<v-card variant="flat">
<v-card-item>
<template v-slot:prepend>
<div class="plugin-page-icon d-flex justify-center align-center rounded-lg mr-4">
<v-icon size="36" color="primary">mdi-puzzle</v-icon>
</div>
</template>
<v-card-title class="text-h4 font-weight-bold">
{{ tm('title') }}
</v-card-title>
<v-card-subtitle class="text-subtitle-1 mt-1 text-medium-emphasis">
{{ tm('subtitle') }}
</v-card-subtitle>
</v-card-item>
<!-- 标签页 -->
<v-card-text>
<!-- 标签栏和搜索栏 - 响应式布局 -->
<div class="mb-4">
<div class="mb-4 d-flex flex-wrap">
<!-- 标签栏 -->
<v-tabs v-model="activeTab" color="primary" class="mb-3">
<v-tabs v-model="activeTab" color="primary">
<v-tab value="installed">
<v-icon class="mr-2">mdi-puzzle</v-icon>
{{ tm('tabs.installed') }}
@@ -550,17 +535,16 @@ onMounted(async () => {
</v-tabs>
<!-- 搜索栏 - 在移动端时独占一行 -->
<v-row class="mb-2">
<v-col cols="12" sm="6" md="4" lg="3">
<v-text-field v-if="activeTab == 'market'" v-model="marketSearch" density="compact"
:label="tm('search.marketPlaceholder')" prepend-inner-icon="mdi-magnify" variant="solo-filled" flat
hide-details single-line>
</v-text-field>
<v-text-field v-else v-model="pluginSearch" density="compact" :label="tm('search.placeholder')"
prepend-inner-icon="mdi-magnify" variant="solo-filled" flat hide-details single-line>
</v-text-field>
</v-col>
</v-row>
<div style="flex-grow: 1; min-width: 250px; max-width: 400px; margin-left: auto; margin-top: 8px;">
<v-text-field v-if="activeTab == 'market'" v-model="marketSearch" density="compact"
:label="tm('search.marketPlaceholder')" prepend-inner-icon="mdi-magnify" variant="solo-filled" flat
hide-details single-line>
</v-text-field>
<v-text-field v-else v-model="pluginSearch" density="compact" :label="tm('search.placeholder')"
prepend-inner-icon="mdi-magnify" variant="solo-filled" flat hide-details single-line>
</v-text-field>
</div>
</div>
@@ -776,18 +760,13 @@ onMounted(async () => {
<div class="d-flex align-center mb-2" style="justify-content: space-between;">
<h2>{{ tm('market.allPlugins') }}</h2>
<div class="d-flex align-center">
<v-btn
variant="tonal"
size="small"
@click="refreshPluginMarket"
:loading="refreshingMarket"
class="mr-2"
>
<v-btn variant="tonal" size="small" @click="refreshPluginMarket" :loading="refreshingMarket"
class="mr-2">
<v-icon>mdi-refresh</v-icon>
{{ tm('buttons.refresh') }}
</v-btn>
<v-switch v-model="showPluginFullName" :label="tm('market.showFullName')" hide-details density="compact"
style="margin-left: 12px" />
<v-switch v-model="showPluginFullName" :label="tm('market.showFullName')" hide-details
density="compact" style="margin-left: 12px" />
</div>
</div>
@@ -827,7 +806,7 @@ onMounted(async () => {
<template v-slot:item.tags="{ item }">
<span v-if="item.tags.length === 0">-</span>
<v-chip v-for="tag in item.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'"
size="x-small" v-show="tag !== 'danger'">
size="x-small" v-show="tag !== 'danger'" class="ma-1">
{{ tag }}</v-chip>
</template>
<template v-slot:item.actions="{ item }">

View File

@@ -23,7 +23,7 @@
<!-- 人格卡片网格 -->
<v-row>
<v-col v-for="persona in personas" :key="persona.persona_id" cols="12" md="6" lg="4" xl="3">
<v-card class="persona-card" elevation="2" rounded="lg" @click="viewPersona(persona)">
<v-card class="persona-card" rounded="md" @click="viewPersona(persona)">
<v-card-title class="d-flex justify-space-between align-center">
<div class="text-truncate ml-2">
{{ persona.persona_id }}
@@ -296,9 +296,9 @@
<v-card-text>
<div class="mb-4">
<h4 class="text-h6 mb-2">{{ tm('form.systemPrompt') }}</h4>
<div class="system-prompt-content">
<pre class="system-prompt-content">
{{ viewingPersona.system_prompt }}
</div>
</pre>
</div>
<div v-if="viewingPersona.begin_dialogs && viewingPersona.begin_dialogs.length > 0" class="mb-4">
@@ -759,10 +759,6 @@ export default {
cursor: pointer;
}
.persona-card:hover {
box-shadow: 0 8px 25px 0 rgba(0, 0, 0, 0.15);
}
.system-prompt-preview {
font-size: 14px;
line-height: 1.4;
@@ -775,10 +771,10 @@ export default {
}
.system-prompt-content {
background-color: rgba(var(--v-theme-surface-variant), 0.3);
max-height: 400px;
overflow: auto;
padding: 12px;
border-radius: 8px;
font-family: 'Roboto Mono', monospace;
font-size: 14px;
line-height: 1.5;
white-space: pre-wrap;

View File

@@ -10,7 +10,8 @@
{{ tm('subtitle') }}
</p>
</div>
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddPlatformDialog = true" rounded="xl" size="x-large">
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddPlatformDialog = true"
rounded="xl" size="x-large">
{{ tm('addAdapter') }}
</v-btn>
</v-row>
@@ -25,14 +26,9 @@
<v-row v-else>
<v-col v-for="(platform, index) in config_data.platform || []" :key="index" cols="12" md="6" lg="4" xl="3">
<item-card
:item="platform"
title-field="id"
enabled-field="enable"
:bglogo="getPlatformIcon(platform.type || platform.id)"
@toggle-enabled="platformStatusChange"
@delete="deletePlatform"
@edit="editPlatform">
<item-card :item="platform" title-field="id" enabled-field="enable"
:bglogo="getPlatformIcon(platform.type || platform.id)" @toggle-enabled="platformStatusChange"
@delete="deletePlatform" @edit="editPlatform">
</item-card>
</v-col>
</v-row>
@@ -61,59 +57,13 @@
</v-container>
<!-- 添加平台适配器对话框 -->
<v-dialog v-model="showAddPlatformDialog" max-width="900px" min-height="80%">
<v-card class="platform-selection-dialog">
<v-card-title class="bg-primary text-white py-3 px-4" style="display: flex; align-items: center;">
<v-icon color="white" class="me-2">mdi-plus-circle</v-icon>
<span>{{ tm('dialog.addPlatform') }}</span>
<v-spacer></v-spacer>
<v-btn icon variant="text" color="white" @click="showAddPlatformDialog = false">
<v-icon>mdi-close</v-icon>
</v-btn>
</v-card-title>
<v-card-text class="pa-4" style="overflow-y: auto;">
<v-row class="mt-1">
<v-col v-for="(template, name) in metadata['platform_group']?.metadata?.platform?.config_template || {}"
:key="name" cols="12" sm="6" md="6">
<v-card variant="outlined" hover class="platform-card" @click="selectPlatformTemplate(name)">
<div class="platform-card-content">
<div class="platform-card-text">
<v-card-title class="platform-card-title">{{ tm('dialog.connectTitle', { name }) }}</v-card-title>
<v-card-text class="text-caption text-medium-emphasis platform-card-description">
{{ getPlatformDescription(template, name) }}
</v-card-text>
</div>
<div class="platform-card-logo">
<img :src="getPlatformIcon(template.type)" v-if="getPlatformIcon(template.type)" class="platform-logo-img">
<div v-else class="platform-logo-fallback">
{{ name[0].toUpperCase() }}
</div>
</div>
</div>
</v-card>
</v-col>
<v-col
v-if="Object.keys(metadata['platform_group']?.metadata?.platform?.config_template || {}).length === 0"
cols="12">
<v-alert type="info" variant="tonal">
{{ tm('dialog.noTemplates') }}
</v-alert>
</v-col>
</v-row>
</v-card-text>
</v-card>
</v-dialog>
<AddNewPlatform v-model:show="showAddPlatformDialog" :metadata="metadata"
@select-template="selectPlatformTemplate" />
<!-- 配置对话框 -->
<v-dialog v-model="showPlatformCfg" persistent width="900px" max-width="90%">
<v-card>
<v-card-title class="bg-primary text-white py-3">
<v-icon color="white" class="me-2">{{ updatingMode ? 'mdi-pencil' : 'mdi-plus' }}</v-icon>
<span>{{ updatingMode ? tm('dialog.edit') : tm('dialog.add') }} {{ newSelectedPlatformName }} {{
tm('dialog.adapter') }}</span>
</v-card-title>
<v-card
:title="updatingMode ? tm('dialog.edit') : tm('dialog.add') + ` ${newSelectedPlatformName} ` + tm('dialog.adapter')">
<v-card-text class="py-4">
<v-row>
<v-col cols="12">
@@ -164,7 +114,7 @@
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">{{ tm('dialog.idConflict.confirm')
}}</v-btn>
}}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -177,7 +127,9 @@
</v-card-title>
<v-card-text class="py-4">
<p>{{ tm('dialog.securityWarning.aiocqhttpTokenMissing') }}</p>
<span><a href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7" target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
<span><a
href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7"
target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
</v-card-text>
<v-card-actions class="px-4 pb-4">
<v-spacer></v-spacer>
@@ -199,8 +151,10 @@ import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import ItemCard from '@/components/shared/ItemCard.vue';
import AddNewPlatform from '@/components/platform/AddNewPlatform.vue';
import { useCommonStore } from '@/stores/common';
import { useI18n, useModuleI18n } from '@/i18n/composables';
import { getPlatformIcon, getTutorialLink } from '@/utils/platformUtils';
export default {
name: 'PlatformPage',
@@ -208,7 +162,8 @@ export default {
AstrBotConfig,
WaitingForRestart,
ConsoleDisplayer,
ItemCard
ItemCard,
AddNewPlatform
},
setup() {
const { t } = useI18n();
@@ -285,66 +240,14 @@ export default {
},
methods: {
// 从工具函数导入
getPlatformIcon,
openTutorial() {
const tutorialUrl = this.getTutorialLink(this.newSelectedPlatformConfig.type);
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);
window.open(tutorialUrl, '_blank');
},
getPlatformIcon(name) {
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
} else if (name === 'wecom') {
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
} else if (name === 'lark') {
return new URL('@/assets/images/platform_logos/lark.png', import.meta.url).href
} else if (name === 'dingtalk') {
return new URL('@/assets/images/platform_logos/dingtalk.svg', import.meta.url).href
} else if (name === 'telegram') {
return new URL('@/assets/images/platform_logos/telegram.svg', import.meta.url).href
} else if (name === 'discord') {
return new URL('@/assets/images/platform_logos/discord.svg', import.meta.url).href
} else if (name === 'slack') {
return new URL('@/assets/images/platform_logos/slack.svg', import.meta.url).href
} else if (name === 'kook') {
return new URL('@/assets/images/platform_logos/kook.png', import.meta.url).href
} else if (name === 'vocechat') {
return new URL('@/assets/images/platform_logos/vocechat.png', import.meta.url).href
} else if (name === 'satori' || name === 'Satori') {
return new URL('@/assets/images/platform_logos/satori.png', import.meta.url).href
}
},
getTutorialLink(platform_type) {
let tutorial_map = {
"qq_official_webhook": "https://docs.astrbot.app/deploy/platform/qqofficial/webhook.html",
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
"wechatpadpro": "https://docs.astrbot.app/deploy/platform/wechat/wechatpadpro.html",
"weixin_official_account": "https://docs.astrbot.app/deploy/platform/weixin-official-account.html",
"discord": "https://docs.astrbot.app/deploy/platform/discord.html",
"slack": "https://docs.astrbot.app/deploy/platform/slack.html",
"kook": "https://docs.astrbot.app/deploy/platform/kook.html",
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
}
return tutorial_map[platform_type] || "https://docs.astrbot.app";
},
getPlatformDescription(template, name) {
// special judge for community platforms
if (name.includes('vocechat')) {
return "由 @HikariFroya 提供。";
} else if (name.includes('kook')) {
return "由 @wuyan1003 提供。"
}
},
getConfig() {
axios.get('/api/config/get').then((res) => {
this.config_data = res.data.data.config;
@@ -355,7 +258,7 @@ export default {
});
},
// 添加一个新方法来选择平台模板
// 选择平台模板
selectPlatformTemplate(name) {
this.newSelectedPlatformName = name;
this.showPlatformCfg = true;
@@ -363,7 +266,6 @@ export default {
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(
this.metadata['platform_group']?.metadata?.platform?.config_template[name] || {}
));
this.showAddPlatformDialog = false;
},
addFromDefaultConfigTmpl(index) {
@@ -480,7 +382,7 @@ export default {
this.oneBotEmptyTokenWarningResolve(continueWithWarning);
this.oneBotEmptyTokenWarningResolve = null;
}
if (!continueWithWarning) {
this.loading = false;
}
@@ -532,84 +434,4 @@ export default {
padding: 20px;
padding-top: 8px;
}
.platform-selection-dialog .v-card-title {
border-top-left-radius: 4px;
border-top-right-radius: 4px;
}
.platform-card {
transition: all 0.3s ease;
height: 100%;
cursor: pointer;
overflow: hidden;
position: relative;
}
.platform-card:hover {
transform: translateY(-4px);
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
border-color: var(--v-primary-base);
}
.platform-card-content {
display: flex;
align-items: center;
height: 100px;
padding: 16px;
position: relative;
z-index: 2;
}
.platform-card-text {
flex: 1;
display: flex;
flex-direction: column;
justify-content: center;
}
.platform-card-title {
font-size: 15px;
font-weight: 600;
margin-bottom: 4px;
padding: 0;
}
.platform-card-description {
padding: 0;
margin: 0;
}
.platform-card-logo {
position: absolute;
right: 0;
top: 0;
bottom: 0;
width: 80px;
display: flex;
align-items: center;
justify-content: center;
z-index: 1;
}
.platform-logo-img {
max-width: 60px;
max-height: 60px;
opacity: 0.6;
object-fit: contain;
}
.platform-logo-fallback {
width: 50px;
height: 50px;
border-radius: 50%;
background-color: var(--v-primary-base);
color: white;
display: flex;
align-items: center;
justify-content: center;
font-size: 24px;
font-weight: bold;
opacity: 0.3;
}
</style>

View File

@@ -56,14 +56,16 @@
<v-row v-else>
<v-col v-for="(provider, index) in filteredProviders" :key="index" cols="12" md="6" lg="4" xl="3">
<item-card
:item="provider"
title-field="id"
<item-card
:item="provider"
title-field="id"
enabled-field="enable"
@toggle-enabled="providerStatusChange"
:bglogo="getProviderIcon(provider.provider)"
@delete="deleteProvider"
@edit="configExistingProvider">
@delete="deleteProvider"
@edit="configExistingProvider"
@copy="copyProvider"
:show-copy-button="true">
<template v-slot:details="{ item }">
</template>
</item-card>
@@ -95,7 +97,7 @@
<v-alert v-if="providerStatuses.length === 0" type="info" variant="tonal">
{{ tm('availability.noData') }}
</v-alert>
<v-container v-else class="pa-0">
<v-row>
<v-col v-for="status in providerStatuses" :key="status.id" cols="12" sm="6" md="4">
@@ -113,7 +115,7 @@
></v-progress-circular>
<span class="font-weight-bold">{{ status.id }}</span>
<v-chip :color="getStatusColor(status.status)" size="small" class="ml-2">
{{ getStatusText(status.status) }}
</v-chip>
@@ -153,86 +155,15 @@
</v-container>
<!-- 添加提供商对话框 -->
<v-dialog v-model="showAddProviderDialog" max-width="1100px" min-height="95%">
<v-card class="provider-selection-dialog">
<v-card-title class="bg-primary text-white py-3 px-4" style="display: flex; align-items: center;">
<v-icon color="white" class="me-2">mdi-plus-circle</v-icon>
<span>{{ tm('dialogs.addProvider.title') }}</span>
<v-spacer></v-spacer>
<v-btn icon variant="text" color="white" @click="showAddProviderDialog = false">
<v-icon>mdi-close</v-icon>
</v-btn>
</v-card-title>
<v-card-text class="pa-4" style="overflow-y: auto;">
<v-tabs v-model="activeProviderTab" grow slider-color="primary" bg-color="background">
<v-tab value="chat_completion" class="font-weight-medium px-3">
<v-icon start>mdi-message-text</v-icon>
{{ tm('dialogs.addProvider.tabs.basic') }}
</v-tab>
<v-tab value="speech_to_text" class="font-weight-medium px-3">
<v-icon start>mdi-microphone-message</v-icon>
{{ tm('dialogs.addProvider.tabs.speechToText') }}
</v-tab>
<v-tab value="text_to_speech" class="font-weight-medium px-3">
<v-icon start>mdi-volume-high</v-icon>
{{ tm('dialogs.addProvider.tabs.textToSpeech') }}
</v-tab>
<v-tab value="embedding" class="font-weight-medium px-3">
<v-icon start>mdi-code-json</v-icon>
{{ tm('dialogs.addProvider.tabs.embedding') }}
</v-tab>
<v-tab value="rerank" class="font-weight-medium px-3">
<v-icon start>mdi-compare-vertical</v-icon>
{{ tm('dialogs.addProvider.tabs.rerank') }}
</v-tab>
</v-tabs>
<v-window v-model="activeProviderTab" class="mt-4">
<v-window-item v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
:key="tabType"
:value="tabType">
<v-row class="mt-1">
<v-col v-for="(template, name) in getTemplatesByType(tabType)"
:key="name"
cols="12" sm="6" md="4">
<v-card variant="outlined" hover class="provider-card" @click="selectProviderTemplate(name)">
<div class="provider-card-content">
<div class="provider-card-text">
<v-card-title class="provider-card-title">接入 {{ name }}</v-card-title>
<v-card-text class="text-caption text-medium-emphasis provider-card-description">
{{ getProviderDescription(template, name) }}
</v-card-text>
</div>
<div class="provider-card-logo">
<img :src="getProviderIcon(template.provider)" v-if="getProviderIcon(template.provider)" class="provider-logo-img">
<div v-else class="provider-logo-fallback">
{{ name[0].toUpperCase() }}
</div>
</div>
</div>
</v-card>
</v-col>
<v-col v-if="Object.keys(getTemplatesByType(tabType)).length === 0" cols="12">
<v-alert type="info" variant="tonal">
{{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
</v-alert>
</v-col>
</v-row>
</v-window-item>
</v-window>
</v-card-text>
</v-card>
</v-dialog>
<AddNewProvider
v-model:show="showAddProviderDialog"
:metadata="metadata"
@select-template="selectProviderTemplate"
/>
<!-- 配置对话框 -->
<v-dialog v-model="showProviderCfg" width="900" persistent>
<v-card>
<v-card-title class="bg-primary text-white py-3">
<v-icon color="white" class="me-2">{{ updatingMode ? 'mdi-pencil' : 'mdi-plus' }}</v-icon>
<span>{{ updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') }} {{ newSelectedProviderName }} {{ tm('dialogs.config.provider') }}</span>
</v-card-title>
<v-card :title="updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') + ` ${newSelectedProviderName} ` + tm('dialogs.config.provider')">
<v-card-text class="py-4">
<AstrBotConfig
:iterable="newSelectedProviderConfig"
@@ -307,7 +238,9 @@ import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import ItemCard from '@/components/shared/ItemCard.vue';
import AddNewProvider from '@/components/provider/AddNewProvider.vue';
import { useModuleI18n } from '@/i18n/composables';
import { getProviderIcon } from '@/utils/providerUtils';
export default {
name: 'ProviderPage',
@@ -315,7 +248,8 @@ export default {
AstrBotConfig,
WaitingForRestart,
ConsoleDisplayer,
ItemCard
ItemCard,
AddNewProvider
},
setup() {
const { tm } = useModuleI18n('features/provider');
@@ -348,17 +282,16 @@ export default {
save_message_success: "success",
showConsole: false,
// 显示状态部分
showStatus: false,
// 供应商状态相关
providerStatuses: [],
loadingStatus: false,
// 新增提供商对话框相关
showAddProviderDialog: false,
activeProviderTab: 'chat_completion',
// 添加提供商类型分类
activeProviderTypeTab: 'all',
@@ -370,6 +303,7 @@ export default {
"googlegenai_chat_completion": "chat_completion",
"zhipu_chat_completion": "chat_completion",
"dify": "chat_completion",
"coze": "chat_completion",
"dashscope": "chat_completion",
"openai_whisper_api": "speech_to_text",
"openai_whisper_selfhost": "speech_to_text",
@@ -437,7 +371,7 @@ export default {
}
};
},
// 根据选择的标签过滤提供商列表
filteredProviders() {
if (!this.config_data.provider || this.activeProviderTypeTab === 'all') {
@@ -449,7 +383,7 @@ export default {
if (provider.provider_type) {
return provider.provider_type === this.activeProviderTypeTab;
}
// 否则使用映射关系
const mappedType = this.oldVersionProviderTypeMapping[provider.type];
return mappedType === this.activeProviderTypeTab;
@@ -472,6 +406,9 @@ export default {
});
},
// 从工具函数导入
getProviderIcon,
// 获取空列表文本
getEmptyText() {
if (this.activeProviderTypeTab === 'all') {
@@ -481,63 +418,11 @@ export default {
}
},
// 按提供商类型获取模板列表
getTemplatesByType(type) {
const templates = this.metadata['provider_group']?.metadata?.provider?.config_template || {};
const filtered = {};
for (const [name, template] of Object.entries(templates)) {
if (template.provider_type === type) {
filtered[name] = template;
}
}
return filtered;
},
// 获取提供商类型对应的图标
getProviderIcon(type) {
const icons = {
'openai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
'azure': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/azure.svg',
'xai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/xai.svg',
'anthropic': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/anthropic.svg',
'ollama': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ollama.svg',
'google': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
'deepseek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
'modelscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/modelscope.svg',
'zhipu': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
'siliconflow': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',
'minimax': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/minimax.svg',
'302ai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
'microsoft': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/microsoft.svg',
'vllm': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/vllm.svg',
};
return icons[type] || '';
},
// 获取Tab类型的中文名称
getTabTypeName(tabType) {
return this.messages.tabTypes[tabType] || tabType;
},
// 获取提供商简介
getProviderDescription(template, name) {
if (name == 'OpenAI') {
return this.tm('providers.description.openai', { type: template.type });
} else if (name == 'vLLM Rerank') {
return this.tm('providers.description.vllm_rerank', { type: template.type });
}
return this.tm('providers.description.default', { type: template.type });
},
// 选择提供商模板
selectProviderTemplate(name) {
this.newSelectedProviderName = name;
@@ -546,7 +431,6 @@ export default {
this.newSelectedProviderConfig = JSON.parse(JSON.stringify(
this.metadata['provider_group']?.metadata?.provider?.config_template[name] || {}
));
this.showAddProviderDialog = false;
},
configExistingProvider(provider) {
@@ -657,6 +541,40 @@ export default {
}
},
async copyProvider(providerToCopy) {
console.log('copyProvider triggered for:', providerToCopy);
// 1. 创建深拷贝
const newProviderConfig = JSON.parse(JSON.stringify(providerToCopy));
// 2. 生成唯一的 ID
const generateUniqueId = (baseId) => {
let newId = `${baseId}_copy`;
let counter = 1;
const existingIds = this.config_data.provider.map(p => p.id);
while (existingIds.includes(newId)) {
newId = `${baseId}_copy_${counter}`;
counter++;
}
return newId;
};
newProviderConfig.id = generateUniqueId(providerToCopy.id);
// 3. 设置为禁用状态,等待用户手动开启
newProviderConfig.enable = false;
this.loading = true;
try {
// 4. 调用后端接口创建
const res = await axios.post('/api/config/provider/new', newProviderConfig);
this.showSuccess(res.data.message || `成功复制并创建了 ${newProviderConfig.id}`);
this.getConfig(); // 5. 刷新列表
} catch (err) {
this.showError(err.response?.data?.message || err.message);
} finally {
this.loading = false;
}
},
deleteProvider(provider) {
if (confirm(this.tm('messages.confirm.delete', { id: provider.id }))) {
axios.post('/api/config/provider/delete', { id: provider.id }).then((res) => {
@@ -694,14 +612,14 @@ export default {
this.save_message_success = "error";
this.save_message_snack = true;
},
// 获取供应商状态
async fetchProviderStatus() {
if (this.loadingStatus) return;
this.loadingStatus = true;
this.showStatus = true; // 自动展开状态部分
// 1. 立即初始化UI为pending状态
this.providerStatuses = this.config_data.provider.map(p => ({
id: p.id,
@@ -818,89 +736,6 @@ export default {
padding-top: 8px;
}
.provider-card {
transition: all 0.3s ease;
height: 100%;
cursor: pointer;
overflow: hidden;
position: relative;
}
.provider-card:hover {
transform: translateY(-4px);
box-shadow: 0 4px 25px 0 rgba(0, 0, 0, 0.05);
border-color: var(--v-primary-base);
}
.provider-card-content {
display: flex;
align-items: center;
height: 100px;
padding: 16px;
position: relative;
z-index: 2;
}
.provider-card-text {
flex: 1;
display: flex;
flex-direction: column;
justify-content: center;
}
.provider-card-title {
font-size: 15px;
font-weight: 600;
margin-bottom: 4px;
padding: 0;
}
.provider-card-description {
padding: 0;
margin: 0;
}
.provider-card-logo {
position: absolute;
right: 0;
top: 0;
bottom: 0;
width: 80px;
display: flex;
align-items: center;
justify-content: center;
z-index: 1;
}
.provider-logo-img {
width: 60px;
height: 60px;
opacity: 0.6;
object-fit: contain;
}
.provider-logo-fallback {
width: 50px;
height: 50px;
border-radius: 50%;
background-color: var(--v-primary-base);
color: white;
display: flex;
align-items: center;
justify-content: center;
font-size: 24px;
font-weight: bold;
opacity: 0.3;
}
.v-tabs {
border-radius: 8px;
}
.v-window {
border-radius: 4px;
}
.status-card {
height: 120px;
overflow-y: auto;

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,5 @@
<template>
<div class="dashboard-container">
<div class="dashboard-header">
<h1 class="dashboard-title">{{ t('title') }}</h1>
<div class="dashboard-subtitle">{{ t('subtitle') }}</div>
</div>
<v-slide-y-transition>
<v-row v-if="noticeTitle && noticeContent" class="notice-row">
<v-alert
@@ -166,29 +161,10 @@ export default {
background-color: var(--v-theme-background);
min-height: calc(100vh - 64px);
border-radius: 10px;
}
.dashboard-header {
margin-bottom: 24px;
padding-bottom: 16px;
border-bottom: 1px solid rgba(0, 0, 0, 0.06);
}
.dashboard-title {
font-size: 24px;
font-weight: 600;
color: var(--v-theme-primaryText);
margin-bottom: 4px;
}
.dashboard-subtitle {
font-size: 14px;
color: var(--v-theme-secondaryText);
}
.notice-row {
margin-bottom: 20px;
margin-bottom: 16px;
}
.dashboard-alert {

View File

@@ -98,6 +98,7 @@ export default {
.stat-value-wrapper {
display: flex;
flex-wrap: wrap;
align-items: baseline;
justify-content: space-between;
margin-bottom: 4px;

View File

@@ -44,7 +44,7 @@
<div class="stat-box" :class="{'trend-up': growthRate > 0, 'trend-down': growthRate < 0}">
<div class="stat-label">{{ t('charts.messageTrend.growthRate') }}</div>
<div class="stat-number">
<v-icon size="small" :icon="growthRate > 0 ? 'mdi-arrow-up' : 'mdi-arrow-down'"></v-icon>
<v-icon v-show="growthRate !== 0" size="small" :icon="growthRate > 0 ? 'mdi-arrow-up' : 'mdi-arrow-down'"></v-icon>
{{ Math.abs(growthRate) }}%
</div>
</div>
@@ -303,8 +303,10 @@ export default {
.chart-header {
display: flex;
flex-wrap: wrap;
justify-content: space-between;
align-items: flex-start;
gap: 10px;
margin-bottom: 20px;
}
@@ -321,7 +323,7 @@ export default {
}
.time-select {
max-width: 150px;
max-width: fit-content;
font-size: 14px;
}
@@ -349,6 +351,7 @@ export default {
font-weight: 600;
color: var(--v-theme-primaryText);
display: flex;
flex-wrap: wrap;
align-items: center;
}

View File

@@ -527,12 +527,11 @@ UID: {user_id} 此 ID 可用于设置管理员。
return
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
if provider and provider.meta().type in ["dify", "coze"]:
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message(
"已重置当前 Dify 会话,新聊天将更换到新的会话。"
"已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。"
)
)
return
@@ -755,8 +754,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
创建新对话
"""
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
if provider and provider.meta().type in ["dify", "coze"]:
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message("成功,下次聊天将是新对话。")
@@ -783,8 +781,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
"""创建新群聊对话"""
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
if provider and provider.meta().type in ["dify", "coze"]:
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message("成功,下次聊天将是新对话。")
@@ -823,7 +820,6 @@ UID: {user_id} 此 ID 可用于设置管理员。
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
if not data["data"]:
message.set_result(MessageEventResult().message("未找到任何对话。"))
@@ -1214,6 +1210,12 @@ UID: {user_id} 此 ID 可用于设置管理员。
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
if cfg.get("group_name_display") and event.message_obj.group_id:
group_name = event.message_obj.group.group_name
if group_name:
req.system_prompt += f"\nGroup name: {group_name}\n"
# 启用附加时间戳
if cfg.get("datetime_system_prompt"):
current_time = None
@@ -1230,6 +1232,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
)
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
img_cap_prov_id = cfg.get("default_image_caption_provider_id")
if req.conversation:
# persona inject
persona_id = req.conversation.persona_id or cfg.get("default_personality")
@@ -1270,7 +1273,6 @@ UID: {user_id} 此 ID 可用于设置管理员。
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
# image caption
img_cap_prov_id = cfg.get("default_image_caption_provider_id")
if img_cap_prov_id and req.image_urls:
img_cap_prompt = cfg.get(
"image_caption_prompt", "Please describe the image."
@@ -1307,9 +1309,12 @@ UID: {user_id} 此 ID 可用于设置管理员。
break
if image_seg:
try:
if prov := self.context.get_using_provider(
event.unified_msg_origin
):
prov = None
if img_cap_prov_id:
prov = self.context.get_provider_by_id(img_cap_prov_id)
if prov is None:
prov = self.context.get_using_provider(event.unified_msg_origin)
if prov:
llm_resp = await prov.text_chat(
prompt="Please describe the image content.",
image_urls=[await image_seg.convert_to_file_path()],
@@ -1318,6 +1323,8 @@ UID: {user_id} 此 ID 可用于设置管理员。
req.system_prompt += (
f"Image Caption: {llm_resp.completion_text}\n"
)
else:
logger.warning("No provider found for image captioning.")
except BaseException as e:
logger.error(f"处理引用图片失败: {e}")
@@ -1337,22 +1344,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
logger.error(f"ltm: {e}")
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd")
@filter.command("alter_cmd", alias={"alter"})
async def alter_cmd(self, event: AstrMessageEvent):
# token = event.message_str.split(" ")
token = self.parse_commands(event.message_str)
if token.len < 2:
if token.len < 3:
yield event.plain_result(
"可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令\n /alter_cmd reset config 打开reset权限配置"
"该指令用于设置指令或指令组的权限。\n"
"格式: /alter_cmd <cmd_name> <admin/member>\n"
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
"/alter_cmd reset config 打开 reset 权限配置"
)
return
cmd_name = token.get(1)
cmd_type = token.get(2)
cmd_name = " ".join(token.tokens[1:-1])
cmd_type = token.get(-1)
# ============================
# 对reset权限进行特殊处理
# ============================
if cmd_name == "reset" and cmd_type == "config":
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get("astrbot", {})
@@ -1402,16 +1409,18 @@ UID: {user_id} 此 ID 可用于设置管理员。
# 查找指令
found_command = None
cmd_group = False
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
if filter_.command_name == cmd_name:
if filter_.equals(cmd_name):
found_command = handler
break
elif isinstance(filter_, CommandGroupFilter):
if cmd_name == filter_.group_name:
if filter_.equals(cmd_name):
found_command = handler
cmd_group = True
break
if not found_command:
@@ -1448,8 +1457,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
else filter.PermissionType.MEMBER
),
)
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
cmd_group_str = "指令组" if cmd_group else "指令"
yield event.plain_result(
f"已将「{cmd_name}{cmd_group_str} 的权限级别调整为 {cmd_type}"
)
async def update_reset_permission(self, scene_key: str, perm_type: str):
"""更新reset命令在特定场景下的权限设置

View File

@@ -178,7 +178,7 @@ class Main(star.Star):
return results
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str:
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
event.set_result(
MessageEventResult().message(
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。"
@@ -210,7 +210,7 @@ class Main(star.Star):
processed_results = await asyncio.gather(*tasks, return_exceptions=True)
ret = ""
for processed_result in processed_results:
if isinstance(processed_result, Exception):
if isinstance(processed_result, BaseException):
logger.error(f"Error processing search result: {processed_result}")
continue
ret += processed_result
@@ -335,7 +335,7 @@ class Main(star.Star):
@filter.on_llm_request(priority=-10000)
async def edit_web_search_tools(
self, event: AstrMessageEvent, req: ProviderRequest
) -> str:
):
"""Get the session conversation for the given event."""
cfg = self.context.get_config(umo=event.unified_msg_origin)
prov_settings = cfg.get("provider_settings", {})
@@ -347,6 +347,9 @@ class Main(star.Star):
req.func_tool = tool_set.get_full_tool_set()
tool_set = req.func_tool
if not tool_set:
return
if not websearch_enable:
# pop tools
for tool_name in self.TOOLS:

View File

@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.1.2"
version = "4.2.0"
description = "易上手的多平台 LLM 聊天机器人及开发框架"
readme = "README.md"
requires-python = ">=3.10"
@@ -49,6 +49,8 @@ dependencies = [
"watchfiles>=1.0.5",
"websockets>=15.0.1",
"wechatpy>=1.8.18",
"audioop-lts ; python_full_version >= '3.13'",
"click>=8.2.1",
]
[project.scripts]

View File

@@ -42,4 +42,5 @@ slack-sdk
pydub
sqlmodel
deprecated
sqlalchemy[asyncio]
sqlalchemy[asyncio]
audioop-lts; python_version>='3.13'

View File

@@ -1,5 +1,7 @@
import pytest
import pytest_asyncio
import os
import asyncio
from quart import Quart
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.db.sqlite import SQLiteDatabase
@@ -9,36 +11,46 @@ from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@pytest.fixture(scope="module")
def core_lifecycle_td():
db = SQLiteDatabase("data/data_v3.db")
@pytest_asyncio.fixture(scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
db = SQLiteDatabase(str(tmp_db_path))
log_broker = LogBroker()
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
return core_lifecycle_td
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
await core_lifecycle.initialize()
return core_lifecycle
@pytest.fixture(scope="module")
def app(core_lifecycle_td):
db = SQLiteDatabase("data/data_v3.db")
server = AstrBotDashboard(core_lifecycle_td, db)
def app(core_lifecycle_td: AstrBotCoreLifecycle):
"""Creates a Quart app instance for testing."""
shutdown_event = asyncio.Event()
# The db instance is already part of the core_lifecycle_td
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
return server.app
@pytest.fixture(scope="module")
def header():
return {}
@pytest_asyncio.fixture(scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login",
json={
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
},
)
data = await response.get_json()
assert data["status"] == "ok"
token = data["data"]["token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_init_core_lifecycle_td(core_lifecycle_td):
await core_lifecycle_td.initialize()
assert core_lifecycle_td is not None
@pytest.mark.asyncio
async def test_auth_login(
app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict
):
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Tests the login functionality with both wrong and correct credentials."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login", json={"username": "wrong", "password": "password"}
@@ -55,31 +67,32 @@ async def test_auth_login(
)
data = await response.get_json()
assert data["status"] == "ok" and "token" in data["data"]
header["Authorization"] = f"Bearer {data['data']['token']}"
@pytest.mark.asyncio
async def test_get_stat(app: Quart, header: dict):
async def test_get_stat(app: Quart, authenticated_header: dict):
test_client = app.test_client()
response = await test_client.get("/api/stat/get")
assert response.status_code == 401
response = await test_client.get("/api/stat/get", headers=header)
response = await test_client.get("/api/stat/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok" and "platform" in data["data"]
@pytest.mark.asyncio
async def test_plugins(app: Quart, header: dict):
async def test_plugins(app: Quart, authenticated_header: dict):
test_client = app.test_client()
# 已经安装的插件
response = await test_client.get("/api/plugin/get", headers=header)
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
# 插件市场
response = await test_client.get("/api/plugin/market_list", headers=header)
response = await test_client.get(
"/api/plugin/market_list", headers=authenticated_header
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
@@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict):
response = await test_client.post(
"/api/plugin/install",
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
headers=header,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict):
# 插件更新
response = await test_client.post(
"/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header
"/api/plugin/update",
json={"name": "astrbot_plugin_essential"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict):
response = await test_client.post(
"/api/plugin/uninstall",
json={"name": "astrbot_plugin_essential"},
headers=header,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -132,9 +147,9 @@ async def test_plugins(app: Quart, header: dict):
@pytest.mark.asyncio
async def test_check_update(app: Quart, header: dict):
async def test_check_update(app: Quart, authenticated_header: dict):
test_client = app.test_client()
response = await test_client.get("/api/update/check", headers=header)
response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "success"
@@ -142,24 +157,45 @@ async def test_check_update(app: Quart, header: dict):
@pytest.mark.asyncio
async def test_do_update(
app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path_factory,
):
global VERSION
test_client = app.test_client()
os.makedirs("data/astrbot_release", exist_ok=True)
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
VERSION = "114.514.1919810"
response = await test_client.post(
"/api/update/do", headers=header, json={"version": "latest"}
# Use a temporary path for the mock update to avoid side effects
temp_release_dir = tmp_path_factory.mktemp("release")
release_path = temp_release_dir / "astrbot"
async def mock_update(*args, **kwargs):
"""Mocks the update process by creating a directory in the temp path."""
os.makedirs(release_path, exist_ok=True)
return
async def mock_download_dashboard(*args, **kwargs):
"""Mocks the dashboard download to prevent network access."""
return
async def mock_pip_install(*args, **kwargs):
"""Mocks pip install to prevent actual installation."""
return
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard
)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "error" # 已经是最新版本
response = await test_client.post(
"/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False}
"/api/update/do",
headers=authenticated_header,
json={"version": "v3.4.0", "reboot": False},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert os.path.exists("data/astrbot_release/astrbot")
assert os.path.exists(release_path)

View File

@@ -1,5 +1,9 @@
import os
import sys
# 将项目根目录添加到 sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import pytest
from unittest import mock
from main import check_env, check_dashboard_files
@@ -27,29 +31,58 @@ def test_check_env(monkeypatch):
@pytest.mark.asyncio
async def test_check_dashboard_files(monkeypatch):
async def test_check_dashboard_files_not_exists(monkeypatch):
"""Tests dashboard download when files do not exist."""
monkeypatch.setattr(os.path, "exists", lambda x: False)
async def mock_get(*args, **kwargs):
class MockResponse:
status = 200
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
mock_download.assert_called_once()
async def read(self):
return b"content"
return MockResponse()
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
"""Tests that dashboard is not downloaded when it exists and version matches."""
# Mock os.path.exists to return True
monkeypatch.setattr(os.path, "exists", lambda x: True)
with mock.patch("aiohttp.ClientSession.get", new=mock_get):
with mock.patch("builtins.open", mock.mock_open()) as mock_file:
with mock.patch("zipfile.ZipFile.extractall") as mock_extractall:
# Mock get_dashboard_version to return the current version
with mock.patch("main.get_dashboard_version") as mock_get_version:
# We need to import VERSION from main's context
from main import VERSION
async def mock_aenter(_):
await check_dashboard_files()
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
mock_extractall.assert_called_once()
mock_get_version.return_value = f"v{VERSION}"
async def mock_aexit(obj, exc_type, exc, tb):
return
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
# Assert that download_dashboard was NOT called
mock_download.assert_not_called()
mock_extractall.__aenter__ = mock_aenter
mock_extractall.__aexit__ = mock_aexit
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
"""Tests that a warning is logged when dashboard version mismatches."""
monkeypatch.setattr(os.path, "exists", lambda x: True)
with mock.patch("main.get_dashboard_version") as mock_get_version:
mock_get_version.return_value = "v0.0.1" # A different version
with mock.patch("main.logger.warning") as mock_logger_warning:
await check_dashboard_files()
mock_logger_warning.assert_called_once()
call_args, _ = mock_logger_warning.call_args
assert "不符" in call_args[0]
@pytest.mark.asyncio
async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
"""Tests that providing a valid webui_dir skips all checks."""
valid_dir = "/tmp/my-custom-webui"
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.get_dashboard_version") as mock_get_version:
result = await check_dashboard_files(webui_dir=valid_dir)
assert result == valid_dir
mock_download.assert_not_called()
mock_get_version.assert_not_called()

View File

@@ -1,285 +0,0 @@
import pytest
import logging
import os
import asyncio
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import (
AstrBotMessage,
MessageMember,
MessageType,
)
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core.message.components import Plain, At
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context
from asyncio import Queue
SESSION_ID_IN_WHITELIST = "test_sid_wl"
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
TEST_LLM_PROVIDER = {
"id": "zhipu_default",
"type": "openai_chat_completion",
"enable": True,
"key": [os.getenv("ZHIPU_API_KEY")],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
}
TEST_COMMANDS = [
["help", "已注册的 AstrBot 内置指令"],
["tool ls", "函数工具"],
["tool on websearch", "激活工具"],
["tool off websearch", "停用工具"],
["plugin", "已加载的插件"],
["t2i", "文本转图片模式"],
["sid", "此 ID 可用于设置会话白名单。"],
["op test_op", "授权成功。"],
["deop test_op", "取消授权成功。"],
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
["provider", "当前载入的 LLM 提供商"],
["reset", "重置成功"],
# ["model", "查看、切换提供商模型列表"],
["history", "历史记录:"],
["key", "当前 Key"],
["persona", "[Persona]"],
]
class FakeAstrMessageEvent(AstrMessageEvent):
def __init__(self, abm: AstrBotMessage = None):
meta = PlatformMetadata("test_platform", "test")
super().__init__(
message_str=abm.message_str,
message_obj=abm,
platform_meta=meta,
session_id=abm.session_id,
)
async def send(self, message: MessageChain):
await super().send(message)
@staticmethod
def create_fake_event(
message_str: str,
session_id: str = "test_sid",
is_at: bool = False,
is_group: bool = False,
sender_id: str = "123456",
):
abm = AstrBotMessage()
abm.message_str = message_str
abm.group_id = "test"
abm.message = [Plain(message_str)]
if is_at:
abm.message.append(At(qq="bot"))
abm.self_id = "bot"
abm.sender = MessageMember(sender_id, "mika")
abm.timestamp = 1234567890
abm.message_id = "test"
abm.session_id = session_id
if is_group:
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
return FakeAstrMessageEvent(abm)
@pytest.fixture(scope="module")
def event_queue():
return Queue()
@pytest.fixture(scope="module")
def config():
cfg = AstrBotConfig()
cfg["platform_settings"]["id_whitelist"] = [
"test_platform:FriendMessage:test_sid_wl",
"test_platform:GroupMessage:test_sid_wl",
]
cfg["admins_id"] = ["123456"]
cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"]
cfg["provider"] = [TEST_LLM_PROVIDER]
return cfg
@pytest.fixture(scope="module")
def db():
return SQLiteDatabase("data/data_v3.db")
@pytest.fixture(scope="module")
def platform_manager(event_queue, config):
return PlatformManager(config, event_queue)
@pytest.fixture(scope="module")
def provider_manager(config, db):
return ProviderManager(config, db)
@pytest.fixture(scope="module")
def star_context(event_queue, config, db, platform_manager, provider_manager):
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
return star_context
@pytest.fixture(scope="module")
def plugin_manager(star_context, config):
plugin_manager = PluginManager(star_context, config)
# await plugin_manager.reload()
asyncio.run(plugin_manager.reload())
return plugin_manager
@pytest.fixture(scope="module")
def pipeline_context(config, plugin_manager):
return PipelineContext(config, plugin_manager)
@pytest.fixture(scope="module")
def pipeline_scheduler(pipeline_context):
return PipelineScheduler(pipeline_context)
@pytest.mark.asyncio
async def test_platform_initialization(platform_manager: PlatformManager):
await platform_manager.initialize()
@pytest.mark.asyncio
async def test_provider_initialization(provider_manager: ProviderManager):
await provider_manager.initialize()
@pytest.mark.asyncio
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
await pipeline_scheduler.initialize()
@pytest.mark.asyncio
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
"""测试唤醒"""
# 群聊无 @ 无指令
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any(
"执行阶段 WhitelistCheckStage" not in message for message in caplog.messages
)
# 群聊有 @ 无指令
mock_event = FakeAstrMessageEvent.create_fake_event(
"test", is_group=True, is_at=True
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
# 群聊有指令
mock_event = FakeAstrMessageEvent.create_fake_event(
"/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST
)
await pipeline_scheduler.execute(mock_event)
assert mock_event._has_send_oper is True
@pytest.mark.asyncio
async def test_pipeline_wl(
pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog
):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"test", SESSION_ID_IN_WHITELIST, sender_id="123"
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any(
"不在会话白名单中,已终止事件传播。" not in message
for message in caplog.messages
), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any(
"不在会话白名单中,已终止事件传播。" in message for message in caplog.messages
), "日志中未找到预期的消息"
@pytest.mark.asyncio
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
# 测试默认屏蔽词
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"色情", session_id=SESSION_ID_IN_WHITELIST
) # 测试需要。
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), (
"日志中未找到预期的消息"
)
# 测试额外屏蔽词
mock_event = FakeAstrMessageEvent.create_fake_event(
"TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), (
"日志中未找到预期的消息"
)
mock_event = FakeAstrMessageEvent.create_fake_event(
"_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" not in message for message in caplog.messages)
# TODO: 测试 百度AI 的内容安全检查
@pytest.mark.asyncio
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert mock_event.get_result() is not None
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
@pytest.mark.asyncio
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert any(
"web_searcher - search_from_search_engine" in message
for message in caplog.messages
)
@pytest.mark.asyncio
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
for command in TEST_COMMANDS:
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
command[0], session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
assert any(command[1] in message for message in caplog.messages)

View File

@@ -1,5 +1,6 @@
import pytest
import os
from unittest.mock import MagicMock
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@@ -8,18 +9,51 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db.sqlite import SQLiteDatabase
from asyncio import Queue
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase("data/data_v3.db")
star_context = Context(event_queue, config, db)
@pytest.fixture
def plugin_manager_pm():
return PluginManager(star_context, config)
def plugin_manager_pm(tmp_path):
"""
Provides a fully isolated PluginManager instance for testing.
- Uses a temporary directory for plugins.
- Uses a temporary database.
- Creates a fresh context for each test.
"""
# Create temporary resources
temp_plugins_path = tmp_path / "plugins"
temp_plugins_path.mkdir()
temp_db_path = tmp_path / "test_db.db"
# Create fresh, isolated instances for the context
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase(str(temp_db_path))
# Set the plugin store path in the config to the temporary directory
config.plugin_store_path = str(temp_plugins_path)
# Mock dependencies for the context
provider_manager = MagicMock()
platform_manager = MagicMock()
conversation_manager = MagicMock()
message_history_manager = MagicMock()
persona_manager = MagicMock()
astrbot_config_mgr = MagicMock()
star_context = Context(
event_queue,
config,
db,
provider_manager,
platform_manager,
conversation_manager,
message_history_manager,
persona_manager,
astrbot_config_mgr,
)
# Create the PluginManager instance
manager = PluginManager(star_context, config)
yield manager
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
@@ -36,48 +70,76 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
@pytest.mark.asyncio
async def test_plugin_crud(plugin_manager_pm: PluginManager):
"""测试插件安装和重载"""
os.makedirs("data/plugins", exist_ok=True)
async def test_install_plugin(plugin_manager_pm: PluginManager):
"""Tests successful plugin installation in an isolated environment."""
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert plugin_path is not None
plugin_info = await plugin_manager_pm.install_plugin(test_repo)
plugin_path = os.path.join(
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
)
assert plugin_info is not None
assert os.path.exists(plugin_path)
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# shutil.rmtree(plugin_path)
assert any(md.name == "astrbot_plugin_essential" for md in star_registry), (
"Plugin 'astrbot_plugin_essential' was not loaded into star_registry."
)
# install plugin which is not exists
@pytest.mark.asyncio
async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that installing a non-existent plugin raises an exception."""
with pytest.raises(Exception):
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
await plugin_manager_pm.install_plugin(
"https://github.com/Soulter/non_existent_repo"
)
# update
@pytest.mark.asyncio
async def test_update_plugin(plugin_manager_pm: PluginManager):
"""Tests updating an existing plugin in an isolated environment."""
# First, install the plugin
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
await plugin_manager_pm.install_plugin(test_repo)
# Then, update it
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
# uninstall
@pytest.mark.asyncio
async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that updating a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("non_existent_plugin")
@pytest.mark.asyncio
async def test_uninstall_plugin(plugin_manager_pm: PluginManager):
"""Tests successful plugin uninstallation in an isolated environment."""
# First, install the plugin
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
await plugin_manager_pm.install_plugin(test_repo)
plugin_path = os.path.join(
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
)
assert os.path.exists(plugin_path) # Pre-condition
# Then, uninstall it
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
assert not os.path.exists(plugin_path)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), (
"Plugin 'astrbot_plugin_essential' was not unloaded from star_registry."
)
assert not any(
"astrbot_plugin_essential" in md.handler_module_path
for md in star_handlers_registry
), (
"Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry."
)
@pytest.mark.asyncio
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that uninstalling a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
# TODO: file installation
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")

3407
uv.lock generated

File diff suppressed because it is too large Load Diff