Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5b3ce8424 | ||
|
|
80cbbfa5ca | ||
|
|
9177bb660f | ||
|
|
a3df39a01a | ||
|
|
25dce05cbb | ||
|
|
1542ea3e03 | ||
|
|
6084abbcfe | ||
|
|
ed19b63914 | ||
|
|
4efeb85296 | ||
|
|
fc76665615 | ||
|
|
3a044bb71a | ||
|
|
cddd606562 | ||
|
|
7a5bc51c11 | ||
|
|
9f939b4b6f | ||
|
|
80a86f5b1b | ||
|
|
a0ce1855ab | ||
|
|
a4b43b884a | ||
|
|
824c0f6667 | ||
|
|
a030fe8491 | ||
|
|
3a9429e8ef | ||
|
|
c4eb1ab748 | ||
|
|
29ed19d600 | ||
|
|
0cc65513a5 | ||
|
|
debc048659 | ||
|
|
92f5c918dd | ||
|
|
9519f1e8e2 | ||
|
|
a8f874bf05 | ||
|
|
9d9917e45b | ||
|
|
91ee0a870d | ||
|
|
6cbbffc5a9 | ||
|
|
8f26fd34d1 | ||
|
|
fda655f6d7 | ||
|
|
a663d6509b | ||
|
|
9ec8839efa | ||
|
|
a7a0350eb2 | ||
|
|
39a7a0d960 | ||
|
|
7740e1e131 | ||
|
|
9dce1ed47e | ||
|
|
e84a00d3a5 | ||
|
|
88a944cb57 | ||
|
|
20c32e72cc | ||
|
|
4788c20816 | ||
|
|
e83fc570a4 | ||
|
|
e841b6af88 | ||
|
|
ea6f209557 | ||
|
|
9bfa726107 | ||
|
|
d24902c66d | ||
|
|
72aea2d3f3 | ||
|
|
dc9612d564 | ||
|
|
1770556d56 | ||
|
|
888fb84aee | ||
|
|
d597fd056d | ||
|
|
dea0ab3974 | ||
|
|
da6facd7d7 | ||
|
|
bb8ab5f173 | ||
|
|
ac8a541059 | ||
|
|
0e66771f0e | ||
|
|
d3a295a801 | ||
|
|
f2df771771 | ||
|
|
7b72cd87a5 | ||
|
|
9431efc6d1 |
51
.github/PULL_REQUEST_TEMPLATE.md
vendored
51
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,19 +1,46 @@
|
||||
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
||||
解决了 #XYZ
|
||||
<!-- 如果有的话,请指定此 PR 旨在解决的 ISSUE 编号。 -->
|
||||
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
|
||||
|
||||
### Motivation
|
||||
fixes #XYZ
|
||||
|
||||
<!--解释为什么要改动-->
|
||||
---
|
||||
|
||||
### Modifications
|
||||
### Motivation / 动机
|
||||
|
||||
<!--简单解释你的改动-->
|
||||
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
|
||||
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
|
||||
|
||||
### Check
|
||||
### Modifications / 改动点
|
||||
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
|
||||
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
|
||||
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
|
||||
|
||||
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||
- [ ] 👀 我的更改经过良好的测试
|
||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
||||
- [ ] 😮 我的更改没有引入恶意代码
|
||||
### Verification Steps / 验证步骤
|
||||
|
||||
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤(例如:1. 导航到... 2. 点击...)。-->
|
||||
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
|
||||
|
||||
### Screenshots or Test Results / 运行截图或测试结果
|
||||
|
||||
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
|
||||
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
|
||||
|
||||
### Compatibility & Breaking Changes / 兼容性与破坏性变更
|
||||
|
||||
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
|
||||
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
|
||||
|
||||
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
|
||||
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
|
||||
|
||||
---
|
||||
|
||||
### Checklist / 检查清单
|
||||
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
||||
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
|
||||
|
||||
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
|
||||
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
|
||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
|
||||
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.
|
||||
|
||||
36
.github/auto_assign.yml
vendored
Normal file
36
.github/auto_assign.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
# Set to true to add reviewers to pull requests
|
||||
addReviewers: true
|
||||
|
||||
# Set to true to add assignees to pull requests
|
||||
addAssignees: false
|
||||
|
||||
# A list of reviewers to be added to pull requests (GitHub user name)
|
||||
reviewers:
|
||||
- Soulter
|
||||
- Raven95676
|
||||
- Larch-C
|
||||
- anka-afk
|
||||
- advent259141
|
||||
# - zouyonghe
|
||||
|
||||
# A number of reviewers added to the pull request
|
||||
# Set 0 to add all the reviewers (default: 0)
|
||||
numberOfReviewers: 2
|
||||
|
||||
# A list of assignees, overrides reviewers if set
|
||||
# assignees:
|
||||
# - assigneeA
|
||||
|
||||
# A number of assignees to add to the pull request
|
||||
# Set to 0 to add all of the assignees.
|
||||
# Uses numberOfReviewers if unset.
|
||||
# numberOfAssignees: 2
|
||||
|
||||
# A list of keywords to be skipped the process that add reviewers if pull requests include it
|
||||
skipKeywords:
|
||||
- wip
|
||||
- draft
|
||||
|
||||
# A list of users to be skipped by both the add reviewers and add assignees processes
|
||||
# skipUsers:
|
||||
# - dependabot[bot]
|
||||
2
.github/workflows/auto_release.yml
vendored
2
.github/workflows/auto_release.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
34
.github/workflows/code-format.yml
vendored
Normal file
34
.github/workflows/code-format.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Code Format Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
push:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
format-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install UV
|
||||
run: pip install uv
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync
|
||||
|
||||
- name: Check code formatting with ruff
|
||||
run: |
|
||||
uv run ruff format --check .
|
||||
|
||||
- name: Check code style with ruff
|
||||
run: |
|
||||
uv run ruff check .
|
||||
2
.github/workflows/coverage_test.yml
vendored
2
.github/workflows/coverage_test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
1
.github/workflows/dashboard_ci.yml
vendored
1
.github/workflows/dashboard_ci.yml
vendored
@@ -37,6 +37,7 @@ jobs:
|
||||
!dist/**/*.md
|
||||
|
||||
- name: Create GitHub Release
|
||||
if: github.event_name == 'push'
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
tag: release-${{ github.sha }}
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'Stale issue message'
|
||||
|
||||
24
README.md
24
README.md
@@ -14,7 +14,6 @@
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||

|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
@@ -100,7 +99,7 @@ uv run main.py
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 开发者群:753075035
|
||||
- 开发者群:975206796
|
||||
- 开发者群(备份):295657329
|
||||
|
||||
### Telegram 群组
|
||||
@@ -111,7 +110,6 @@ uv run main.py
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
| 平台 | 支持性 |
|
||||
@@ -128,6 +126,8 @@ uv run main.py
|
||||
| Discord | ✔ |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
| Satori | ✔ |
|
||||
| Misskey | ✔ |
|
||||
|
||||
## ⚡ 提供商支持情况
|
||||
|
||||
@@ -173,7 +173,6 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
@@ -182,10 +181,18 @@ pre-commit install
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
此外,本项目的诞生离不开以下开源项目:
|
||||
此外,本项目的诞生离不开以下开源项目的帮助:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
||||
|
||||
另外,一些同类型其他的活跃开源 Bot 项目:
|
||||
|
||||
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
|
||||
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
|
||||
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
|
||||
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
|
||||
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
|
||||
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
@@ -193,14 +200,11 @@ pre-commit install
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from astrbot.core.star.register import (
|
||||
register_permission_type as permission_type,
|
||||
register_custom_filter as custom_filter,
|
||||
register_on_astrbot_loaded as on_astrbot_loaded,
|
||||
register_on_platform_loaded as on_platform_loaded,
|
||||
register_on_llm_request as on_llm_request,
|
||||
register_on_llm_response as on_llm_response,
|
||||
register_llm_tool as llm_tool,
|
||||
@@ -41,6 +42,7 @@ __all__ = [
|
||||
"custom_filter",
|
||||
"PermissionType",
|
||||
"on_astrbot_loaded",
|
||||
"on_platform_loaded",
|
||||
"on_llm_request",
|
||||
"llm_tool",
|
||||
"on_decorating_result",
|
||||
|
||||
@@ -124,15 +124,17 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
result.append({
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
}
|
||||
)
|
||||
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
@@ -142,15 +144,17 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for plugin_id, plugin_info in data.items():
|
||||
online_plugins.append({
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
})
|
||||
online_plugins.append(
|
||||
{
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
|
||||
@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
|
||||
class Agent(Generic[TContext]):
|
||||
name: str
|
||||
instructions: str | None = None
|
||||
tools: list[str, FunctionTool] | None = None
|
||||
tools: list[str | FunctionTool] | None = None
|
||||
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
||||
|
||||
@@ -92,7 +92,7 @@ class MCPClient:
|
||||
self.session: Optional[mcp.ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
self.name = None
|
||||
self.name: str | None = None
|
||||
self.active: bool = True
|
||||
self.tools: list[mcp.Tool] = []
|
||||
self.server_errlogs: list[str] = []
|
||||
@@ -198,6 +198,8 @@ class MCPClient:
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
if not self.session:
|
||||
raise Exception("MCP Client is not initialized")
|
||||
response = await self.session.list_tools()
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
import typing as T
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@@ -14,4 +14,5 @@ class ContextWrapper(Generic[TContext]):
|
||||
context: TContext
|
||||
event: AstrMessageEvent
|
||||
|
||||
|
||||
NoContext = ContextWrapper[None]
|
||||
|
||||
@@ -258,7 +258,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
yield MessageChain(
|
||||
type="tool_direct_result"
|
||||
).base64_image(res.content[0].data)
|
||||
).base64_image(resource.blob)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -269,17 +269,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context,
|
||||
func_tool_name,
|
||||
func_tool_args,
|
||||
resp,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
@@ -289,27 +278,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||
|
||||
self.run_context.event.clear_result()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from deprecated import deprecated
|
||||
from typing import Awaitable, Literal, Any, Optional
|
||||
from typing import Awaitable, Callable, Literal, Any, Optional
|
||||
from .mcp_client import MCPClient
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
|
||||
class FunctionTool:
|
||||
"""A class representing a function tool that can be used in function calling."""
|
||||
|
||||
name: str | None = None
|
||||
name: str
|
||||
parameters: dict | None = None
|
||||
description: str | None = None
|
||||
handler: Awaitable | None = None
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str | None = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
@@ -51,7 +51,7 @@ class ToolSet:
|
||||
This class provides methods to add, remove, and retrieve tools, as well as
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||
|
||||
def __init__(self, tools: list[FunctionTool] = None):
|
||||
def __init__(self, tools: list[FunctionTool] | None = None):
|
||||
self.tools: list[FunctionTool] = tools or []
|
||||
|
||||
def empty(self) -> bool:
|
||||
@@ -79,7 +79,13 @@ class ToolSet:
|
||||
return None
|
||||
|
||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
||||
def add_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
):
|
||||
"""Add a function tool to the set."""
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
@@ -104,7 +110,7 @@ class ToolSet:
|
||||
self.remove_tool(name)
|
||||
|
||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||
def get_func(self, name: str) -> list[FunctionTool]:
|
||||
def get_func(self, name: str) -> FunctionTool | None:
|
||||
"""Get all function tools."""
|
||||
return self.get_tool(name)
|
||||
|
||||
@@ -125,7 +131,11 @@ class ToolSet:
|
||||
},
|
||||
}
|
||||
|
||||
if tool.parameters.get("properties") or not omit_empty_parameter_field:
|
||||
if (
|
||||
tool.parameters
|
||||
and tool.parameters.get("properties")
|
||||
or not omit_empty_parameter_field
|
||||
):
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
@@ -135,14 +145,14 @@ class ToolSet:
|
||||
"""Convert tools to Anthropic API format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
input_schema = {"type": "object"}
|
||||
if tool.parameters:
|
||||
input_schema["properties"] = tool.parameters.get("properties", {})
|
||||
input_schema["required"] = tool.parameters.get("required", [])
|
||||
tool_def = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": tool.parameters.get("properties", {}),
|
||||
"required": tool.parameters.get("required", []),
|
||||
},
|
||||
"input_schema": input_schema,
|
||||
}
|
||||
result.append(tool_def)
|
||||
return result
|
||||
@@ -210,14 +220,15 @@ class ToolSet:
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
tools = []
|
||||
for tool in self.tools:
|
||||
d = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": convert_schema(tool.parameters),
|
||||
}
|
||||
for tool in self.tools
|
||||
]
|
||||
if tool.parameters:
|
||||
d["parameters"] = convert_schema(tool.parameters)
|
||||
tools.append(d)
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.0.0"
|
||||
VERSION = "4.1.6"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -56,10 +56,11 @@ DEFAULT_CONFIG = {
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"websearch_provider": "default",
|
||||
"websearch_tavily_key": "",
|
||||
"websearch_tavily_key": [],
|
||||
"web_search_link": False,
|
||||
"display_reasoning_text": False,
|
||||
"identifier": False,
|
||||
"group_name_display": False,
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
@@ -235,6 +236,16 @@ CONFIG_METADATA_2 = {
|
||||
"discord_guild_id_for_debug": "",
|
||||
"discord_activity_name": "",
|
||||
},
|
||||
"Misskey": {
|
||||
"id": "misskey",
|
||||
"type": "misskey",
|
||||
"enable": False,
|
||||
"misskey_instance_url": "https://misskey.example",
|
||||
"misskey_token": "",
|
||||
"misskey_default_visibility": "public",
|
||||
"misskey_local_only": False,
|
||||
"misskey_enable_chat": True,
|
||||
},
|
||||
"Slack": {
|
||||
"id": "slack",
|
||||
"type": "slack",
|
||||
@@ -252,7 +263,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "satori",
|
||||
"enable": False,
|
||||
"satori_api_base_url": "http://localhost:5140/satori/v1",
|
||||
"satori_endpoint": "ws://127.0.0.1:5140/satori/v1/events",
|
||||
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
|
||||
"satori_token": "",
|
||||
"satori_auto_reconnect": True,
|
||||
"satori_heartbeat_interval": 10,
|
||||
@@ -261,34 +272,34 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"items": {
|
||||
"satori_api_base_url": {
|
||||
"description": "Satori API Base URL",
|
||||
"description": "Satori API 终结点",
|
||||
"type": "string",
|
||||
"hint": "The base URL for the Satori API.",
|
||||
"hint": "Satori API 的基础地址。",
|
||||
},
|
||||
"satori_endpoint": {
|
||||
"description": "Satori WebSocket Endpoint",
|
||||
"description": "Satori WebSocket 终结点",
|
||||
"type": "string",
|
||||
"hint": "The WebSocket endpoint for Satori events.",
|
||||
"hint": "Satori 事件的 WebSocket 端点。",
|
||||
},
|
||||
"satori_token": {
|
||||
"description": "Satori Token",
|
||||
"description": "Satori 令牌",
|
||||
"type": "string",
|
||||
"hint": "The token used for authenticating with the Satori API.",
|
||||
"hint": "用于 Satori API 身份验证的令牌。",
|
||||
},
|
||||
"satori_auto_reconnect": {
|
||||
"description": "Enable Auto Reconnect",
|
||||
"description": "启用自动重连",
|
||||
"type": "bool",
|
||||
"hint": "Whether to automatically reconnect the WebSocket on disconnection.",
|
||||
"hint": "断开连接时是否自动重新连接 WebSocket。",
|
||||
},
|
||||
"satori_heartbeat_interval": {
|
||||
"description": "Satori Heartbeat Interval",
|
||||
"description": "Satori 心跳间隔",
|
||||
"type": "int",
|
||||
"hint": "The interval (in seconds) for sending heartbeat messages.",
|
||||
"hint": "发送心跳消息的间隔(秒)。",
|
||||
},
|
||||
"satori_reconnect_delay": {
|
||||
"description": "Satori Reconnect Delay",
|
||||
"description": "Satori 重连延迟",
|
||||
"type": "int",
|
||||
"hint": "The delay (in seconds) before attempting to reconnect.",
|
||||
"hint": "尝试重新连接前的延迟时间(秒)。",
|
||||
},
|
||||
"slack_connection_mode": {
|
||||
"description": "Slack Connection Mode",
|
||||
@@ -336,6 +347,32 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||
},
|
||||
"misskey_instance_url": {
|
||||
"description": "Misskey 实例 URL",
|
||||
"type": "string",
|
||||
"hint": "例如 https://misskey.example,填写 Bot 账号所在的 Misskey 实例地址",
|
||||
},
|
||||
"misskey_token": {
|
||||
"description": "Misskey Access Token",
|
||||
"type": "string",
|
||||
"hint": "连接服务设置生成的 API 鉴权访问令牌(Access token)",
|
||||
},
|
||||
"misskey_default_visibility": {
|
||||
"description": "默认帖子可见性",
|
||||
"type": "string",
|
||||
"options": ["public", "home", "followers"],
|
||||
"hint": "机器人发帖时的默认可见性设置。public:公开,home:主页时间线,followers:仅关注者。",
|
||||
},
|
||||
"misskey_local_only": {
|
||||
"description": "仅限本站(不参与联合)",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
|
||||
},
|
||||
"misskey_enable_chat": {
|
||||
"description": "启用聊天消息响应",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人将会监听和响应私信聊天消息",
|
||||
},
|
||||
"telegram_command_register": {
|
||||
"description": "Telegram 命令注册",
|
||||
"type": "bool",
|
||||
@@ -599,6 +636,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||
},
|
||||
@@ -613,6 +651,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"xAI": {
|
||||
@@ -625,6 +664,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Anthropic": {
|
||||
@@ -654,6 +694,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"LM Studio": {
|
||||
@@ -667,6 +708,7 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "llama-3.1-8b",
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini(OpenAI兼容)": {
|
||||
@@ -682,6 +724,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini": {
|
||||
@@ -722,6 +765,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
@@ -734,6 +778,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"硅基流动": {
|
||||
@@ -749,6 +794,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
@@ -764,6 +810,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
"id": "compshare",
|
||||
@@ -777,6 +824,7 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Kimi": {
|
||||
@@ -789,6 +837,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"智谱 AI": {
|
||||
@@ -847,6 +896,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"FastGPT": {
|
||||
@@ -858,6 +908,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
"timeout": 60,
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"Whisper(API)": {
|
||||
"id": "whisper",
|
||||
@@ -1102,6 +1153,12 @@ CONFIG_METADATA_2 = {
|
||||
"render_type": "checkbox",
|
||||
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
|
||||
},
|
||||
"custom_extra_body": {
|
||||
"description": "自定义请求体参数",
|
||||
"type": "dict",
|
||||
"items": {},
|
||||
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值。",
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"invisible": True,
|
||||
@@ -1704,6 +1761,9 @@ CONFIG_METADATA_2 = {
|
||||
"identifier": {
|
||||
"type": "bool",
|
||||
},
|
||||
"group_name_display": {
|
||||
"type": "bool",
|
||||
},
|
||||
"datetime_system_prompt": {
|
||||
"type": "bool",
|
||||
},
|
||||
@@ -1883,17 +1943,31 @@ CONFIG_METADATA_3 = {
|
||||
"_special": "select_provider",
|
||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||
},
|
||||
"provider_stt_settings.enable": {
|
||||
"description": "默认启用语音转文本",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_stt_settings.provider_id": {
|
||||
"description": "语音转文本模型",
|
||||
"type": "string",
|
||||
"hint": "留空代表不使用。",
|
||||
"_special": "select_provider_stt",
|
||||
"condition": {
|
||||
"provider_stt_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_tts_settings.enable": {
|
||||
"description": "默认启用文本转语音",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_tts_settings.provider_id": {
|
||||
"description": "文本转语音模型",
|
||||
"type": "string",
|
||||
"hint": "留空代表不使用。",
|
||||
"_special": "select_provider_tts",
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.image_caption_prompt": {
|
||||
"description": "图片转述提示词",
|
||||
@@ -1938,7 +2012,9 @@ CONFIG_METADATA_3 = {
|
||||
},
|
||||
"provider_settings.websearch_tavily_key": {
|
||||
"description": "Tavily API Key",
|
||||
"type": "string",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可添加多个 Key 进行轮询。",
|
||||
"condition": {
|
||||
"provider_settings.websearch_provider": "tavily",
|
||||
},
|
||||
@@ -1961,6 +2037,11 @@ CONFIG_METADATA_3 = {
|
||||
"description": "用户识别",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.group_name_display": {
|
||||
"description": "显示群名称",
|
||||
"type": "bool",
|
||||
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
|
||||
},
|
||||
"provider_settings.datetime_system_prompt": {
|
||||
"description": "现实世界时间感知",
|
||||
"type": "bool",
|
||||
@@ -2108,41 +2189,41 @@ CONFIG_METADATA_3 = {
|
||||
"description": "内容安全",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"platform_settings.content_safety.also_use_in_response": {
|
||||
"content_safety.also_use_in_response": {
|
||||
"description": "同时检查模型的响应内容",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_settings.content_safety.baidu_aip.enable": {
|
||||
"content_safety.baidu_aip.enable": {
|
||||
"description": "使用百度内容安全审核",
|
||||
"type": "bool",
|
||||
"hint": "您需要手动安装 baidu-aip 库。",
|
||||
},
|
||||
"platform_settings.content_safety.baidu_aip.app_id": {
|
||||
"content_safety.baidu_aip.app_id": {
|
||||
"description": "App ID",
|
||||
"type": "string",
|
||||
"condition": {
|
||||
"platform_settings.content_safety.baidu_aip.enable": True,
|
||||
"content_safety.baidu_aip.enable": True,
|
||||
},
|
||||
},
|
||||
"platform_settings.content_safety.baidu_aip.api_key": {
|
||||
"content_safety.baidu_aip.api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
"condition": {
|
||||
"platform_settings.content_safety.baidu_aip.enable": True,
|
||||
"content_safety.baidu_aip.enable": True,
|
||||
},
|
||||
},
|
||||
"platform_settings.content_safety.baidu_aip.secret_key": {
|
||||
"content_safety.baidu_aip.secret_key": {
|
||||
"description": "Secret Key",
|
||||
"type": "string",
|
||||
"condition": {
|
||||
"platform_settings.content_safety.baidu_aip.enable": True,
|
||||
"content_safety.baidu_aip.enable": True,
|
||||
},
|
||||
},
|
||||
"platform_settings.content_safety.internal_keywords.enable": {
|
||||
"content_safety.internal_keywords.enable": {
|
||||
"description": "关键词检查",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_settings.content_safety.internal_keywords.extra_keywords": {
|
||||
"content_safety.internal_keywords.extra_keywords": {
|
||||
"description": "额外关键词",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
|
||||
@@ -53,7 +53,7 @@ async def do_migration_v4(
|
||||
await migration_webchat_data(db_helper, platform_id_map)
|
||||
|
||||
# 执行偏好设置迁移
|
||||
await migration_preferences(db_helper,platform_id_map)
|
||||
await migration_preferences(db_helper, platform_id_map)
|
||||
|
||||
# 执行平台统计表迁移
|
||||
await migration_platform_table(db_helper, platform_id_map)
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
@@ -42,4 +43,5 @@ class SharedPreferences:
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
|
||||
|
||||
sp = SharedPreferences()
|
||||
|
||||
@@ -4,6 +4,7 @@ from astrbot.core.db.po import Platform, Stats
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话存储
|
||||
@@ -76,7 +77,7 @@ PRAGMA encoding = 'UTF-8';
|
||||
"""
|
||||
|
||||
|
||||
class SQLiteDatabase():
|
||||
class SQLiteDatabase:
|
||||
def __init__(self, db_path: str) -> None:
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
@@ -18,6 +18,7 @@ from astrbot.core.db.po import (
|
||||
from sqlalchemy import select, update, delete, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import or_
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
@@ -153,8 +154,22 @@ class SQLiteDatabase(BaseDatabase):
|
||||
ConversationV2.platform_id.in_(platform_ids)
|
||||
)
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
base_query = base_query.where(
|
||||
ConversationV2.title.ilike(f"%{search_query}%")
|
||||
or_(
|
||||
ConversationV2.title.ilike(f"%{search_query}%"),
|
||||
ConversationV2.content.ilike(f"%{search_query}%"),
|
||||
ConversationV2.user_id.ilike(f"%{search_query}%"),
|
||||
)
|
||||
)
|
||||
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||
for msg_type in kwargs["message_types"]:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.user_id.ilike(f"%:{msg_type}:%")
|
||||
)
|
||||
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.platform_id.in_(kwargs["platforms"])
|
||||
)
|
||||
|
||||
# Get total count matching the filters
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .vec_db import FaissVecDB
|
||||
|
||||
__all__ = ["FaissVecDB"]
|
||||
__all__ = ["FaissVecDB"]
|
||||
|
||||
@@ -113,7 +113,8 @@ class FaissVecDB(BaseVecDB):
|
||||
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
||||
)
|
||||
top_k_results = [
|
||||
top_k_results[reranked_result.index] for reranked_result in reranked_results
|
||||
top_k_results[reranked_result.index]
|
||||
for reranked_result in reranked_results
|
||||
]
|
||||
|
||||
return top_k_results
|
||||
|
||||
@@ -22,6 +22,7 @@ class InitialLoader:
|
||||
self.db = db
|
||||
self.logger = logger
|
||||
self.log_broker = log_broker
|
||||
self.webui_dir: str | None = None
|
||||
|
||||
async def start(self):
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
@@ -35,8 +36,10 @@ class InitialLoader:
|
||||
|
||||
core_task = core_lifecycle.start()
|
||||
|
||||
webui_dir = self.webui_dir
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
|
||||
)
|
||||
task = asyncio.gather(
|
||||
core_task, self.dashboard_server.run()
|
||||
|
||||
@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
|
||||
self.strategy_selector = StrategySelector(config)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, check_text: str = None
|
||||
self, event: AstrMessageEvent, check_text: str | None = None
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""检查内容安全"""
|
||||
text = check_text if check_text else event.get_message_str()
|
||||
|
||||
@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
|
||||
self.secret_key = sk
|
||||
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
||||
|
||||
def check(self, content: str):
|
||||
def check(self, content: str) -> tuple[bool, str]:
|
||||
res = self.client.textCensorUserDefined(content)
|
||||
if "conclusionType" not in res:
|
||||
return False, ""
|
||||
|
||||
@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
|
||||
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
||||
# )
|
||||
|
||||
def check(self, content: str) -> bool:
|
||||
def check(self, content: str) -> tuple[bool, str]:
|
||||
for keyword in self.keywords:
|
||||
if re.search(keyword, content):
|
||||
return False, "内容安全检查不通过,匹配到敏感词。"
|
||||
|
||||
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
@@ -36,6 +36,9 @@ async def call_handler(
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
@@ -77,7 +80,7 @@ async def call_event_hook(
|
||||
|
||||
Returns:
|
||||
bool: 如果事件被终止,返回 True
|
||||
# """
|
||||
#"""
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
hook_type, plugins_name=event.plugins_name
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import copy
|
||||
import json
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
@@ -133,6 +134,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
if agent_runner.done():
|
||||
llm_response = agent_runner.get_final_llm_resp()
|
||||
|
||||
if not llm_response:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
||||
)
|
||||
@@ -148,7 +158,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
yield mcp.types.TextContent(
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
@@ -200,7 +210,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
):
|
||||
if not tool.mcp_client:
|
||||
raise ValueError("MCP client is not available for MCP function tools.")
|
||||
res = await tool.mcp_client.session.call_tool(
|
||||
|
||||
session = tool.mcp_client.session
|
||||
if not session:
|
||||
raise ValueError("MCP session is not available for MCP function tools.")
|
||||
res = await session.call_tool(
|
||||
name=tool.name,
|
||||
arguments=tool_args,
|
||||
)
|
||||
@@ -271,11 +285,11 @@ async def run_agent(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
astr_event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
)
|
||||
)
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
@@ -325,7 +339,7 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||
|
||||
async def _get_session_conv(self, event: AstrMessageEvent):
|
||||
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
|
||||
umo = event.unified_msg_origin
|
||||
conv_mgr = self.conv_manager
|
||||
|
||||
@@ -337,6 +351,8 @@ class LLMRequestSubStage(Stage):
|
||||
if not conversation:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
raise RuntimeError("无法创建新的对话。")
|
||||
return conversation
|
||||
|
||||
async def process(
|
||||
@@ -444,7 +460,10 @@ class LLMRequestSubStage(Stage):
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
plugin = star_map.get(tool.handler_module_path)
|
||||
mp = tool.handler_module_path
|
||||
if not mp:
|
||||
continue
|
||||
plugin = star_map.get(mp)
|
||||
if not plugin:
|
||||
continue
|
||||
if plugin.name in event.plugins_name or plugin.reserved:
|
||||
|
||||
@@ -34,12 +34,14 @@ class StarRequestSubStage(Stage):
|
||||
|
||||
for handler in activated_handlers:
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_path not in star_map:
|
||||
continue
|
||||
logger.debug(
|
||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||
md = star_map.get(handler.handler_module_path)
|
||||
if not md:
|
||||
logger.warning(
|
||||
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
|
||||
)
|
||||
continue
|
||||
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
|
||||
try:
|
||||
wrapper = call_handler(event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
@@ -49,7 +51,7 @@ class StarRequestSubStage(Stage):
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
|
||||
if event.is_at_or_wake_command:
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
import random
|
||||
import asyncio
|
||||
import math
|
||||
import traceback
|
||||
import astrbot.core.message.components as Comp
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from ..context import PipelineContext, call_event_hook
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.components import BaseMessageComponent, ComponentType
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -114,6 +112,43 @@ class RespondStage(Stage):
|
||||
# 如果所有组件都为空
|
||||
return True
|
||||
|
||||
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
|
||||
"""检查是否需要分段回复"""
|
||||
if not self.enable_seg:
|
||||
return False
|
||||
|
||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
"qq_official",
|
||||
"weixin_official_account",
|
||||
"dingtalk",
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _extract_comp(
|
||||
self,
|
||||
raw_chain: list[BaseMessageComponent],
|
||||
extract_types: set[ComponentType],
|
||||
modify_raw_chain: bool = True,
|
||||
):
|
||||
extracted = []
|
||||
if modify_raw_chain:
|
||||
remaining = []
|
||||
for comp in raw_chain:
|
||||
if comp.type in extract_types:
|
||||
extracted.append(comp)
|
||||
else:
|
||||
remaining.append(comp)
|
||||
raw_chain[:] = remaining
|
||||
else:
|
||||
extracted = [comp for comp in raw_chain if comp.type in extract_types]
|
||||
|
||||
return extracted
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -123,7 +158,14 @@ class RespondStage(Stage):
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
)
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
if result.async_stream is None:
|
||||
logger.warning("async_stream 为空,跳过发送。")
|
||||
return
|
||||
# 流式结果直接交付平台适配器处理
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented", False
|
||||
@@ -148,87 +190,71 @@ class RespondStage(Stage):
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
||||
non_record_comps = [
|
||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||
]
|
||||
|
||||
if (
|
||||
self.enable_seg
|
||||
and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
# 发送消息链
|
||||
# Record 需要强制单独发送
|
||||
need_separately = {ComponentType.Record}
|
||||
if self.is_seg_reply_required(event):
|
||||
header_comps = self._extract_comp(
|
||||
result.chain,
|
||||
{ComponentType.Reply, ComponentType.At},
|
||||
modify_raw_chain=True,
|
||||
)
|
||||
and event.get_platform_name()
|
||||
not in ["qq_official", "weixin_official_account", "dingtalk"]
|
||||
):
|
||||
decorated_comps = []
|
||||
if self.reply_with_mention:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Comp.At):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
if self.reply_with_quote:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Comp.Reply):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
# leverage lock to guarentee the order of message sending among different events
|
||||
if not result.chain or len(result.chain) == 0:
|
||||
# may fix #2670
|
||||
logger.warning(
|
||||
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
|
||||
)
|
||||
return
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
for comp in result.chain:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
if comp.type in need_separately:
|
||||
await event.send(MessageChain([comp]))
|
||||
else:
|
||||
await event.send(MessageChain([*header_comps, comp]))
|
||||
header_comps.clear()
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
if all(
|
||||
comp.type in {ComponentType.Reply, ComponentType.At}
|
||||
for comp in result.chain
|
||||
):
|
||||
# may fix #2670
|
||||
logger.warning(
|
||||
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
|
||||
)
|
||||
return
|
||||
sep_comps = self._extract_comp(
|
||||
result.chain,
|
||||
need_separately,
|
||||
modify_raw_chain=True,
|
||||
)
|
||||
for comp in sep_comps:
|
||||
chain = MessageChain([comp])
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
await event.send(chain)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {chain}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
chain = MessageChain(result.chain)
|
||||
if result.chain and len(result.chain) > 0:
|
||||
try:
|
||||
await event.send(chain)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {chain}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
await event.send(MessageChain(non_record_comps))
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
logger.info(
|
||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
)
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
|
||||
return
|
||||
|
||||
event.clear_result()
|
||||
|
||||
@@ -24,7 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
|
||||
@@ -55,7 +55,7 @@ class AstrBotMessage:
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||
message_id: str # 消息id
|
||||
group_id: str = "" # 群组id,如果为私聊,则为空
|
||||
group: Group # 群组
|
||||
sender: MessageMember # 发送者
|
||||
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
@@ -64,6 +64,28 @@ class AstrBotMessage:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.timestamp = int(time.time())
|
||||
self.group = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
@property
|
||||
def group_id(self) -> str:
|
||||
"""
|
||||
向后兼容的 group_id 属性
|
||||
群组id,如果为私聊,则为空
|
||||
"""
|
||||
if self.group:
|
||||
return self.group.group_id
|
||||
return ""
|
||||
|
||||
@group_id.setter
|
||||
def group_id(self, value: str):
|
||||
"""设置 group_id"""
|
||||
if value:
|
||||
if self.group:
|
||||
self.group.group_id = value
|
||||
else:
|
||||
self.group = Group(group_id=value)
|
||||
else:
|
||||
self.group = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List
|
||||
from asyncio import Queue
|
||||
from .register import platform_cls_map
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, star_map, EventType
|
||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
|
||||
@@ -66,27 +67,39 @@ class PlatformManager:
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
)
|
||||
case "lark":
|
||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||
from .sources.lark.lark_adapter import (
|
||||
LarkPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "dingtalk":
|
||||
from .sources.dingtalk.dingtalk_adapter import (
|
||||
DingtalkPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "telegram":
|
||||
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
||||
from .sources.telegram.tg_adapter import (
|
||||
TelegramPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wecom":
|
||||
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||
from .sources.wecom.wecom_adapter import (
|
||||
WecomPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "weixin_official_account":
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||
WeixinOfficialAccountPlatformAdapter, # noqa
|
||||
WeixinOfficialAccountPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "discord":
|
||||
from .sources.discord.discord_platform_adapter import (
|
||||
DiscordPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "misskey":
|
||||
from .sources.misskey.misskey_adapter import (
|
||||
MisskeyPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "slack":
|
||||
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
||||
case "satori":
|
||||
from .sources.satori.satori_adapter import SatoriPlatformAdapter # noqa: F401
|
||||
from .sources.satori.satori_adapter import (
|
||||
SatoriPlatformAdapter, # noqa: F401
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||
@@ -115,6 +128,17 @@ class PlatformManager:
|
||||
)
|
||||
)
|
||||
)
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnPlatformLoadedEvent
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.info(
|
||||
f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler()
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
try:
|
||||
|
||||
@@ -182,11 +182,13 @@ class AiocqhttpAdapter(Platform):
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
str(event.sender["user_id"]), event.sender["nickname"]
|
||||
str(event.sender["user_id"]),
|
||||
event.sender.get("card") or event.sender.get("nickname", "N/A"),
|
||||
)
|
||||
if event["message_type"] == "group":
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
abm.group.group_name = event.get("group_name", "N/A")
|
||||
elif event["message_type"] == "private":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
@@ -321,7 +323,9 @@ class AiocqhttpAdapter(Platform):
|
||||
user_id=int(m["data"]["qq"]),
|
||||
no_cache=False,
|
||||
)
|
||||
nickname = at_info.get("nick", "") or at_info.get("nickname", "")
|
||||
nickname = at_info.get("nick", "") or at_info.get(
|
||||
"nickname", ""
|
||||
)
|
||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||
|
||||
abm.message.append(
|
||||
|
||||
@@ -54,9 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
logger.debug(f"send image: {ret}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"钉钉图片处理失败: {e}")
|
||||
logger.warning(f"跳过图片发送: {image_path}")
|
||||
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
|
||||
continue
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
@@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot):
|
||||
await self.on_ready_once_callback()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
|
||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
|
||||
)
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
@@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot):
|
||||
message_data = self._create_message_data(message)
|
||||
await self.on_message_received(message_data)
|
||||
|
||||
|
||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||
"""从交互中提取内容"""
|
||||
interaction_type = interaction.type
|
||||
|
||||
@@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent):
|
||||
self.url = url
|
||||
self.disabled = disabled
|
||||
|
||||
|
||||
class DiscordReference(BaseMessageComponent):
|
||||
"""Discord引用组件"""
|
||||
|
||||
type: str = "discord_reference"
|
||||
|
||||
def __init__(self, message_id: str, channel_id: str):
|
||||
self.message_id = message_id
|
||||
self.channel_id = channel_id
|
||||
@@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def to_discord_view(self) -> discord.ui.View:
|
||||
"""转换为Discord View对象"""
|
||||
view = discord.ui.View(timeout=self.timeout)
|
||||
|
||||
@@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
|
||||
# 解析消息链为 Discord 所需的对象
|
||||
try:
|
||||
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
|
||||
(
|
||||
content,
|
||||
files,
|
||||
view,
|
||||
embeds,
|
||||
reference_message_id,
|
||||
) = await self._parse_to_discord(message)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
||||
return
|
||||
@@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
files.append(
|
||||
discord.File(BytesIO(file_bytes),
|
||||
filename=i.name)
|
||||
discord.File(BytesIO(file_bytes), filename=i.name)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
391
astrbot/core/platform/sources/misskey/misskey_adapter.py
Normal file
391
astrbot/core/platform/sources/misskey/misskey_adapter.py
Normal file
@@ -0,0 +1,391 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Awaitable
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.platform import (
|
||||
AstrBotMessage,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
import astrbot.api.message_components as Comp
|
||||
|
||||
from .misskey_api import MisskeyAPI
|
||||
from .misskey_event import MisskeyPlatformEvent
|
||||
from .misskey_utils import (
|
||||
serialize_message_chain,
|
||||
resolve_message_visibility,
|
||||
is_valid_user_session_id,
|
||||
is_valid_room_session_id,
|
||||
add_at_mention_if_needed,
|
||||
process_files,
|
||||
extract_sender_info,
|
||||
create_base_message,
|
||||
process_at_mention,
|
||||
cache_user_info,
|
||||
cache_room_info,
|
||||
)
|
||||
|
||||
|
||||
@register_platform_adapter("misskey", "Misskey 平台适配器")
|
||||
class MisskeyPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config or {}
|
||||
self.settings = platform_settings or {}
|
||||
self.instance_url = self.config.get("misskey_instance_url", "")
|
||||
self.access_token = self.config.get("misskey_token", "")
|
||||
self.max_message_length = self.config.get("max_message_length", 3000)
|
||||
self.default_visibility = self.config.get(
|
||||
"misskey_default_visibility", "public"
|
||||
)
|
||||
self.local_only = self.config.get("misskey_local_only", False)
|
||||
self.enable_chat = self.config.get("misskey_enable_chat", True)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.api: Optional[MisskeyAPI] = None
|
||||
self._running = False
|
||||
self.client_self_id = ""
|
||||
self._bot_username = ""
|
||||
self._user_cache = {}
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
default_config = {
|
||||
"misskey_instance_url": "",
|
||||
"misskey_token": "",
|
||||
"max_message_length": 3000,
|
||||
"misskey_default_visibility": "public",
|
||||
"misskey_local_only": False,
|
||||
"misskey_enable_chat": True,
|
||||
}
|
||||
default_config.update(self.config)
|
||||
|
||||
return PlatformMetadata(
|
||||
name="misskey",
|
||||
description="Misskey 平台适配器",
|
||||
id=self.config.get("id", "misskey"),
|
||||
default_config_tmpl=default_config,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
if not self.instance_url or not self.access_token:
|
||||
logger.error("[Misskey] 配置不完整,无法启动")
|
||||
return
|
||||
|
||||
self.api = MisskeyAPI(self.instance_url, self.access_token)
|
||||
self._running = True
|
||||
|
||||
try:
|
||||
user_info = await self.api.get_current_user()
|
||||
self.client_self_id = str(user_info.get("id", ""))
|
||||
self._bot_username = user_info.get("username", "")
|
||||
logger.info(
|
||||
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 获取用户信息失败: {e}")
|
||||
self._running = False
|
||||
return
|
||||
|
||||
await self._start_websocket_connection()
|
||||
|
||||
async def _start_websocket_connection(self):
|
||||
backoff_delay = 1.0
|
||||
max_backoff = 300.0
|
||||
backoff_multiplier = 1.5
|
||||
connection_attempts = 0
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
connection_attempts += 1
|
||||
if not self.api:
|
||||
logger.error("[Misskey] API 客户端未初始化")
|
||||
break
|
||||
|
||||
streaming = self.api.get_streaming_client()
|
||||
streaming.add_message_handler("notification", self._handle_notification)
|
||||
if self.enable_chat:
|
||||
streaming.add_message_handler(
|
||||
"newChatMessage", self._handle_chat_message
|
||||
)
|
||||
streaming.add_message_handler("_debug", self._debug_handler)
|
||||
|
||||
if await streaming.connect():
|
||||
logger.info(
|
||||
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
|
||||
)
|
||||
connection_attempts = 0 # 重置计数器
|
||||
await streaming.subscribe_channel("main")
|
||||
if self.enable_chat:
|
||||
await streaming.subscribe_channel("messaging")
|
||||
await streaming.subscribe_channel("messagingIndex")
|
||||
logger.info("[Misskey] 聊天频道已订阅")
|
||||
|
||||
backoff_delay = 1.0 # 重置延迟
|
||||
await streaming.listen()
|
||||
else:
|
||||
logger.error(
|
||||
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
|
||||
)
|
||||
|
||||
if self._running:
|
||||
logger.info(
|
||||
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
|
||||
)
|
||||
await asyncio.sleep(backoff_delay)
|
||||
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
|
||||
|
||||
async def _handle_notification(self, data: Dict[str, Any]):
|
||||
try:
|
||||
logger.debug(
|
||||
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
notification_type = data.get("type")
|
||||
if notification_type in ["mention", "reply", "quote"]:
|
||||
note = data.get("note")
|
||||
if note and self._is_bot_mentioned(note):
|
||||
logger.info(
|
||||
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
|
||||
)
|
||||
message = await self.convert_message(note)
|
||||
event = MisskeyPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.api,
|
||||
)
|
||||
self.commit_event(event)
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 处理通知失败: {e}")
|
||||
|
||||
async def _handle_chat_message(self, data: Dict[str, Any]):
|
||||
try:
|
||||
logger.debug(
|
||||
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
sender_id = str(
|
||||
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
|
||||
)
|
||||
if sender_id == self.client_self_id:
|
||||
return
|
||||
|
||||
room_id = data.get("toRoomId")
|
||||
if room_id:
|
||||
raw_text = data.get("text", "")
|
||||
logger.debug(
|
||||
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
|
||||
)
|
||||
|
||||
message = await self.convert_room_message(data)
|
||||
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
|
||||
else:
|
||||
message = await self.convert_chat_message(data)
|
||||
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
|
||||
|
||||
event = MisskeyPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.api,
|
||||
)
|
||||
self.commit_event(event)
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
|
||||
|
||||
async def _debug_handler(self, data: Dict[str, Any]):
|
||||
logger.debug(
|
||||
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
|
||||
text = note.get("text", "")
|
||||
if not text:
|
||||
return False
|
||||
|
||||
mentions = note.get("mentions", [])
|
||||
if self._bot_username and f"@{self._bot_username}" in text:
|
||||
return True
|
||||
if self.client_self_id in [str(uid) for uid in mentions]:
|
||||
return True
|
||||
|
||||
reply = note.get("reply")
|
||||
if reply and isinstance(reply, dict):
|
||||
reply_user_id = str(reply.get("user", {}).get("id", ""))
|
||||
if reply_user_id == self.client_self_id:
|
||||
return bool(self._bot_username and f"@{self._bot_username}" in text)
|
||||
|
||||
return False
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSession, message_chain: MessageChain
|
||||
) -> Awaitable[Any]:
|
||||
if not self.api:
|
||||
logger.error("[Misskey] API 客户端未初始化")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
try:
|
||||
session_id = session.session_id
|
||||
text, has_at_user = serialize_message_chain(message_chain.chain)
|
||||
|
||||
if not has_at_user and session_id:
|
||||
user_info = self._user_cache.get(session_id)
|
||||
text = add_at_mention_if_needed(text, user_info, has_at_user)
|
||||
|
||||
if not text or not text.strip():
|
||||
logger.warning("[Misskey] 消息内容为空,跳过发送")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
if len(text) > self.max_message_length:
|
||||
text = text[: self.max_message_length] + "..."
|
||||
|
||||
if session_id and is_valid_user_session_id(session_id):
|
||||
from .misskey_utils import extract_user_id_from_session_id
|
||||
|
||||
user_id = extract_user_id_from_session_id(session_id)
|
||||
await self.api.send_message(user_id, text)
|
||||
elif session_id and is_valid_room_session_id(session_id):
|
||||
from .misskey_utils import extract_room_id_from_session_id
|
||||
|
||||
room_id = extract_room_id_from_session_id(session_id)
|
||||
await self.api.send_room_message(room_id, text)
|
||||
else:
|
||||
visibility, visible_user_ids = resolve_message_visibility(
|
||||
user_id=session_id,
|
||||
user_cache=self._user_cache,
|
||||
self_id=self.client_self_id,
|
||||
default_visibility=self.default_visibility,
|
||||
)
|
||||
|
||||
await self.api.create_note(
|
||||
text,
|
||||
visibility=visibility,
|
||||
visible_user_ids=visible_user_ids,
|
||||
local_only=self.local_only,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 发送消息失败: {e}")
|
||||
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
|
||||
sender_info = extract_sender_info(raw_data, is_chat=False)
|
||||
message = create_base_message(
|
||||
raw_data,
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=False,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
cache_user_info(
|
||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
||||
)
|
||||
|
||||
message_parts = []
|
||||
raw_text = raw_data.get("text", "")
|
||||
|
||||
if raw_text:
|
||||
text_parts, processed_text = process_at_mention(
|
||||
message, raw_text, self._bot_username, self.client_self_id
|
||||
)
|
||||
message_parts.extend(text_parts)
|
||||
|
||||
files = raw_data.get("files", [])
|
||||
file_parts = process_files(message, files)
|
||||
message_parts.extend(file_parts)
|
||||
|
||||
message.message_str = (
|
||||
" ".join(part for part in message_parts if part.strip())
|
||||
if message_parts
|
||||
else ""
|
||||
)
|
||||
return message
|
||||
|
||||
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
|
||||
sender_info = extract_sender_info(raw_data, is_chat=True)
|
||||
message = create_base_message(
|
||||
raw_data,
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=True,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
cache_user_info(
|
||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
|
||||
)
|
||||
|
||||
raw_text = raw_data.get("text", "")
|
||||
if raw_text:
|
||||
message.message.append(Comp.Plain(raw_text))
|
||||
|
||||
files = raw_data.get("files", [])
|
||||
process_files(message, files, include_text_parts=False)
|
||||
|
||||
message.message_str = raw_text if raw_text else ""
|
||||
return message
|
||||
|
||||
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
|
||||
sender_info = extract_sender_info(raw_data, is_chat=True)
|
||||
room_id = raw_data.get("toRoomId", "")
|
||||
message = create_base_message(
|
||||
raw_data,
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=False,
|
||||
room_id=room_id,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
|
||||
cache_user_info(
|
||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
||||
)
|
||||
cache_room_info(self._user_cache, raw_data, self.client_self_id)
|
||||
|
||||
raw_text = raw_data.get("text", "")
|
||||
message_parts = []
|
||||
|
||||
if raw_text:
|
||||
if self._bot_username and f"@{self._bot_username}" in raw_text:
|
||||
text_parts, processed_text = process_at_mention(
|
||||
message, raw_text, self._bot_username, self.client_self_id
|
||||
)
|
||||
message_parts.extend(text_parts)
|
||||
else:
|
||||
message.message.append(Comp.Plain(raw_text))
|
||||
message_parts.append(raw_text)
|
||||
|
||||
files = raw_data.get("files", [])
|
||||
file_parts = process_files(message, files)
|
||||
message_parts.extend(file_parts)
|
||||
|
||||
message.message_str = (
|
||||
" ".join(part for part in message_parts if part.strip())
|
||||
if message_parts
|
||||
else ""
|
||||
)
|
||||
return message
|
||||
|
||||
async def terminate(self):
|
||||
self._running = False
|
||||
if self.api:
|
||||
await self.api.close()
|
||||
|
||||
def get_client(self) -> Any:
|
||||
return self.api
|
||||
404
astrbot/core/platform/sources/misskey/misskey_api.py
Normal file
404
astrbot/core/platform/sources/misskey/misskey_api.py
Normal file
@@ -0,0 +1,404 @@
|
||||
import json
|
||||
from typing import Any, Optional, Dict, List, Callable, Awaitable
|
||||
import uuid
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
import websockets
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
|
||||
) from e
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
# Constants
|
||||
API_MAX_RETRIES = 3
|
||||
HTTP_OK = 200
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""Misskey API 基础异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
"""网络连接异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIRateLimitError(APIError):
|
||||
"""API 频率限制异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(APIError):
|
||||
"""认证失败异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketError(APIError):
|
||||
"""WebSocket 连接异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StreamingClient:
|
||||
def __init__(self, instance_url: str, access_token: str):
|
||||
self.instance_url = instance_url.rstrip("/")
|
||||
self.access_token = access_token
|
||||
self.websocket: Optional[Any] = None
|
||||
self.is_connected = False
|
||||
self.message_handlers: Dict[str, Callable] = {}
|
||||
self.channels: Dict[str, str] = {}
|
||||
self._running = False
|
||||
self._last_pong = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
try:
|
||||
ws_url = self.instance_url.replace("https://", "wss://").replace(
|
||||
"http://", "ws://"
|
||||
)
|
||||
ws_url += f"/streaming?i={self.access_token}"
|
||||
|
||||
self.websocket = await websockets.connect(
|
||||
ws_url, ping_interval=30, ping_timeout=10
|
||||
)
|
||||
self.is_connected = True
|
||||
self._running = True
|
||||
|
||||
logger.info("[Misskey WebSocket] 已连接")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
|
||||
self.is_connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
self._running = False
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
self.is_connected = False
|
||||
logger.info("[Misskey WebSocket] 连接已断开")
|
||||
|
||||
async def subscribe_channel(
|
||||
self, channel_type: str, params: Optional[Dict] = None
|
||||
) -> str:
|
||||
if not self.is_connected or not self.websocket:
|
||||
raise WebSocketError("WebSocket 未连接")
|
||||
|
||||
channel_id = str(uuid.uuid4())
|
||||
message = {
|
||||
"type": "connect",
|
||||
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
|
||||
}
|
||||
|
||||
await self.websocket.send(json.dumps(message))
|
||||
self.channels[channel_id] = channel_type
|
||||
return channel_id
|
||||
|
||||
async def unsubscribe_channel(self, channel_id: str):
|
||||
if (
|
||||
not self.is_connected
|
||||
or not self.websocket
|
||||
or channel_id not in self.channels
|
||||
):
|
||||
return
|
||||
|
||||
message = {"type": "disconnect", "body": {"id": channel_id}}
|
||||
|
||||
await self.websocket.send(json.dumps(message))
|
||||
del self.channels[channel_id]
|
||||
|
||||
def add_message_handler(
|
||||
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
|
||||
):
|
||||
self.message_handlers[event_type] = handler
|
||||
|
||||
async def listen(self):
|
||||
if not self.is_connected or not self.websocket:
|
||||
raise WebSocketError("WebSocket 未连接")
|
||||
|
||||
try:
|
||||
async for message in self.websocket:
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._handle_message(data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
|
||||
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
|
||||
self.is_connected = False
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(
|
||||
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
|
||||
)
|
||||
self.is_connected = False
|
||||
except websockets.exceptions.InvalidHandshake as e:
|
||||
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
|
||||
self.is_connected = False
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
|
||||
self.is_connected = False
|
||||
|
||||
async def _handle_message(self, data: Dict[str, Any]):
|
||||
message_type = data.get("type")
|
||||
body = data.get("body", {})
|
||||
|
||||
logger.debug(
|
||||
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
if message_type == "channel":
|
||||
channel_id = body.get("id")
|
||||
event_type = body.get("type")
|
||||
event_body = body.get("body", {})
|
||||
|
||||
logger.debug(
|
||||
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
|
||||
)
|
||||
|
||||
if channel_id in self.channels:
|
||||
channel_type = self.channels[channel_id]
|
||||
handler_key = f"{channel_type}:{event_type}"
|
||||
|
||||
if handler_key in self.message_handlers:
|
||||
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
|
||||
await self.message_handlers[handler_key](event_body)
|
||||
elif event_type in self.message_handlers:
|
||||
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
|
||||
await self.message_handlers[event_type](event_body)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[Misskey WebSocket] 未找到处理器: {handler_key} 或 {event_type}"
|
||||
)
|
||||
if "_debug" in self.message_handlers:
|
||||
await self.message_handlers["_debug"](
|
||||
{
|
||||
"type": event_type,
|
||||
"body": event_body,
|
||||
"channel": channel_type,
|
||||
}
|
||||
)
|
||||
|
||||
elif message_type in self.message_handlers:
|
||||
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
|
||||
await self.message_handlers[message_type](body)
|
||||
else:
|
||||
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
|
||||
if "_debug" in self.message_handlers:
|
||||
await self.message_handlers["_debug"](data)
|
||||
|
||||
|
||||
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
last_exc = None
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except retryable_exceptions as e:
|
||||
last_exc = e
|
||||
continue
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class MisskeyAPI:
|
||||
def __init__(self, instance_url: str, access_token: str):
|
||||
self.instance_url = instance_url.rstrip("/")
|
||||
self.access_token = access_token
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self.streaming: Optional[StreamingClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.streaming:
|
||||
await self.streaming.disconnect()
|
||||
self.streaming = None
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
logger.debug("[Misskey API] 客户端已关闭")
|
||||
|
||||
def get_streaming_client(self) -> StreamingClient:
|
||||
if not self.streaming:
|
||||
self.streaming = StreamingClient(self.instance_url, self.access_token)
|
||||
return self.streaming
|
||||
|
||||
@property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
headers = {"Authorization": f"Bearer {self.access_token}"}
|
||||
self._session = aiohttp.ClientSession(headers=headers)
|
||||
return self._session
|
||||
|
||||
def _handle_response_status(self, status: int, endpoint: str):
|
||||
"""处理 HTTP 响应状态码"""
|
||||
if status == 400:
|
||||
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
|
||||
raise APIError(f"Bad request for {endpoint}")
|
||||
elif status in (401, 403):
|
||||
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
|
||||
raise AuthenticationError(f"Authentication failed for {endpoint}")
|
||||
elif status == 429:
|
||||
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
|
||||
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
|
||||
else:
|
||||
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
|
||||
raise APIConnectionError(f"HTTP {status} for {endpoint}")
|
||||
|
||||
async def _process_response(
|
||||
self, response: aiohttp.ClientResponse, endpoint: str
|
||||
) -> Any:
|
||||
"""处理 API 响应"""
|
||||
if response.status == HTTP_OK:
|
||||
try:
|
||||
result = await response.json()
|
||||
if endpoint == "i/notifications":
|
||||
notifications_data = (
|
||||
result
|
||||
if isinstance(result, list)
|
||||
else result.get("notifications", [])
|
||||
if isinstance(result, dict)
|
||||
else []
|
||||
)
|
||||
if notifications_data:
|
||||
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
|
||||
else:
|
||||
logger.debug(f"API 请求成功: {endpoint}")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"响应不是有效的 JSON 格式: {e}")
|
||||
raise APIConnectionError("Invalid JSON response") from e
|
||||
else:
|
||||
try:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
|
||||
)
|
||||
except Exception:
|
||||
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
|
||||
|
||||
self._handle_response_status(response.status, endpoint)
|
||||
raise APIConnectionError(f"Request failed for {endpoint}")
|
||||
|
||||
@retry_async(
|
||||
max_retries=API_MAX_RETRIES,
|
||||
retryable_exceptions=(APIConnectionError, APIRateLimitError),
|
||||
)
|
||||
async def _make_request(
|
||||
self, endpoint: str, data: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
url = f"{self.instance_url}/api/{endpoint}"
|
||||
payload = {"i": self.access_token}
|
||||
if data:
|
||||
payload.update(data)
|
||||
|
||||
try:
|
||||
async with self.session.post(url, json=payload) as response:
|
||||
return await self._process_response(response, endpoint)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"HTTP 请求错误: {e}")
|
||||
raise APIConnectionError(f"HTTP request failed: {e}") from e
|
||||
|
||||
async def create_note(
|
||||
self,
|
||||
text: str,
|
||||
visibility: str = "public",
|
||||
reply_id: Optional[str] = None,
|
||||
visible_user_ids: Optional[List[str]] = None,
|
||||
local_only: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建新贴文"""
|
||||
data: Dict[str, Any] = {
|
||||
"text": text,
|
||||
"visibility": visibility,
|
||||
"localOnly": local_only,
|
||||
}
|
||||
if reply_id:
|
||||
data["replyId"] = reply_id
|
||||
if visible_user_ids and visibility == "specified":
|
||||
data["visibleUserIds"] = visible_user_ids
|
||||
|
||||
result = await self._make_request("notes/create", data)
|
||||
note_id = result.get("createdNote", {}).get("id", "unknown")
|
||||
logger.debug(f"发帖成功,note_id: {note_id}")
|
||||
return result
|
||||
|
||||
async def get_current_user(self) -> Dict[str, Any]:
|
||||
"""获取当前用户信息"""
|
||||
return await self._make_request("i", {})
|
||||
|
||||
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
|
||||
"""发送聊天消息"""
|
||||
result = await self._make_request(
|
||||
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
|
||||
)
|
||||
message_id = result.get("id", "unknown")
|
||||
logger.debug(f"聊天发送成功,message_id: {message_id}")
|
||||
return result
|
||||
|
||||
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
|
||||
"""发送房间消息"""
|
||||
result = await self._make_request(
|
||||
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
|
||||
)
|
||||
message_id = result.get("id", "unknown")
|
||||
logger.debug(f"房间消息发送成功,message_id: {message_id}")
|
||||
return result
|
||||
|
||||
async def get_messages(
|
||||
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取聊天消息历史"""
|
||||
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
|
||||
if since_id:
|
||||
data["sinceId"] = since_id
|
||||
|
||||
result = await self._make_request("chat/messages/user-timeline", data)
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
|
||||
return []
|
||||
|
||||
async def get_mentions(
|
||||
self, limit: int = 10, since_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取提及通知"""
|
||||
data: Dict[str, Any] = {"limit": limit}
|
||||
if since_id:
|
||||
data["sinceId"] = since_id
|
||||
data["includeTypes"] = ["mention", "reply", "quote"]
|
||||
|
||||
result = await self._make_request("i/notifications", data)
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
elif isinstance(result, dict) and "notifications" in result:
|
||||
return result["notifications"]
|
||||
else:
|
||||
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
|
||||
return []
|
||||
123
astrbot/core/platform/sources/misskey/misskey_event.py
Normal file
123
astrbot/core/platform/sources/misskey/misskey_event.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
|
||||
from astrbot.api.message_components import Plain
|
||||
|
||||
from .misskey_utils import (
|
||||
serialize_message_chain,
|
||||
resolve_visibility_from_raw_message,
|
||||
is_valid_user_session_id,
|
||||
is_valid_room_session_id,
|
||||
add_at_mention_if_needed,
|
||||
extract_user_id_from_session_id,
|
||||
extract_room_id_from_session_id,
|
||||
)
|
||||
|
||||
|
||||
class MisskeyPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
def _is_system_command(self, message_str: str) -> bool:
|
||||
"""检测是否为系统指令"""
|
||||
if not message_str or not message_str.strip():
|
||||
return False
|
||||
|
||||
system_prefixes = ["/", "!", "#", ".", "^"]
|
||||
message_trimmed = message_str.strip()
|
||||
|
||||
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
content, has_at = serialize_message_chain(message.chain)
|
||||
|
||||
if not content:
|
||||
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
|
||||
return
|
||||
|
||||
try:
|
||||
original_message_id = getattr(self.message_obj, "message_id", None)
|
||||
raw_message = getattr(self.message_obj, "raw_message", {})
|
||||
|
||||
if raw_message and not has_at:
|
||||
user_data = raw_message.get("user", {})
|
||||
user_info = {
|
||||
"username": user_data.get("username", ""),
|
||||
"nickname": user_data.get("name", user_data.get("username", "")),
|
||||
}
|
||||
content = add_at_mention_if_needed(content, user_info, has_at)
|
||||
|
||||
# 根据会话类型选择发送方式
|
||||
if hasattr(self.client, "send_message") and is_valid_user_session_id(
|
||||
self.session_id
|
||||
):
|
||||
user_id = extract_user_id_from_session_id(self.session_id)
|
||||
await self.client.send_message(user_id, content)
|
||||
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
|
||||
self.session_id
|
||||
):
|
||||
room_id = extract_room_id_from_session_id(self.session_id)
|
||||
await self.client.send_room_message(room_id, content)
|
||||
elif original_message_id and hasattr(self.client, "create_note"):
|
||||
visibility, visible_user_ids = resolve_visibility_from_raw_message(
|
||||
raw_message
|
||||
)
|
||||
await self.client.create_note(
|
||||
content,
|
||||
reply_id=original_message_id,
|
||||
visibility=visibility,
|
||||
visible_user_ids=visible_user_ids,
|
||||
)
|
||||
elif hasattr(self.client, "create_note"):
|
||||
logger.debug("[MisskeyEvent] 创建新帖子")
|
||||
await self.client.create_note(content)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MisskeyEvent] 发送失败: {e}")
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
327
astrbot/core/platform/sources/misskey/misskey_utils.py
Normal file
327
astrbot/core/platform/sources/misskey/misskey_utils.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Misskey 平台适配器通用工具函数"""
|
||||
|
||||
from typing import Dict, Any, List, Tuple, Optional, Union
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
|
||||
|
||||
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
|
||||
"""将消息链序列化为文本字符串"""
|
||||
text_parts = []
|
||||
has_at = False
|
||||
|
||||
def process_component(component):
|
||||
nonlocal has_at
|
||||
if isinstance(component, Comp.Plain):
|
||||
return component.text
|
||||
elif isinstance(component, Comp.File):
|
||||
file_name = getattr(component, "name", "文件")
|
||||
return f"[文件: {file_name}]"
|
||||
elif isinstance(component, Comp.At):
|
||||
has_at = True
|
||||
return f"@{component.qq}"
|
||||
elif hasattr(component, "text"):
|
||||
text = getattr(component, "text", "")
|
||||
if "@" in text:
|
||||
has_at = True
|
||||
return text
|
||||
else:
|
||||
return str(component)
|
||||
|
||||
for component in chain:
|
||||
if isinstance(component, Comp.Node) and component.content:
|
||||
for node_comp in component.content:
|
||||
result = process_component(node_comp)
|
||||
if result:
|
||||
text_parts.append(result)
|
||||
else:
|
||||
result = process_component(component)
|
||||
if result:
|
||||
text_parts.append(result)
|
||||
|
||||
return "".join(text_parts), has_at
|
||||
|
||||
|
||||
def resolve_message_visibility(
|
||||
user_id: Optional[str],
|
||||
user_cache: Dict[str, Any],
|
||||
self_id: Optional[str],
|
||||
default_visibility: str = "public",
|
||||
) -> Tuple[str, Optional[List[str]]]:
|
||||
"""解析 Misskey 消息的可见性设置"""
|
||||
visibility = default_visibility
|
||||
visible_user_ids = None
|
||||
|
||||
if user_id and user_cache:
|
||||
user_info = user_cache.get(user_id)
|
||||
if user_info:
|
||||
original_visibility = user_info.get("visibility", default_visibility)
|
||||
if original_visibility == "specified":
|
||||
visibility = "specified"
|
||||
original_visible_users = user_info.get("visible_user_ids", [])
|
||||
users_to_include = [user_id]
|
||||
if self_id:
|
||||
users_to_include.append(self_id)
|
||||
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||
else:
|
||||
visibility = original_visibility
|
||||
|
||||
return visibility, visible_user_ids
|
||||
|
||||
|
||||
def resolve_visibility_from_raw_message(
|
||||
raw_message: Dict[str, Any], self_id: Optional[str] = None
|
||||
) -> Tuple[str, Optional[List[str]]]:
|
||||
"""从原始消息数据中解析可见性设置"""
|
||||
visibility = "public"
|
||||
visible_user_ids = None
|
||||
|
||||
if not raw_message:
|
||||
return visibility, visible_user_ids
|
||||
|
||||
original_visibility = raw_message.get("visibility", "public")
|
||||
if original_visibility == "specified":
|
||||
visibility = "specified"
|
||||
original_visible_users = raw_message.get("visibleUserIds", [])
|
||||
sender_id = raw_message.get("userId", "")
|
||||
|
||||
users_to_include = []
|
||||
if sender_id:
|
||||
users_to_include.append(sender_id)
|
||||
if self_id:
|
||||
users_to_include.append(self_id)
|
||||
|
||||
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||
else:
|
||||
visibility = original_visibility
|
||||
|
||||
return visibility, visible_user_ids
|
||||
|
||||
|
||||
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
|
||||
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
|
||||
if not isinstance(session_id, str) or "%" not in session_id:
|
||||
return False
|
||||
|
||||
parts = session_id.split("%")
|
||||
return (
|
||||
len(parts) == 2
|
||||
and parts[0] == "chat"
|
||||
and bool(parts[1])
|
||||
and parts[1] != "unknown"
|
||||
)
|
||||
|
||||
|
||||
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
|
||||
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
|
||||
if not isinstance(session_id, str) or "%" not in session_id:
|
||||
return False
|
||||
|
||||
parts = session_id.split("%")
|
||||
return (
|
||||
len(parts) == 2
|
||||
and parts[0] == "room"
|
||||
and bool(parts[1])
|
||||
and parts[1] != "unknown"
|
||||
)
|
||||
|
||||
|
||||
def extract_user_id_from_session_id(session_id: str) -> str:
|
||||
"""从 session_id 中提取用户 ID"""
|
||||
if "%" in session_id:
|
||||
parts = session_id.split("%")
|
||||
if len(parts) >= 2:
|
||||
return parts[1]
|
||||
return session_id
|
||||
|
||||
|
||||
def extract_room_id_from_session_id(session_id: str) -> str:
|
||||
"""从 session_id 中提取房间 ID"""
|
||||
if "%" in session_id:
|
||||
parts = session_id.split("%")
|
||||
if len(parts) >= 2 and parts[0] == "room":
|
||||
return parts[1]
|
||||
return session_id
|
||||
|
||||
|
||||
def add_at_mention_if_needed(
|
||||
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
|
||||
) -> str:
|
||||
"""如果需要且没有@用户,则添加@用户"""
|
||||
if has_at or not user_info:
|
||||
return text
|
||||
|
||||
username = user_info.get("username")
|
||||
nickname = user_info.get("nickname")
|
||||
|
||||
if username:
|
||||
mention = f"@{username}"
|
||||
if not text.startswith(mention):
|
||||
text = f"{mention}\n{text}".strip()
|
||||
elif nickname:
|
||||
mention = f"@{nickname}"
|
||||
if not text.startswith(mention):
|
||||
text = f"{mention}\n{text}".strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
|
||||
"""创建文件组件和描述文本"""
|
||||
file_url = file_info.get("url", "")
|
||||
file_name = file_info.get("name", "未知文件")
|
||||
file_type = file_info.get("type", "")
|
||||
|
||||
if file_type.startswith("image/"):
|
||||
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
|
||||
elif file_type.startswith("audio/"):
|
||||
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
|
||||
elif file_type.startswith("video/"):
|
||||
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
|
||||
else:
|
||||
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
|
||||
|
||||
|
||||
def process_files(
|
||||
message: AstrBotMessage, files: list, include_text_parts: bool = True
|
||||
) -> list:
|
||||
"""处理文件列表,添加到消息组件中并返回文本描述"""
|
||||
file_parts = []
|
||||
for file_info in files:
|
||||
component, part_text = create_file_component(file_info)
|
||||
message.message.append(component)
|
||||
if include_text_parts:
|
||||
file_parts.append(part_text)
|
||||
return file_parts
|
||||
|
||||
|
||||
def extract_sender_info(
|
||||
raw_data: Dict[str, Any], is_chat: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""提取发送者信息"""
|
||||
if is_chat:
|
||||
sender = raw_data.get("fromUser", {})
|
||||
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
|
||||
else:
|
||||
sender = raw_data.get("user", {})
|
||||
sender_id = str(sender.get("id", ""))
|
||||
|
||||
return {
|
||||
"sender": sender,
|
||||
"sender_id": sender_id,
|
||||
"nickname": sender.get("name", sender.get("username", "")),
|
||||
"username": sender.get("username", ""),
|
||||
}
|
||||
|
||||
|
||||
def create_base_message(
|
||||
raw_data: Dict[str, Any],
|
||||
sender_info: Dict[str, Any],
|
||||
client_self_id: str,
|
||||
is_chat: bool = False,
|
||||
room_id: Optional[str] = None,
|
||||
unique_session: bool = False,
|
||||
) -> AstrBotMessage:
|
||||
"""创建基础消息对象"""
|
||||
message = AstrBotMessage()
|
||||
message.raw_message = raw_data
|
||||
message.message = []
|
||||
|
||||
message.sender = MessageMember(
|
||||
user_id=sender_info["sender_id"],
|
||||
nickname=sender_info["nickname"],
|
||||
)
|
||||
|
||||
if room_id:
|
||||
session_prefix = "room"
|
||||
session_id = f"{session_prefix}%{room_id}"
|
||||
if unique_session:
|
||||
session_id += f"_{sender_info['sender_id']}"
|
||||
message.type = MessageType.GROUP_MESSAGE
|
||||
message.group_id = room_id
|
||||
elif is_chat:
|
||||
session_prefix = "chat"
|
||||
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
else:
|
||||
session_prefix = "note"
|
||||
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
message.session_id = (
|
||||
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
|
||||
)
|
||||
message.message_id = str(raw_data.get("id", ""))
|
||||
message.self_id = client_self_id
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def process_at_mention(
|
||||
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
|
||||
) -> Tuple[List[str], str]:
|
||||
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
|
||||
message_parts = []
|
||||
|
||||
if not raw_text:
|
||||
return message_parts, ""
|
||||
|
||||
if bot_username and raw_text.startswith(f"@{bot_username}"):
|
||||
at_mention = f"@{bot_username}"
|
||||
message.message.append(Comp.At(qq=client_self_id))
|
||||
remaining_text = raw_text[len(at_mention) :].strip()
|
||||
if remaining_text:
|
||||
message.message.append(Comp.Plain(remaining_text))
|
||||
message_parts.append(remaining_text)
|
||||
return message_parts, remaining_text
|
||||
else:
|
||||
message.message.append(Comp.Plain(raw_text))
|
||||
message_parts.append(raw_text)
|
||||
return message_parts, raw_text
|
||||
|
||||
|
||||
def cache_user_info(
|
||||
user_cache: Dict[str, Any],
|
||||
sender_info: Dict[str, Any],
|
||||
raw_data: Dict[str, Any],
|
||||
client_self_id: str,
|
||||
is_chat: bool = False,
|
||||
):
|
||||
"""缓存用户信息"""
|
||||
if is_chat:
|
||||
user_cache_data = {
|
||||
"username": sender_info["username"],
|
||||
"nickname": sender_info["nickname"],
|
||||
"visibility": "specified",
|
||||
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
|
||||
}
|
||||
else:
|
||||
user_cache_data = {
|
||||
"username": sender_info["username"],
|
||||
"nickname": sender_info["nickname"],
|
||||
"visibility": raw_data.get("visibility", "public"),
|
||||
"visible_user_ids": raw_data.get("visibleUserIds", []),
|
||||
}
|
||||
|
||||
user_cache[sender_info["sender_id"]] = user_cache_data
|
||||
|
||||
|
||||
def cache_room_info(
|
||||
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
|
||||
):
|
||||
"""缓存房间信息"""
|
||||
room_data = raw_data.get("toRoom")
|
||||
room_id = raw_data.get("toRoomId")
|
||||
|
||||
if room_data and room_id:
|
||||
room_cache_key = f"room:{room_id}"
|
||||
user_cache[room_cache_key] = {
|
||||
"room_id": room_id,
|
||||
"room_name": room_data.get("name", ""),
|
||||
"room_description": room_data.get("description", ""),
|
||||
"owner_id": room_data.get("ownerId", ""),
|
||||
"visibility": "specified",
|
||||
"visible_user_ids": [client_self_id],
|
||||
}
|
||||
@@ -94,10 +94,15 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
plain_text,
|
||||
image_base64,
|
||||
image_path,
|
||||
record_file_path
|
||||
record_file_path,
|
||||
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||
|
||||
if not plain_text and not image_base64 and not image_path and not record_file_path:
|
||||
if (
|
||||
not plain_text
|
||||
and not image_base64
|
||||
and not image_path
|
||||
and not record_file_path
|
||||
):
|
||||
return
|
||||
|
||||
payload = {
|
||||
@@ -118,7 +123,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
if record_file_path: # group record msg
|
||||
if record_file_path: # group record msg
|
||||
media = await self.upload_group_and_c2c_record(
|
||||
record_file_path, 3, group_openid=source.group_openid
|
||||
)
|
||||
@@ -134,9 +139,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
if record_file_path: # c2c record
|
||||
if record_file_path: # c2c record
|
||||
media = await self.upload_group_and_c2c_record(
|
||||
record_file_path, 3, openid = source.author.user_openid
|
||||
record_file_path, 3, openid=source.author.user_openid
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
@@ -190,58 +195,55 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
async def upload_group_and_c2c_record(
|
||||
self,
|
||||
file_source: str,
|
||||
file_type: int,
|
||||
srv_send_msg: bool = False,
|
||||
**kwargs
|
||||
self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs
|
||||
) -> Optional[Media]:
|
||||
"""
|
||||
上传媒体文件
|
||||
"""
|
||||
# 构建基础payload
|
||||
payload = {
|
||||
"file_type": file_type,
|
||||
"srv_send_msg": srv_send_msg
|
||||
}
|
||||
|
||||
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
|
||||
|
||||
# 处理文件数据
|
||||
if os.path.exists(file_source):
|
||||
# 读取本地文件
|
||||
async with aiofiles.open(file_source, 'rb') as f:
|
||||
async with aiofiles.open(file_source, "rb") as f:
|
||||
file_content = await f.read()
|
||||
# use base64 encode
|
||||
payload["file_data"] = base64.b64encode(file_content).decode('utf-8')
|
||||
payload["file_data"] = base64.b64encode(file_content).decode("utf-8")
|
||||
else:
|
||||
# 使用URL
|
||||
payload["url"] = file_source
|
||||
|
||||
|
||||
# 添加接收者信息和确定路由
|
||||
if "openid" in kwargs:
|
||||
payload["openid"] = kwargs["openid"]
|
||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||
elif "group_openid" in kwargs:
|
||||
payload["group_openid"] =kwargs["group_openid"]
|
||||
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"])
|
||||
payload["group_openid"] = kwargs["group_openid"]
|
||||
route = Route(
|
||||
"POST",
|
||||
"/v2/groups/{group_openid}/files",
|
||||
group_openid=kwargs["group_openid"],
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
# 使用底层HTTP请求
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
|
||||
if result:
|
||||
return Media(
|
||||
file_uuid=result.get("file_uuid"),
|
||||
file_info=result.get("file_info"),
|
||||
ttl=result.get("ttl", 0),
|
||||
file_id=result.get("id", "")
|
||||
file_id=result.get("id", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传请求错误: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def post_c2c_message(
|
||||
self,
|
||||
openid: str,
|
||||
@@ -286,19 +288,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
image_base64 = image_base64.removeprefix("base64://")
|
||||
elif isinstance(i, Record):
|
||||
if i.file:
|
||||
record_wav_path = await i.convert_to_file_path() # wav 路径
|
||||
record_wav_path = await i.convert_to_file_path() # wav 路径
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
record_tecent_silk_path = os.path.join(
|
||||
temp_dir, f"{uuid.uuid4()}.silk"
|
||||
)
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path)
|
||||
duration = await wav_to_tencent_silk(
|
||||
record_wav_path, record_tecent_silk_path
|
||||
)
|
||||
if duration > 0:
|
||||
record_file_path = record_tecent_silk_path
|
||||
else:
|
||||
record_file_path = None
|
||||
record_file_path = None
|
||||
logger.error("转换音频格式时出错:音频时长不大于0")
|
||||
except Exception as e:
|
||||
logger.error(f"处理语音时出错: {e}")
|
||||
record_file_path = None
|
||||
record_file_path = None
|
||||
else:
|
||||
logger.debug(f"qq_official 忽略 {i.type}")
|
||||
return plain_text, image_base64, image_file_path, record_file_path
|
||||
|
||||
@@ -17,7 +17,14 @@ from astrbot.api.platform import (
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
from astrbot.api.message_components import Plain, Image, At, File, Record
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
At,
|
||||
File,
|
||||
Record,
|
||||
Reply,
|
||||
)
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
|
||||
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
|
||||
)
|
||||
self.token = self.config.get("satori_token", "")
|
||||
self.endpoint = self.config.get(
|
||||
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
|
||||
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
|
||||
)
|
||||
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
|
||||
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
|
||||
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="satori",
|
||||
description="Satori 通用协议适配器",
|
||||
id=self.config["id"],
|
||||
)
|
||||
|
||||
self.ws: Optional[ClientConnection] = None
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.sequence = 0
|
||||
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
|
||||
return self.metadata
|
||||
|
||||
def _is_websocket_closed(self, ws) -> bool:
|
||||
"""检查WebSocket连接是否已关闭"""
|
||||
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
|
||||
|
||||
abm.self_id = login.get("user", {}).get("id", "")
|
||||
|
||||
content = message.get("content", "")
|
||||
abm.message = await self.parse_satori_elements(content)
|
||||
# 消息链
|
||||
abm.message = []
|
||||
|
||||
content = message.get("content", "")
|
||||
|
||||
quote = message.get("quote")
|
||||
content_for_parsing = content # 副本
|
||||
|
||||
# 提取<quote>标签
|
||||
if "<quote" in content:
|
||||
try:
|
||||
quote_info = await self._extract_quote_element(content)
|
||||
if quote_info:
|
||||
quote = quote_info["quote"]
|
||||
content_for_parsing = quote_info["content_without_quote"]
|
||||
except Exception as e:
|
||||
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
|
||||
|
||||
if quote:
|
||||
# 引用消息
|
||||
quote_abm = await self._convert_quote_message(quote)
|
||||
if quote_abm:
|
||||
sender_id = quote_abm.sender.user_id
|
||||
if isinstance(sender_id, str) and sender_id.isdigit():
|
||||
sender_id = int(sender_id)
|
||||
elif not isinstance(sender_id, int):
|
||||
sender_id = 0 # 默认值
|
||||
|
||||
reply_component = Reply(
|
||||
id=quote_abm.message_id,
|
||||
chain=quote_abm.message,
|
||||
sender_id=quote_abm.sender.user_id,
|
||||
sender_nickname=quote_abm.sender.nickname,
|
||||
time=quote_abm.timestamp,
|
||||
message_str=quote_abm.message_str,
|
||||
text=quote_abm.message_str,
|
||||
qq=sender_id,
|
||||
)
|
||||
abm.message.append(reply_component)
|
||||
|
||||
# 解析消息内容
|
||||
content_elements = await self.parse_satori_elements(content_for_parsing)
|
||||
abm.message.extend(content_elements)
|
||||
|
||||
# parse message_str
|
||||
abm.message_str = ""
|
||||
for comp in abm.message:
|
||||
for comp in content_elements:
|
||||
if isinstance(comp, Plain):
|
||||
abm.message_str += comp.text
|
||||
|
||||
@@ -333,6 +386,163 @@ class SatoriPlatformAdapter(Platform):
|
||||
logger.error(f"转换 Satori 消息失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_namespace_prefixes(self, content: str) -> set:
|
||||
"""提取XML内容中的命名空间前缀"""
|
||||
prefixes = set()
|
||||
|
||||
# 查找所有标签
|
||||
i = 0
|
||||
while i < len(content):
|
||||
# 查找开始标签
|
||||
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
|
||||
# 找到标签结束位置
|
||||
tag_end = content.find(">", i)
|
||||
if tag_end != -1:
|
||||
# 提取标签内容
|
||||
tag_content = content[i + 1 : tag_end]
|
||||
# 检查是否有命名空间前缀
|
||||
if ":" in tag_content and "xmlns:" not in tag_content:
|
||||
# 分割标签名
|
||||
parts = tag_content.split()
|
||||
if parts:
|
||||
tag_name = parts[0]
|
||||
if ":" in tag_name:
|
||||
prefix = tag_name.split(":")[0]
|
||||
# 确保是有效的命名空间前缀
|
||||
if (
|
||||
prefix.isalnum()
|
||||
or prefix.replace("_", "").isalnum()
|
||||
):
|
||||
prefixes.add(prefix)
|
||||
i = tag_end + 1
|
||||
else:
|
||||
i += 1
|
||||
# 查找结束标签
|
||||
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
|
||||
# 找到标签结束位置
|
||||
tag_end = content.find(">", i)
|
||||
if tag_end != -1:
|
||||
# 提取标签内容
|
||||
tag_content = content[i + 2 : tag_end]
|
||||
# 检查是否有命名空间前缀
|
||||
if ":" in tag_content:
|
||||
prefix = tag_content.split(":")[0]
|
||||
# 确保是有效的命名空间前缀
|
||||
if prefix.isalnum() or prefix.replace("_", "").isalnum():
|
||||
prefixes.add(prefix)
|
||||
i = tag_end + 1
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return prefixes
|
||||
|
||||
async def _extract_quote_element(self, content: str) -> Optional[dict]:
|
||||
"""提取<quote>标签信息"""
|
||||
try:
|
||||
# 处理命名空间前缀问题
|
||||
processed_content = content
|
||||
if ":" in content and not content.startswith("<root"):
|
||||
prefixes = self._extract_namespace_prefixes(content)
|
||||
|
||||
# 构建命名空间声明
|
||||
ns_declarations = " ".join(
|
||||
[
|
||||
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
||||
for prefix in prefixes
|
||||
]
|
||||
)
|
||||
|
||||
# 包装内容
|
||||
processed_content = f"<root {ns_declarations}>{content}</root>"
|
||||
elif not content.startswith("<root"):
|
||||
processed_content = f"<root>{content}</root>"
|
||||
else:
|
||||
processed_content = content
|
||||
|
||||
root = ET.fromstring(processed_content)
|
||||
|
||||
# 查找<quote>标签
|
||||
quote_element = None
|
||||
for elem in root.iter():
|
||||
tag_name = elem.tag
|
||||
if "}" in tag_name:
|
||||
tag_name = tag_name.split("}")[1]
|
||||
if tag_name.lower() == "quote":
|
||||
quote_element = elem
|
||||
break
|
||||
|
||||
if quote_element is not None:
|
||||
# 提取quote标签的属性
|
||||
quote_id = quote_element.get("id", "")
|
||||
|
||||
# 提取<quote>标签内部的内容
|
||||
inner_content = ""
|
||||
if quote_element.text:
|
||||
inner_content += quote_element.text
|
||||
for child in quote_element:
|
||||
inner_content += ET.tostring(
|
||||
child, encoding="unicode", method="xml"
|
||||
)
|
||||
if child.tail:
|
||||
inner_content += child.tail
|
||||
|
||||
# 构造移除了<quote>标签的内容
|
||||
content_without_quote = content.replace(
|
||||
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
|
||||
)
|
||||
|
||||
return {
|
||||
"quote": {"id": quote_id, "content": inner_content},
|
||||
"content_without_quote": content_without_quote,
|
||||
}
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"提取<quote>标签时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
|
||||
"""转换引用消息"""
|
||||
try:
|
||||
quote_abm = AstrBotMessage()
|
||||
quote_abm.message_id = quote.get("id", "")
|
||||
|
||||
# 解析引用消息的发送者
|
||||
quote_author = quote.get("author", {})
|
||||
if quote_author:
|
||||
quote_abm.sender = MessageMember(
|
||||
user_id=quote_author.get("id", ""),
|
||||
nickname=quote_author.get("nick", quote_author.get("name", "")),
|
||||
)
|
||||
else:
|
||||
# 如果没有作者信息,使用默认值
|
||||
quote_abm.sender = MessageMember(
|
||||
user_id=quote.get("user_id", ""),
|
||||
nickname="内容",
|
||||
)
|
||||
|
||||
# 解析引用消息内容
|
||||
quote_content = quote.get("content", "")
|
||||
quote_abm.message = await self.parse_satori_elements(quote_content)
|
||||
|
||||
quote_abm.message_str = ""
|
||||
for comp in quote_abm.message:
|
||||
if isinstance(comp, Plain):
|
||||
quote_abm.message_str += comp.text
|
||||
|
||||
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
|
||||
|
||||
# 如果没有任何内容,使用默认文本
|
||||
if not quote_abm.message_str.strip():
|
||||
quote_abm.message_str = "[引用消息]"
|
||||
|
||||
return quote_abm
|
||||
except Exception as e:
|
||||
logger.error(f"转换引用消息失败: {e}")
|
||||
return None
|
||||
|
||||
async def parse_satori_elements(self, content: str) -> list:
|
||||
"""解析 Satori 消息元素"""
|
||||
elements = []
|
||||
@@ -341,12 +551,35 @@ class SatoriPlatformAdapter(Platform):
|
||||
return elements
|
||||
|
||||
try:
|
||||
wrapped_content = f"<root>{content}</root>"
|
||||
root = ET.fromstring(wrapped_content)
|
||||
# 处理命名空间前缀问题
|
||||
processed_content = content
|
||||
if ":" in content and not content.startswith("<root"):
|
||||
prefixes = self._extract_namespace_prefixes(content)
|
||||
|
||||
# 构建命名空间声明
|
||||
ns_declarations = " ".join(
|
||||
[
|
||||
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
||||
for prefix in prefixes
|
||||
]
|
||||
)
|
||||
|
||||
# 包装内容
|
||||
processed_content = f"<root {ns_declarations}>{content}</root>"
|
||||
elif not content.startswith("<root"):
|
||||
processed_content = f"<root>{content}</root>"
|
||||
else:
|
||||
processed_content = content
|
||||
|
||||
root = ET.fromstring(processed_content)
|
||||
await self._parse_xml_node(root, elements)
|
||||
except ET.ParseError as e:
|
||||
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
|
||||
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
|
||||
# 如果解析失败,将整个内容当作纯文本
|
||||
if content.strip():
|
||||
elements.append(Plain(text=content))
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
|
||||
raise e
|
||||
|
||||
# 如果没有解析到任何元素,将整个内容当作纯文本
|
||||
@@ -361,7 +594,12 @@ class SatoriPlatformAdapter(Platform):
|
||||
elements.append(Plain(text=node.text))
|
||||
|
||||
for child in node:
|
||||
tag_name = child.tag.lower()
|
||||
# 获取标签名,去除命名空间前缀
|
||||
tag_name = child.tag
|
||||
if "}" in tag_name:
|
||||
tag_name = tag_name.split("}")[1]
|
||||
tag_name = tag_name.lower()
|
||||
|
||||
attrs = child.attrib
|
||||
|
||||
if tag_name == "at":
|
||||
@@ -372,31 +610,59 @@ class SatoriPlatformAdapter(Platform):
|
||||
src = attrs.get("src", "")
|
||||
if not src:
|
||||
continue
|
||||
if src.startswith("data:image/"):
|
||||
src = src.split(",")[1]
|
||||
elements.append(Image.fromBase64(src))
|
||||
elif src.startswith("http"):
|
||||
elements.append(Image.fromURL(src))
|
||||
else:
|
||||
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
|
||||
elements.append(Image(file=src))
|
||||
|
||||
elif tag_name == "file":
|
||||
src = attrs.get("src", "")
|
||||
name = attrs.get("name", "文件")
|
||||
if src:
|
||||
elements.append(File(file=src, name=name))
|
||||
elements.append(File(name=name, file=src))
|
||||
|
||||
elif tag_name in ("audio", "record"):
|
||||
src = attrs.get("src", "")
|
||||
if not src:
|
||||
continue
|
||||
if src.startswith("data:audio/"):
|
||||
src = src.split(",")[1]
|
||||
elements.append(Record.fromBase64(src))
|
||||
elif src.startswith("http"):
|
||||
elements.append(Record.fromURL(src))
|
||||
elements.append(Record(file=src))
|
||||
|
||||
elif tag_name == "quote":
|
||||
# quote标签已经被特殊处理
|
||||
pass
|
||||
|
||||
elif tag_name == "face":
|
||||
face_id = attrs.get("id", "")
|
||||
face_name = attrs.get("name", "")
|
||||
face_type = attrs.get("type", "")
|
||||
|
||||
if face_name:
|
||||
elements.append(Plain(text=f"[表情:{face_name}]"))
|
||||
elif face_id and face_type:
|
||||
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
|
||||
elif face_id:
|
||||
elements.append(Plain(text=f"[表情ID:{face_id}]"))
|
||||
else:
|
||||
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
|
||||
elements.append(Plain(text="[表情]"))
|
||||
|
||||
elif tag_name == "ark":
|
||||
# 作为纯文本添加到消息链中
|
||||
data = attrs.get("data", "")
|
||||
if data:
|
||||
import html
|
||||
|
||||
decoded_data = html.unescape(data)
|
||||
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
||||
else:
|
||||
elements.append(Plain(text="[ARK卡片]"))
|
||||
|
||||
elif tag_name == "json":
|
||||
# JSON标签 视为ARK卡片消息
|
||||
data = attrs.get("data", "")
|
||||
if data:
|
||||
import html
|
||||
|
||||
decoded_data = html.unescape(data)
|
||||
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
||||
else:
|
||||
elements.append(Plain(text="[JSON卡片]"))
|
||||
|
||||
else:
|
||||
# 未知标签,递归处理其内容
|
||||
|
||||
@@ -17,6 +17,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
session_id: str,
|
||||
adapter: "SatoriPlatformAdapter",
|
||||
):
|
||||
# 更新平台元数据
|
||||
if adapter and hasattr(adapter, "logins") and adapter.logins:
|
||||
current_login = adapter.logins[0]
|
||||
platform_name = current_login.get("platform", "satori")
|
||||
user = current_login.get("user", {})
|
||||
user_id = user.get("id", "") if user else ""
|
||||
if not platform_meta.id and user_id:
|
||||
platform_meta.id = f"{platform_name}({user_id})"
|
||||
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.adapter = adapter
|
||||
self.platform = None
|
||||
|
||||
@@ -308,7 +308,9 @@ class SlackAdapter(Platform):
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
return base64_content
|
||||
else:
|
||||
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
|
||||
logger.error(
|
||||
f"Failed to download slack file: {resp.status} {await resp.text()}"
|
||||
)
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
|
||||
async def run(self) -> Awaitable[Any]:
|
||||
|
||||
@@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||
}
|
||||
file_url = response["files"][0]["permalink"]
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
return chunks
|
||||
|
||||
@classmethod
|
||||
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str):
|
||||
async def send_with_client(
|
||||
cls, client: ExtBot, message: MessageChain, user_name: str
|
||||
):
|
||||
image_path = None
|
||||
|
||||
has_reply = False
|
||||
@@ -216,7 +218,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
delta = ""
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class WebChatQueueMgr:
|
||||
def __init__(self) -> None:
|
||||
self.queues = {}
|
||||
@@ -30,4 +31,5 @@ class WebChatQueueMgr:
|
||||
"""Check if a queue exists for the given conversation ID"""
|
||||
return conversation_id in self.queues
|
||||
|
||||
|
||||
webchat_queue_mgr = WebChatQueueMgr()
|
||||
|
||||
@@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform):
|
||||
def _extract_auth_key(self, data):
|
||||
"""Helper method to extract auth_key from response data."""
|
||||
if isinstance(data, dict):
|
||||
auth_keys = data.get("authKeys") # 新接口
|
||||
auth_keys = data.get("authKeys") # 新接口
|
||||
if isinstance(auth_keys, list) and auth_keys:
|
||||
return auth_keys[0]
|
||||
elif isinstance(data, list) and data: # 旧接口
|
||||
elif isinstance(data, list) and data: # 旧接口
|
||||
return data[0]
|
||||
return None
|
||||
|
||||
@@ -234,7 +234,9 @@ class WeChatPadProAdapter(Platform):
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
|
||||
logger.error(
|
||||
f"生成授权码失败: {response.status}, {await response.text()}"
|
||||
)
|
||||
return
|
||||
|
||||
response_data = await response.json()
|
||||
@@ -245,7 +247,9 @@ class WeChatPadProAdapter(Platform):
|
||||
if self.auth_key:
|
||||
logger.info("成功获取授权码")
|
||||
else:
|
||||
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
|
||||
logger.error(
|
||||
f"生成授权码成功但未找到授权码: {response_data}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"生成授权码失败: {response_data}")
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
|
||||
@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
|
||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
||||
data = {
|
||||
"token": token,
|
||||
"cursor": cursor,
|
||||
"limit": limit,
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
return self._post("kf/sync_msg", data=data)
|
||||
|
||||
def get_service_state(self, open_kfid, external_userid):
|
||||
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
}
|
||||
return self._post("kf/service_state/get", data=data)
|
||||
|
||||
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
||||
def trans_service_state(
|
||||
self, open_kfid, external_userid, service_state, servicer_userid=""
|
||||
):
|
||||
"""
|
||||
变更会话状态
|
||||
|
||||
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
"""
|
||||
return self._get("kf/customer/get_upgrade_service_config")
|
||||
|
||||
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
||||
def upgrade_service(
|
||||
self, open_kfid, external_userid, service_type, member=None, groupchat=None
|
||||
):
|
||||
"""
|
||||
为客户升级为专员或客户群服务
|
||||
|
||||
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||
return self._post("kf/get_corp_statistic", data=data)
|
||||
|
||||
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
||||
def get_servicer_statistic(
|
||||
self, start_time, end_time, open_kfid=None, servicer_userid=None
|
||||
):
|
||||
"""
|
||||
获取「客户数据统计」接待人员明细数据
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from optionaldict import optionaldict
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
|
||||
class WeChatKFMessage(BaseWeChatAPI):
|
||||
"""
|
||||
发送微信客服消息
|
||||
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
|
||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||
)
|
||||
|
||||
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
||||
def send_msgmenu(
|
||||
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "msgmenu",
|
||||
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
||||
"msgmenu": {
|
||||
"head_content": head_content,
|
||||
"list": menu_list,
|
||||
"tail_content": tail_content,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
||||
def send_location(
|
||||
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "location",
|
||||
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
||||
"msgmenu": {
|
||||
"name": name,
|
||||
"address": address,
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
||||
def send_miniprogram(
|
||||
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "miniprogram",
|
||||
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
||||
"msgmenu": {
|
||||
"appid": appid,
|
||||
"title": title,
|
||||
"thumb_media_id": thumb_media_id,
|
||||
"pagepath": pagepath,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||
result = await asyncio.wait_for(
|
||||
asyncio.shield(future), 60
|
||||
) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
|
||||
@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
|
||||
@@ -65,13 +65,16 @@ class AssistantMessageSegment:
|
||||
role: str = "assistant"
|
||||
|
||||
def to_dict(self):
|
||||
ret = {
|
||||
ret: dict[str, str | list[dict]] = {
|
||||
"role": self.role,
|
||||
}
|
||||
if self.content:
|
||||
ret["content"] = self.content
|
||||
if self.tool_calls:
|
||||
ret["tool_calls"] = self.tool_calls
|
||||
tool_calls_dict = [
|
||||
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
|
||||
]
|
||||
ret["tool_calls"] = tool_calls_dict
|
||||
return ret
|
||||
|
||||
|
||||
@@ -117,7 +120,14 @@ class ProviderRequest:
|
||||
"""模型名称,为 None 时使用提供商的默认模型"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||
return (
|
||||
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
|
||||
f"image_count={len(self.image_urls or [])}, "
|
||||
f"func_tool={self.func_tool}, "
|
||||
f"contexts={self._print_friendly_context()}, "
|
||||
f"system_prompt={self.system_prompt}, "
|
||||
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
@@ -297,6 +307,7 @@ class LLMResponse:
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankResult:
|
||||
index: int
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
from typing import Dict, List, Awaitable
|
||||
from typing import Dict, List, Awaitable, Callable, Any
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
|
||||
@@ -109,7 +109,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Awaitable,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
) -> FuncTool:
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
@@ -132,7 +132,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Awaitable,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
) -> None:
|
||||
"""添加函数调用工具
|
||||
|
||||
@@ -220,7 +220,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
ready_future: asyncio.Future = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
try:
|
||||
|
||||
@@ -7,7 +7,13 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .provider import (
|
||||
Provider,
|
||||
STTProvider,
|
||||
TTSProvider,
|
||||
EmbeddingProvider,
|
||||
RerankProvider,
|
||||
)
|
||||
from .register import llm_tools, provider_cls_map
|
||||
from ..persona_mgr import PersonaManager
|
||||
|
||||
@@ -38,7 +44,12 @@ class ProviderManager:
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map: dict[str, Provider] = {}
|
||||
self.rerank_provider_insts: List[RerankProvider] = []
|
||||
"""加载的 Rerank Provider 的实例"""
|
||||
self.inst_map: dict[
|
||||
str,
|
||||
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
||||
] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
|
||||
@@ -87,19 +98,31 @@ class ProviderManager:
|
||||
)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
|
||||
prov = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
||||
prov, TTSProvider
|
||||
):
|
||||
self.curr_tts_provider_inst = prov
|
||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
||||
prov, STTProvider
|
||||
):
|
||||
self.curr_stt_provider_inst = prov
|
||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
||||
prov, Provider
|
||||
):
|
||||
self.curr_provider_inst = prov
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
return self.inst_map.get(provider_id)
|
||||
|
||||
def get_using_provider(self, provider_type: ProviderType, umo=None):
|
||||
def get_using_provider(
|
||||
self, provider_type: ProviderType, umo=None
|
||||
) -> Provider | STTProvider | TTSProvider | None:
|
||||
"""获取正在使用的提供商实例。
|
||||
|
||||
Args:
|
||||
@@ -303,12 +326,14 @@ class ProviderManager:
|
||||
provider_metadata = provider_cls_map[provider_config["type"]]
|
||||
try:
|
||||
# 按任务实例化提供商
|
||||
cls_type = provider_metadata.cls_type
|
||||
if not cls_type:
|
||||
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
||||
return
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
@@ -327,9 +352,7 @@ class ProviderManager:
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
@@ -345,7 +368,7 @@ class ProviderManager:
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
inst = provider_metadata.cls_type(
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.selected_default_persona,
|
||||
@@ -366,13 +389,16 @@ class ProviderManager:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
elif provider_metadata.provider_type == ProviderType.RERANK:
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.rerank_provider_insts.append(inst)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
@@ -388,6 +414,7 @@ class ProviderManager:
|
||||
|
||||
# 和配置文件保持同步
|
||||
config_ids = [provider["id"] for provider in self.providers_config]
|
||||
logger.debug(f"providers in user's config: {config_ids}")
|
||||
for key in list(self.inst_map.keys()):
|
||||
if key not in config_ids:
|
||||
await self.terminate_provider(key)
|
||||
@@ -426,11 +453,17 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
if self.inst_map[provider_id] in self.provider_insts:
|
||||
self.provider_insts.remove(self.inst_map[provider_id])
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if isinstance(prov_inst, Provider):
|
||||
self.provider_insts.remove(prov_inst)
|
||||
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||||
self.stt_provider_insts.remove(self.inst_map[provider_id])
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if isinstance(prov_inst, STTProvider):
|
||||
self.stt_provider_insts.remove(prov_inst)
|
||||
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||||
self.tts_provider_insts.remove(self.inst_map[provider_id])
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if isinstance(prov_inst, TTSProvider):
|
||||
self.tts_provider_insts.remove(prov_inst)
|
||||
|
||||
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||||
self.curr_provider_inst = None
|
||||
|
||||
@@ -98,7 +98,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
|
||||
# FishAudio的reference_id通常是32位十六进制字符串
|
||||
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
|
||||
pattern = r'^[a-fA-F0-9]{32}$'
|
||||
pattern = r"^[a-fA-F0-9]{32}$"
|
||||
return bool(re.match(pattern, reference_id.strip()))
|
||||
|
||||
async def _generate_request(self, text: str) -> dict:
|
||||
|
||||
@@ -99,12 +99,15 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for key in to_del:
|
||||
del payloads[key]
|
||||
|
||||
model = payloads.get("model", "")
|
||||
# 针对 qwen3 非 thinking 模型的特殊处理:非流式调用必须设置 enable_thinking=false
|
||||
if "qwen3" in model.lower() and "thinking" not in model.lower():
|
||||
extra_body["enable_thinking"] = False
|
||||
# 读取并合并 custom_extra_body 配置
|
||||
custom_extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
if isinstance(custom_extra_body, dict):
|
||||
extra_body.update(custom_extra_body)
|
||||
|
||||
model = payloads.get("model", "").lower()
|
||||
|
||||
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
|
||||
elif model == "deepseek-reasoner" and "tools" in payloads:
|
||||
if model == "deepseek-reasoner" and "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
@@ -137,6 +140,12 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
# 不在默认参数中的参数放在 extra_body 中
|
||||
extra_body = {}
|
||||
|
||||
# 读取并合并 custom_extra_body 配置
|
||||
custom_extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
if isinstance(custom_extra_body, dict):
|
||||
extra_body.update(custom_extra_body)
|
||||
|
||||
to_del = []
|
||||
for key in payloads.keys():
|
||||
if key not in self.default_params:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import aiohttp
|
||||
from astrbot import logger
|
||||
from ..provider import RerankProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType, RerankResult
|
||||
@@ -44,6 +45,11 @@ class VLLMRerankProvider(RerankProvider):
|
||||
response_data = await response.json()
|
||||
results = response_data.get("results", [])
|
||||
|
||||
if not results:
|
||||
logger.warning(
|
||||
f"Rerank API 返回了空的列表数据。原始响应: {response_data}"
|
||||
)
|
||||
|
||||
return [
|
||||
RerankResult(
|
||||
index=result["index"],
|
||||
|
||||
@@ -6,6 +6,7 @@ from astrbot.core.provider.provider import (
|
||||
TTSProvider,
|
||||
STTProvider,
|
||||
EmbeddingProvider,
|
||||
RerankProvider,
|
||||
)
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -23,7 +24,7 @@ from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from typing import Awaitable, Any, Callable
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
@@ -103,9 +104,14 @@ class Context:
|
||||
"""
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
||||
return self.provider_manager.inst_map.get(provider_id)
|
||||
def get_provider_by_id(
|
||||
self, provider_id: str
|
||||
) -> (
|
||||
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
|
||||
):
|
||||
"""通过 ID 获取对应的 LLM Provider。"""
|
||||
prov = self.provider_manager.inst_map.get(provider_id)
|
||||
return prov
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
@@ -130,34 +136,43 @@ class Context:
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||
"""
|
||||
return self.provider_manager.get_using_provider(
|
||||
prov = self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
if prov and not isinstance(prov, Provider):
|
||||
raise ValueError("返回的 Provider 不是 Provider 类型")
|
||||
return prov
|
||||
|
||||
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
|
||||
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
|
||||
"""
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
return self.provider_manager.get_using_provider(
|
||||
prov = self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
umo=umo,
|
||||
)
|
||||
if prov and not isinstance(prov, TTSProvider):
|
||||
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
|
||||
return prov
|
||||
|
||||
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
|
||||
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
|
||||
"""
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
return self.provider_manager.get_using_provider(
|
||||
prov = self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||
umo=umo,
|
||||
)
|
||||
if prov and not isinstance(prov, STTProvider):
|
||||
raise ValueError("返回的 Provider 不是 STTProvider 类型")
|
||||
return prov
|
||||
|
||||
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
||||
"""获取 AstrBot 的配置。"""
|
||||
@@ -245,7 +260,11 @@ class Context:
|
||||
"""
|
||||
|
||||
def register_llm_tool(
|
||||
self, name: str, func_args: list, desc: str, func_obj: Awaitable
|
||||
self,
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
func_obj: Callable[..., Awaitable[Any]],
|
||||
) -> None:
|
||||
"""
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
@@ -267,9 +286,7 @@ class Context:
|
||||
desc=desc,
|
||||
)
|
||||
star_handlers_registry.append(md)
|
||||
self.provider_manager.llm_tools.add_func(
|
||||
name, func_args, desc, func_obj, func_obj
|
||||
)
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
|
||||
@@ -281,7 +298,7 @@ class Context:
|
||||
command_name: str,
|
||||
desc: str,
|
||||
priority: int,
|
||||
awaitable: Awaitable,
|
||||
awaitable: Callable[..., Awaitable[Any]],
|
||||
use_regex=False,
|
||||
ignore_prefix=False,
|
||||
):
|
||||
|
||||
@@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter):
|
||||
def __init__(
|
||||
self,
|
||||
group_name: str,
|
||||
alias: set = None,
|
||||
parent_group: CommandGroupFilter = None,
|
||||
alias: set | None = None,
|
||||
parent_group: CommandGroupFilter | None = None,
|
||||
):
|
||||
self.group_name = group_name
|
||||
self.alias = alias if alias else set()
|
||||
@@ -54,8 +54,8 @@ class CommandGroupFilter(HandlerFilter):
|
||||
self,
|
||||
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
||||
prefix: str = "",
|
||||
event: AstrMessageEvent = None,
|
||||
cfg: AstrBotConfig = None,
|
||||
event: AstrMessageEvent | None = None,
|
||||
cfg: AstrBotConfig | None = None,
|
||||
) -> str:
|
||||
result = ""
|
||||
for sub_filter in sub_command_filters:
|
||||
@@ -113,8 +113,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
)
|
||||
raise ValueError(
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
||||
+ tree
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||
)
|
||||
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
|
||||
@@ -2,7 +2,6 @@ import enum
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from typing import Union
|
||||
|
||||
|
||||
class PlatformAdapterType(enum.Flag):
|
||||
@@ -19,6 +18,7 @@ class PlatformAdapterType(enum.Flag):
|
||||
VOCECHAT = enum.auto()
|
||||
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
||||
SATORI = enum.auto()
|
||||
MISSKEY = enum.auto()
|
||||
ALL = (
|
||||
AIOCQHTTP
|
||||
| QQOFFICIAL
|
||||
@@ -33,6 +33,7 @@ class PlatformAdapterType(enum.Flag):
|
||||
| VOCECHAT
|
||||
| WEIXIN_OFFICIAL_ACCOUNT
|
||||
| SATORI
|
||||
| MISSKEY
|
||||
)
|
||||
|
||||
|
||||
@@ -50,15 +51,19 @@ ADAPTER_NAME_2_TYPE = {
|
||||
"vocechat": PlatformAdapterType.VOCECHAT,
|
||||
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
||||
"satori": PlatformAdapterType.SATORI,
|
||||
"misskey": PlatformAdapterType.MISSKEY,
|
||||
}
|
||||
|
||||
|
||||
class PlatformAdapterTypeFilter(HandlerFilter):
|
||||
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
|
||||
self.type_or_str = platform_adapter_type_or_str
|
||||
def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str):
|
||||
if isinstance(platform_adapter_type_or_str, str):
|
||||
self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str)
|
||||
else:
|
||||
self.platform_type = platform_adapter_type_or_str
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
adapter_name = event.get_platform_name()
|
||||
if adapter_name in ADAPTER_NAME_2_TYPE:
|
||||
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
|
||||
if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None:
|
||||
return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type)
|
||||
return False
|
||||
|
||||
@@ -8,6 +8,7 @@ from .star_handler import (
|
||||
register_permission_type,
|
||||
register_custom_filter,
|
||||
register_on_astrbot_loaded,
|
||||
register_on_platform_loaded,
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_llm_tool,
|
||||
@@ -26,6 +27,7 @@ __all__ = [
|
||||
"register_permission_type",
|
||||
"register_custom_filter",
|
||||
"register_on_astrbot_loaded",
|
||||
"register_on_platform_loaded",
|
||||
"register_on_llm_request",
|
||||
"register_on_llm_response",
|
||||
"register_llm_tool",
|
||||
|
||||
@@ -5,7 +5,9 @@ from astrbot.core.star import StarMetadata, star_map
|
||||
_warned_register_star = False
|
||||
|
||||
|
||||
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
||||
def register_star(
|
||||
name: str, author: str, desc: str, version: str, repo: str | None = None
|
||||
):
|
||||
"""注册一个插件(Star)。
|
||||
|
||||
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
||||
|
||||
@@ -12,7 +12,7 @@ from ..filter.platform_adapter_type import (
|
||||
from ..filter.permission import PermissionTypeFilter, PermissionType
|
||||
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
||||
from ..filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from typing import Awaitable, Any, Callable
|
||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.agent.agent import Agent
|
||||
@@ -20,15 +20,19 @@ from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
def get_handler_full_name(awaitable: Awaitable) -> str:
|
||||
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
|
||||
"""获取 Handler 的全名"""
|
||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||
|
||||
|
||||
def get_handler_or_create(
|
||||
handler: Awaitable, event_type: EventType, dont_add=False, **kwargs
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
event_type: EventType,
|
||||
dont_add=False,
|
||||
**kwargs,
|
||||
) -> StarHandlerMetadata:
|
||||
"""获取 Handler 或者创建一个新的 Handler"""
|
||||
handler_full_name = get_handler_full_name(handler)
|
||||
@@ -59,22 +63,35 @@ def get_handler_or_create(
|
||||
|
||||
|
||||
def register_command(
|
||||
command_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
||||
command_name: str | None = None,
|
||||
sub_command: str | None = None,
|
||||
alias: set | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""注册一个 Command."""
|
||||
new_command = None
|
||||
add_to_event_filters = False
|
||||
if isinstance(command_name, RegisteringCommandable):
|
||||
# 子指令
|
||||
parent_command_names = command_name.parent_group.get_complete_command_names()
|
||||
new_command = CommandFilter(
|
||||
sub_command, alias, None, parent_command_names=parent_command_names
|
||||
)
|
||||
command_name.parent_group.add_sub_command_filter(new_command)
|
||||
if sub_command is not None:
|
||||
parent_command_names = (
|
||||
command_name.parent_group.get_complete_command_names()
|
||||
)
|
||||
new_command = CommandFilter(
|
||||
sub_command, alias, None, parent_command_names=parent_command_names
|
||||
)
|
||||
command_name.parent_group.add_sub_command_filter(new_command)
|
||||
else:
|
||||
logger.warning(
|
||||
f"注册指令{command_name} 的子指令时未提供 sub_command 参数。"
|
||||
)
|
||||
else:
|
||||
# 裸指令
|
||||
new_command = CommandFilter(command_name, alias, None)
|
||||
add_to_event_filters = True
|
||||
if command_name is None:
|
||||
logger.warning("注册裸指令时未提供 command_name 参数。")
|
||||
else:
|
||||
new_command = CommandFilter(command_name, alias, None)
|
||||
add_to_event_filters = True
|
||||
|
||||
def decorator(awaitable):
|
||||
if not add_to_event_filters:
|
||||
@@ -84,8 +101,9 @@ def register_command(
|
||||
handler_md = get_handler_or_create(
|
||||
awaitable, EventType.AdapterMessageEvent, **kwargs
|
||||
)
|
||||
new_command.init_handler_md(handler_md)
|
||||
handler_md.event_filters.append(new_command)
|
||||
if new_command:
|
||||
new_command.init_handler_md(handler_md)
|
||||
handler_md.event_filters.append(new_command)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
@@ -163,26 +181,38 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
||||
|
||||
|
||||
def register_command_group(
|
||||
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
||||
command_group_name: str | None = None,
|
||||
sub_command: str | None = None,
|
||||
alias: set | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""注册一个 CommandGroup"""
|
||||
new_group = None
|
||||
if isinstance(command_group_name, RegisteringCommandable):
|
||||
# 子指令组
|
||||
new_group = CommandGroupFilter(
|
||||
sub_command, alias, parent_group=command_group_name.parent_group
|
||||
)
|
||||
command_group_name.parent_group.add_sub_command_filter(new_group)
|
||||
if sub_command is None:
|
||||
logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定")
|
||||
else:
|
||||
new_group = CommandGroupFilter(
|
||||
sub_command, alias, parent_group=command_group_name.parent_group
|
||||
)
|
||||
command_group_name.parent_group.add_sub_command_filter(new_group)
|
||||
else:
|
||||
# 根指令组
|
||||
new_group = CommandGroupFilter(command_group_name, alias)
|
||||
if command_group_name is None:
|
||||
logger.warning("根指令组的名称未指定")
|
||||
else:
|
||||
new_group = CommandGroupFilter(command_group_name, alias)
|
||||
|
||||
def decorator(obj):
|
||||
# 根指令组
|
||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(new_group)
|
||||
if new_group:
|
||||
handler_md = get_handler_or_create(
|
||||
obj, EventType.AdapterMessageEvent, **kwargs
|
||||
)
|
||||
handler_md.event_filters.append(new_group)
|
||||
|
||||
return RegisteringCommandable(new_group)
|
||||
return RegisteringCommandable(new_group)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -267,6 +297,18 @@ def register_on_astrbot_loaded(**kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_platform_loaded(**kwargs):
|
||||
"""
|
||||
当平台加载完成时
|
||||
"""
|
||||
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnPlatformLoadedEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_llm_request(**kwargs):
|
||||
"""当有 LLM 请求时的事件
|
||||
|
||||
@@ -311,7 +353,7 @@ def register_on_llm_response(**kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def register_llm_tool(name: str = None, **kwargs):
|
||||
def register_llm_tool(name: str | None = None, **kwargs):
|
||||
"""为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||
@@ -349,9 +391,10 @@ def register_llm_tool(name: str = None, **kwargs):
|
||||
if kwargs.get("registering_agent"):
|
||||
registering_agent = kwargs["registering_agent"]
|
||||
|
||||
def decorator(awaitable: Awaitable):
|
||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||
docstring = docstring_parser.parse(awaitable.__doc__)
|
||||
func_doc = awaitable.__doc__ or ""
|
||||
docstring = docstring_parser.parse(func_doc)
|
||||
args = []
|
||||
for arg in docstring.params:
|
||||
if arg.type_name not in SUPPORTED_TYPES:
|
||||
@@ -367,18 +410,18 @@ def register_llm_tool(name: str = None, **kwargs):
|
||||
)
|
||||
# print(llm_tool_name, registering_agent)
|
||||
if not registering_agent:
|
||||
doc_desc = docstring.description.strip() if docstring.description else ""
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(
|
||||
llm_tool_name, args, docstring.description.strip(), md.handler
|
||||
)
|
||||
llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler)
|
||||
else:
|
||||
assert isinstance(registering_agent, RegisteringAgent)
|
||||
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
||||
if registering_agent._agent.tools is None:
|
||||
registering_agent._agent.tools = []
|
||||
registering_agent._agent.tools.append(llm_tools.spec_to_func(
|
||||
llm_tool_name, args, docstring.description.strip(), awaitable
|
||||
))
|
||||
|
||||
desc = docstring.description.strip() if docstring.description else ""
|
||||
tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable)
|
||||
registering_agent._agent.tools.append(tool)
|
||||
|
||||
return awaitable
|
||||
|
||||
@@ -399,8 +442,8 @@ class RegisteringAgent:
|
||||
def register_agent(
|
||||
name: str,
|
||||
instruction: str,
|
||||
tools: list[str | FunctionTool] = None,
|
||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
|
||||
tools: list[str | FunctionTool] | None = None,
|
||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
|
||||
):
|
||||
"""注册一个 Agent
|
||||
|
||||
@@ -412,7 +455,7 @@ def register_agent(
|
||||
"""
|
||||
tools_ = tools or []
|
||||
|
||||
def decorator(awaitable: Awaitable):
|
||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||
AstrAgent = Agent[AstrAgentContext]
|
||||
agent = AstrAgent(
|
||||
name=name,
|
||||
@@ -421,7 +464,7 @@ def register_agent(
|
||||
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
)
|
||||
handoff_tool = HandoffTool(agent=agent)
|
||||
handoff_tool.handler=awaitable
|
||||
handoff_tool.handler = awaitable
|
||||
llm_tools.func_list.append(handoff_tool)
|
||||
return RegisteringAgent(agent)
|
||||
|
||||
|
||||
@@ -84,7 +84,10 @@ class SessionPluginManager:
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put(
|
||||
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
|
||||
"session_plugin_config",
|
||||
session_plugin_config,
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -137,6 +140,9 @@ class SessionPluginManager:
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
if plugin.name is None:
|
||||
continue
|
||||
|
||||
# 检查插件是否在当前会话中启用
|
||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id, plugin.name
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
from .star import star_map
|
||||
|
||||
@@ -34,26 +34,33 @@ class StarHandlerRegistry(Generic[T]):
|
||||
) -> List[StarHandlerMetadata]:
|
||||
handlers = []
|
||||
for handler in self._handlers:
|
||||
# 过滤事件类型
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
# 过滤启用状态
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
# 过滤插件白名单
|
||||
if plugins_name is not None and plugins_name != ["*"]:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not plugin:
|
||||
continue
|
||||
if (
|
||||
plugin.name not in plugins_name
|
||||
and event_type != EventType.OnAstrBotLoadedEvent
|
||||
and event_type
|
||||
not in (
|
||||
EventType.OnAstrBotLoadedEvent,
|
||||
EventType.OnPlatformLoadedEvent,
|
||||
)
|
||||
and not plugin.reserved
|
||||
):
|
||||
continue
|
||||
handlers.append(handler)
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None:
|
||||
return self.star_handlers_map.get(full_name, None)
|
||||
|
||||
def get_handlers_by_module_name(
|
||||
@@ -80,7 +87,7 @@ class StarHandlerRegistry(Generic[T]):
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
star_handlers_registry = StarHandlerRegistry() # type: ignore
|
||||
|
||||
|
||||
class EventType(enum.Enum):
|
||||
@@ -90,6 +97,7 @@ class EventType(enum.Enum):
|
||||
"""
|
||||
|
||||
OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成
|
||||
OnPlatformLoadedEvent = enum.auto() # 平台加载完成
|
||||
|
||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
@@ -115,7 +123,7 @@ class StarHandlerMetadata:
|
||||
handler_module_path: str
|
||||
"""Handler 所在的模块路径。"""
|
||||
|
||||
handler: Awaitable
|
||||
handler: Callable[..., Awaitable[Any]]
|
||||
"""Handler 的函数对象,应当是一个异步函数"""
|
||||
|
||||
event_filters: List[HandlerFilter]
|
||||
|
||||
@@ -43,7 +43,7 @@ class PluginManager:
|
||||
self.updator = PluginUpdator()
|
||||
|
||||
self.context = context
|
||||
self.context._star_manager = self
|
||||
self.context._star_manager = self # type: ignore
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
@@ -478,9 +478,10 @@ class PluginManager:
|
||||
if isinstance(func_tool, HandoffTool):
|
||||
need_apply = []
|
||||
sub_tools = func_tool.agent.tools
|
||||
for sub_tool in sub_tools:
|
||||
if isinstance(sub_tool, FunctionTool):
|
||||
need_apply.append(sub_tool)
|
||||
if sub_tools:
|
||||
for sub_tool in sub_tools:
|
||||
if isinstance(sub_tool, FunctionTool):
|
||||
need_apply.append(sub_tool)
|
||||
else:
|
||||
need_apply = [func_tool]
|
||||
|
||||
@@ -686,6 +687,9 @@ class PluginManager:
|
||||
)
|
||||
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
if plugin.module_path is None or root_dir_name is None:
|
||||
raise Exception(f"插件 {plugin_name} 数据不完整,无法卸载。")
|
||||
|
||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||
|
||||
try:
|
||||
@@ -791,15 +795,17 @@ class PluginManager:
|
||||
if star_metadata.star_cls is None:
|
||||
return
|
||||
|
||||
if '__del__' in star_metadata.star_cls_type.__dict__:
|
||||
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None, star_metadata.star_cls.__del__
|
||||
)
|
||||
elif 'terminate' in star_metadata.star_cls_type.__dict__:
|
||||
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||
await star_metadata.star_cls.terminate()
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if plugin is None:
|
||||
raise Exception(f"插件 {plugin_name} 不存在。")
|
||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
||||
if plugin.module_path in inactivated_plugins:
|
||||
|
||||
@@ -22,7 +22,7 @@ import inspect
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
||||
from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
|
||||
@@ -30,8 +30,13 @@ from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter
|
||||
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import (
|
||||
AiocqhttpMessageEvent,
|
||||
)
|
||||
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import (
|
||||
AiocqhttpAdapter,
|
||||
)
|
||||
|
||||
|
||||
class StarTools:
|
||||
"""
|
||||
@@ -77,7 +82,11 @@ class StarTools:
|
||||
|
||||
@classmethod
|
||||
async def send_message_by_id(
|
||||
cls, type: str, id: str, message_chain: MessageChain, platform: str = "aiocqhttp"
|
||||
cls,
|
||||
type: str,
|
||||
id: str,
|
||||
message_chain: MessageChain,
|
||||
platform: str = "aiocqhttp",
|
||||
):
|
||||
"""
|
||||
根据 id(例如qq号, 群号等) 直接, 主动地发送消息
|
||||
@@ -92,7 +101,9 @@ class StarTools:
|
||||
raise ValueError("StarTools not initialized")
|
||||
platforms = cls._context.platform_manager.get_insts()
|
||||
if platform == "aiocqhttp":
|
||||
adapter = next((p for p in platforms if isinstance(p, AiocqhttpAdapter)), None)
|
||||
adapter = next(
|
||||
(p for p in platforms if isinstance(p, AiocqhttpAdapter)), None
|
||||
)
|
||||
if adapter is None:
|
||||
raise ValueError("未找到适配器: AiocqhttpAdapter")
|
||||
await AiocqhttpMessageEvent.send_message(
|
||||
@@ -115,7 +126,7 @@ class StarTools:
|
||||
message_str: str,
|
||||
message_id: str = "",
|
||||
raw_message: object = None,
|
||||
group_id: str = ""
|
||||
group_id: str = "",
|
||||
) -> AstrBotMessage:
|
||||
"""
|
||||
创建一个AstrBot消息对象
|
||||
@@ -152,7 +163,6 @@ class StarTools:
|
||||
@classmethod
|
||||
async def create_event(
|
||||
cls, abm: AstrBotMessage, platform: str = "aiocqhttp", is_wake: bool = True
|
||||
|
||||
) -> None:
|
||||
"""
|
||||
创建并提交事件到指定平台
|
||||
@@ -167,7 +177,9 @@ class StarTools:
|
||||
raise ValueError("StarTools not initialized")
|
||||
platforms = cls._context.platform_manager.get_insts()
|
||||
if platform == "aiocqhttp":
|
||||
adapter = next((p for p in platforms if isinstance(p, AiocqhttpAdapter)), None)
|
||||
adapter = next(
|
||||
(p for p in platforms if isinstance(p, AiocqhttpAdapter)), None
|
||||
)
|
||||
if adapter is None:
|
||||
raise ValueError("未找到适配器: AiocqhttpAdapter")
|
||||
event = AiocqhttpMessageEvent(
|
||||
@@ -209,7 +221,11 @@ class StarTools:
|
||||
|
||||
@classmethod
|
||||
def register_llm_tool(
|
||||
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
|
||||
cls,
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
func_obj: Callable[..., Awaitable[Any]],
|
||||
) -> None:
|
||||
"""
|
||||
为函数调用(function-calling/tools-use)添加工具
|
||||
@@ -277,7 +293,9 @@ class StarTools:
|
||||
if not plugin_name:
|
||||
raise ValueError("无法获取插件名称")
|
||||
|
||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
||||
data_dir = Path(
|
||||
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
|
||||
)
|
||||
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -32,6 +32,9 @@ class PluginUpdator(RepoZipUpdator):
|
||||
if not repo_url:
|
||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||
|
||||
if not plugin.root_dir_name:
|
||||
raise Exception(f"插件 {plugin.name} 的根目录名未指定。")
|
||||
|
||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
|
||||
@@ -227,9 +227,11 @@ async def download_dashboard(
|
||||
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
||||
|
||||
if latest or len(str(version)) != 40:
|
||||
logger.info(f"准备下载 {version} 发行版本的 AstrBot WebUI 文件")
|
||||
ver_name = "latest" if latest else version
|
||||
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
|
||||
logger.info(
|
||||
f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}"
|
||||
)
|
||||
try:
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
except BaseException as _:
|
||||
@@ -241,24 +243,10 @@ async def download_dashboard(
|
||||
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
else:
|
||||
logger.info(f"准备下载指定版本的 AstrBot WebUI: {version}")
|
||||
|
||||
url = (
|
||||
"https://api.github.com/repos/AstrBotDevs/astrbot-release-harbour/releases"
|
||||
)
|
||||
url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
|
||||
logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
|
||||
if proxy:
|
||||
url = f"{proxy}/{url}"
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
releases = await resp.json()
|
||||
for release in releases:
|
||||
if version in release["tag_name"]:
|
||||
download_url = release["assets"][0]["browser_download_url"]
|
||||
await download_file(download_url, path, show_progress=True)
|
||||
else:
|
||||
logger.warning(f"未找到指定的版本的 Dashboard 构建文件: {version}")
|
||||
return
|
||||
|
||||
await download_file(url, path, show_progress=True)
|
||||
with zipfile.ZipFile(path, "r") as z:
|
||||
z.extractall(extract_path)
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.css" integrity="sha384-wcIxkf4k558AjM3Yz3BBFQUbk/zgIYC2R0QpeeYb+TwlBVMrlgLqwRjRtGZiK7ww" crossorigin="anonymous">
|
||||
<link rel="stylesheet" href="/path/to/styles/default.min.css">
|
||||
<script src="/path/to/highlight.min.js"></script>
|
||||
<script>hljs.highlightAll();</script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.js" integrity="sha384-hIoBPJpTUs74ddyc4bFZSM1TVlQDA60VBbJS0oA934VSz82sBx1X7kSx2ATBDIyd" crossorigin="anonymous"></script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/contrib/auto-render.min.js" integrity="sha384-43gviWU0YVjaDtb/GhzOouOXtZMP/7XUzwPTstBeZFe/+rCMvRwr4yROQP43s0Xk" crossorigin="anonymous"
|
||||
onload="renderMathInElement(document.getElementById('content'),{delimiters: [{left: '$$', right: '$$', display: true},{left: '$', right: '$', display: false}]});"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div style="background-color: #3276dc; color: #fff; font-size: 64px; ">
|
||||
<span style="font-weight: bold; margin-left: 16px"># AstrBot</span>
|
||||
<span>{{ version }}</span>
|
||||
</div>
|
||||
<article style="margin-top: 32px" id="content"></article>
|
||||
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||
<script>
|
||||
document.getElementById('content').innerHTML = marked.parse(`{{ text | safe}}`);
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
<style>
|
||||
#content {
|
||||
min-width: 200px;
|
||||
max-width: 85%;
|
||||
margin: 0 auto;
|
||||
padding: 2rem 1em 1em;
|
||||
}
|
||||
|
||||
body {
|
||||
word-break: break-word;
|
||||
line-height: 1.75;
|
||||
font-weight: 400;
|
||||
font-size: 32px;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
overflow-x: hidden;
|
||||
color: #333;
|
||||
font-family: -apple-system,BlinkMacSystemFont,Segoe UI,Helvetica,Arial,sans-serif,Apple Color Emoji,Segoe UI Emoji;
|
||||
}
|
||||
h1, h2, h3, h4, h5, h6 {
|
||||
line-height: 1.5;
|
||||
margin-top: 35px;
|
||||
margin-bottom: 10px;
|
||||
padding-bottom: 5px;
|
||||
}
|
||||
h1:first-child, h2:first-child, h3:first-child, h4:first-child, h5:first-child, h6:first-child {
|
||||
margin-top: -1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
h1::before, h2::before, h3::before, h4::before, h5::before, h6::before {
|
||||
content: "#";
|
||||
display: inline-block;
|
||||
color: #3eaf7c;
|
||||
padding-right: 0.23em;
|
||||
}
|
||||
h1 {
|
||||
position: relative;
|
||||
font-size: 2.5rem;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
h1::before {
|
||||
font-size: 2.5rem;
|
||||
}
|
||||
h2 {
|
||||
padding-bottom: 0.5rem;
|
||||
font-size: 2.2rem;
|
||||
border-bottom: 1px solid #ececec;
|
||||
}
|
||||
h3 {
|
||||
font-size: 1.5rem;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
h4 {
|
||||
font-size: 1.25rem;
|
||||
}
|
||||
h5 {
|
||||
font-size: 1rem;
|
||||
}
|
||||
h6 {
|
||||
margin-top: 5px;
|
||||
}
|
||||
p {
|
||||
line-height: inherit;
|
||||
margin-top: 22px;
|
||||
margin-bottom: 22px;
|
||||
}
|
||||
strong {
|
||||
color: #3eaf7c;
|
||||
}
|
||||
img {
|
||||
max-width: 100%;
|
||||
border-radius: 2px;
|
||||
display: block;
|
||||
margin: auto;
|
||||
border: 3px solid rgba(62, 175, 124, 0.2);
|
||||
}
|
||||
hr {
|
||||
border-top: 1px solid #3eaf7c;
|
||||
border-bottom: none;
|
||||
border-left: none;
|
||||
border-right: none;
|
||||
margin-top: 32px;
|
||||
margin-bottom: 32px;
|
||||
}
|
||||
code {
|
||||
font-family: Menlo, Monaco, Consolas, "Courier New", monospace;
|
||||
word-break: break-word;
|
||||
overflow-x: auto;
|
||||
padding: 0.2rem 0.5rem;
|
||||
margin: 0;
|
||||
color: #3eaf7c;
|
||||
font-size: 0.85em;
|
||||
background-color: rgba(27, 31, 35, 0.05);
|
||||
border-radius: 3px;
|
||||
}
|
||||
pre {
|
||||
font-family: Menlo, Monaco, Consolas, "Courier New", monospace;
|
||||
overflow: auto;
|
||||
position: relative;
|
||||
line-height: 1.75;
|
||||
border-radius: 6px;
|
||||
border: 2px solid #3eaf7c;
|
||||
}
|
||||
pre > code {
|
||||
font-size: 12px;
|
||||
padding: 15px 12px;
|
||||
margin: 0;
|
||||
word-break: normal;
|
||||
display: block;
|
||||
overflow-x: auto;
|
||||
color: #333;
|
||||
background: #f8f8f8;
|
||||
}
|
||||
a {
|
||||
font-weight: 500;
|
||||
text-decoration: none;
|
||||
color: #3eaf7c;
|
||||
}
|
||||
a:hover, a:active {
|
||||
border-bottom: 1.5px solid #3eaf7c;
|
||||
}
|
||||
a:before {
|
||||
content: "⇲";
|
||||
}
|
||||
table {
|
||||
display: inline-block !important;
|
||||
font-size: 12px;
|
||||
width: auto;
|
||||
max-width: 100%;
|
||||
overflow: auto;
|
||||
border: solid 1px #3eaf7c;
|
||||
}
|
||||
thead {
|
||||
background: #3eaf7c;
|
||||
color: #fff;
|
||||
text-align: left;
|
||||
}
|
||||
tr:nth-child(2n) {
|
||||
background-color: rgba(62, 175, 124, 0.2);
|
||||
}
|
||||
th, td {
|
||||
padding: 12px 7px;
|
||||
line-height: 24px;
|
||||
}
|
||||
td {
|
||||
min-width: 120px;
|
||||
}
|
||||
blockquote {
|
||||
color: #666;
|
||||
padding: 1px 23px;
|
||||
margin: 22px 0;
|
||||
border-left: 0.5rem solid rgba(62, 175, 124, 0.6);
|
||||
border-color: #42b983;
|
||||
background-color: #f8f8f8;
|
||||
}
|
||||
blockquote::after {
|
||||
display: block;
|
||||
content: "";
|
||||
}
|
||||
blockquote > p {
|
||||
margin: 10px 0;
|
||||
}
|
||||
details {
|
||||
border: none;
|
||||
outline: none;
|
||||
border-left: 4px solid #3eaf7c;
|
||||
padding-left: 10px;
|
||||
margin-left: 4px;
|
||||
}
|
||||
details summary {
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
outline: none;
|
||||
background: white;
|
||||
margin: 0px -17px;
|
||||
}
|
||||
details summary::-webkit-details-marker {
|
||||
color: #3eaf7c;
|
||||
}
|
||||
ol, ul {
|
||||
padding-left: 28px;
|
||||
}
|
||||
ol li, ul li {
|
||||
margin-bottom: 0;
|
||||
list-style: inherit;
|
||||
}
|
||||
ol li .task-list-item, ul li .task-list-item {
|
||||
list-style: none;
|
||||
}
|
||||
ol li .task-list-item ul, ul li .task-list-item ul, ol li .task-list-item ol, ul li .task-list-item ol {
|
||||
margin-top: 0;
|
||||
}
|
||||
ol ul, ul ul, ol ol, ul ol {
|
||||
margin-top: 3px;
|
||||
}
|
||||
ol li {
|
||||
padding-left: 6px;
|
||||
}
|
||||
ol li::marker {
|
||||
color: #3eaf7c;
|
||||
}
|
||||
ul li {
|
||||
list-style: none;
|
||||
}
|
||||
ul li:before {
|
||||
content: "•";
|
||||
margin-right: 4px;
|
||||
color: #3eaf7c;
|
||||
}
|
||||
@media (max-width: 720px) {
|
||||
h1 {
|
||||
font-size: 24px;
|
||||
}
|
||||
h2 {
|
||||
font-size: 20px;
|
||||
}
|
||||
h3 {
|
||||
font-size: 18px;
|
||||
}
|
||||
}
|
||||
|
||||
</style>
|
||||
@@ -2,94 +2,111 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""
|
||||
负责管理 t2i HTML 模板的 CRUD 和重置操作。
|
||||
采用“用户覆盖内置”策略:用户模板存储在 data 目录中,并优先于内置模板加载。
|
||||
所有创建、更新、删除操作仅影响用户目录,以确保更新框架时用户数据安全。
|
||||
"""
|
||||
|
||||
CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"]
|
||||
|
||||
def __init__(self):
|
||||
# 修正路径拼接,加入缺失的 'astrbot' 目录
|
||||
self.template_dir = os.path.join(
|
||||
self.builtin_template_dir = os.path.join(
|
||||
get_astrbot_path(), "astrbot", "core", "utils", "t2i", "template"
|
||||
)
|
||||
self.backup_template_path = os.path.join(
|
||||
self.template_dir, "default_template.html.bak"
|
||||
)
|
||||
# 确保模板目录存在
|
||||
os.makedirs(self.template_dir, exist_ok=True)
|
||||
self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates")
|
||||
|
||||
# 检查模板目录中是否有 .html 文件
|
||||
html_files = [f for f in os.listdir(self.template_dir) if f.endswith(".html")]
|
||||
if not html_files and os.path.exists(self.backup_template_path):
|
||||
self.reset_default_template()
|
||||
os.makedirs(self.user_template_dir, exist_ok=True)
|
||||
self._initialize_user_templates()
|
||||
|
||||
def _get_template_path(self, name: str) -> str:
|
||||
"""获取模板的完整路径,防止路径遍历漏洞。"""
|
||||
def _copy_core_templates(self, overwrite: bool = False):
|
||||
"""从内置目录复制核心模板到用户目录。"""
|
||||
for filename in self.CORE_TEMPLATES:
|
||||
src = os.path.join(self.builtin_template_dir, filename)
|
||||
dst = os.path.join(self.user_template_dir, filename)
|
||||
if os.path.exists(src) and (overwrite or not os.path.exists(dst)):
|
||||
shutil.copyfile(src, dst)
|
||||
|
||||
def _initialize_user_templates(self):
|
||||
"""如果用户目录下缺少核心模板,则进行复制。"""
|
||||
self._copy_core_templates(overwrite=False)
|
||||
|
||||
def _get_user_template_path(self, name: str) -> str:
|
||||
"""获取用户模板的完整路径,防止路径遍历漏洞。"""
|
||||
if ".." in name or "/" in name or "\\" in name:
|
||||
raise ValueError("模板名称包含非法字符。")
|
||||
return os.path.join(self.template_dir, f"{name}.html")
|
||||
return os.path.join(self.user_template_dir, f"{name}.html")
|
||||
|
||||
def list_templates(self) -> list[dict]:
|
||||
"""列出所有可用的模板。"""
|
||||
templates = []
|
||||
for filename in os.listdir(self.template_dir):
|
||||
if filename.endswith(".html"):
|
||||
templates.append(
|
||||
{
|
||||
"name": os.path.splitext(filename)[0],
|
||||
"is_default": filename == "base.html",
|
||||
}
|
||||
)
|
||||
return templates
|
||||
|
||||
def get_template(self, name: str) -> str:
|
||||
"""获取指定模板的内容。"""
|
||||
path = self._get_template_path(name)
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError("模板不存在。")
|
||||
def _read_file(self, path: str) -> str:
|
||||
"""读取文件内容。"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def list_templates(self) -> list[dict]:
|
||||
"""
|
||||
列出所有可用模板。
|
||||
该列表是内置模板和用户模板的合并视图,用户模板将覆盖同名的内置模板。
|
||||
"""
|
||||
dirs_to_scan = [self.builtin_template_dir, self.user_template_dir]
|
||||
all_names = {
|
||||
os.path.splitext(f)[0]
|
||||
for d in dirs_to_scan
|
||||
for f in os.listdir(d)
|
||||
if f.endswith(".html")
|
||||
}
|
||||
return [
|
||||
{"name": name, "is_default": name == "base"} for name in sorted(all_names)
|
||||
]
|
||||
|
||||
def get_template(self, name: str) -> str:
|
||||
"""
|
||||
获取指定模板的内容。
|
||||
优先从用户目录加载,如果不存在则回退到内置目录。
|
||||
"""
|
||||
user_path = self._get_user_template_path(name)
|
||||
if os.path.exists(user_path):
|
||||
return self._read_file(user_path)
|
||||
|
||||
builtin_path = os.path.join(self.builtin_template_dir, f"{name}.html")
|
||||
if os.path.exists(builtin_path):
|
||||
return self._read_file(builtin_path)
|
||||
|
||||
raise FileNotFoundError("模板不存在。")
|
||||
|
||||
def create_template(self, name: str, content: str):
|
||||
"""创建一个新的模板文件。"""
|
||||
path = self._get_template_path(name)
|
||||
"""在用户目录中创建一个新的模板文件。"""
|
||||
path = self._get_user_template_path(name)
|
||||
if os.path.exists(path):
|
||||
raise FileExistsError("同名模板已存在。")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
def update_template(self, name: str, content: str):
|
||||
"""更新一个已存在的模板文件。"""
|
||||
path = self._get_template_path(name)
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError("模板不存在。")
|
||||
"""
|
||||
更新一个模板。此操作始终写入用户目录。
|
||||
如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本,
|
||||
从而实现对内置模板的“覆盖”。
|
||||
"""
|
||||
path = self._get_user_template_path(name)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
def delete_template(self, name: str):
|
||||
"""删除一个模板文件。"""
|
||||
if name == "base":
|
||||
raise ValueError("不能删除默认的 base 模板。")
|
||||
path = self._get_template_path(name)
|
||||
"""
|
||||
仅删除用户目录中的模板文件。
|
||||
如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。
|
||||
"""
|
||||
path = self._get_user_template_path(name)
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError("模板不存在。")
|
||||
raise FileNotFoundError("用户模板不存在,无法删除。")
|
||||
os.remove(path)
|
||||
|
||||
def backup_default_template_if_not_exist(self):
|
||||
"""如果备份不存在,则创建默认模板的备份。"""
|
||||
default_path = os.path.join(self.template_dir, "base.html")
|
||||
if not os.path.exists(self.backup_template_path) and os.path.exists(
|
||||
default_path
|
||||
):
|
||||
shutil.copyfile(default_path, self.backup_template_path)
|
||||
|
||||
def reset_default_template(self):
|
||||
"""重置默认模板。"""
|
||||
if not os.path.exists(self.backup_template_path):
|
||||
raise FileNotFoundError("默认模板的备份文件不存在,无法重置。")
|
||||
|
||||
default_path = os.path.join(self.template_dir, "base.html")
|
||||
shutil.copyfile(self.backup_template_path, default_path)
|
||||
"""
|
||||
将核心模板从内置目录强制重置到用户目录。
|
||||
"""
|
||||
self._copy_core_templates(overwrite=True)
|
||||
|
||||
@@ -157,7 +157,11 @@ class ChatRoute(Route):
|
||||
|
||||
if type == "end":
|
||||
break
|
||||
elif (streaming and type == "complete") or not streaming:
|
||||
elif (
|
||||
(streaming and type == "complete")
|
||||
or not streaming
|
||||
or type == "break"
|
||||
):
|
||||
# append bot message
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
await self.platform_history_mgr.insert(
|
||||
@@ -197,6 +201,7 @@ class ChatRoute(Route):
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
response.timeout = None # fix SSE auto disconnect issue
|
||||
return response
|
||||
|
||||
async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import typing
|
||||
import traceback
|
||||
import os
|
||||
import copy
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
|
||||
@@ -10,7 +10,9 @@ class LogRoute(Route):
|
||||
super().__init__(context)
|
||||
self.log_broker = log_broker
|
||||
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
|
||||
self.app.add_url_rule("/api/log-history", view_func=self.log_history, methods=["GET"])
|
||||
self.app.add_url_rule(
|
||||
"/api/log-history", view_func=self.log_history, methods=["GET"]
|
||||
)
|
||||
|
||||
async def log(self):
|
||||
async def stream():
|
||||
@@ -48,9 +50,15 @@ class LogRoute(Route):
|
||||
"""获取日志历史"""
|
||||
try:
|
||||
logs = list(self.log_broker.log_cache)
|
||||
return Response().ok(data={
|
||||
"logs": logs,
|
||||
}).__dict__
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"logs": logs,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from dataclasses import dataclass
|
||||
from quart import Quart
|
||||
|
||||
@@ -32,10 +32,6 @@ class T2iRoute(Route):
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# 应用启动时,确保备份存在
|
||||
self.manager.backup_default_template_if_not_exist()
|
||||
|
||||
self.register_routes()
|
||||
|
||||
async def list_templates(self):
|
||||
@@ -89,6 +85,7 @@ class T2iRoute(Route):
|
||||
)
|
||||
response.status_code = 400
|
||||
return response
|
||||
name = name.strip()
|
||||
|
||||
self.manager.create_template(name, content)
|
||||
response = jsonify(
|
||||
@@ -118,6 +115,7 @@ class T2iRoute(Route):
|
||||
async def update_template(self, name: str):
|
||||
"""更新一个已存在的T2I模板"""
|
||||
try:
|
||||
name = name.strip()
|
||||
data = await request.json
|
||||
content = data.get("content")
|
||||
if content is None:
|
||||
@@ -126,17 +124,16 @@ class T2iRoute(Route):
|
||||
return response
|
||||
|
||||
self.manager.update_template(name, content)
|
||||
return jsonify(
|
||||
asdict(
|
||||
Response().ok(
|
||||
data={"name": name}, message="Template updated successfully."
|
||||
)
|
||||
)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
response = jsonify(asdict(Response().error("Template not found.")))
|
||||
response.status_code = 404
|
||||
return response
|
||||
|
||||
# 检查更新的是否为当前激活的模板,如果是,则热重载
|
||||
active_template = self.config.get("t2i_active_template", "base")
|
||||
if name == active_template:
|
||||
await self.core_lifecycle.reload_pipeline_scheduler("default")
|
||||
message = f"模板 '{name}' 已更新并重新加载。"
|
||||
else:
|
||||
message = f"模板 '{name}' 已更新。"
|
||||
|
||||
return jsonify(asdict(Response().ok(data={"name": name}, message=message)))
|
||||
except ValueError as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 400
|
||||
@@ -149,6 +146,7 @@ class T2iRoute(Route):
|
||||
async def delete_template(self, name: str):
|
||||
"""删除一个T2I模板"""
|
||||
try:
|
||||
name = name.strip()
|
||||
self.manager.delete_template(name)
|
||||
return jsonify(
|
||||
asdict(Response().ok(message="Template deleted successfully."))
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import traceback
|
||||
|
||||
import aiohttp
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -29,10 +29,19 @@ class AstrBotDashboard:
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
db: BaseDatabase,
|
||||
shutdown_event: asyncio.Event,
|
||||
webui_dir: str | None = None,
|
||||
) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.data_path = os.path.abspath(os.path.join(get_astrbot_data_path(), "dist"))
|
||||
|
||||
# 参数指定webui目录
|
||||
if webui_dir and os.path.exists(webui_dir):
|
||||
self.data_path = os.path.abspath(webui_dir)
|
||||
else:
|
||||
self.data_path = os.path.abspath(
|
||||
os.path.join(get_astrbot_data_path(), "dist")
|
||||
)
|
||||
|
||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||
APP = self.app # noqa
|
||||
self.app.config["MAX_CONTENT_LENGTH"] = (
|
||||
|
||||
15
changelogs/v4.1.0.md
Normal file
15
changelogs/v4.1.0.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# What's Changed
|
||||
|
||||
> 如果已经使用自定义文转图模板,此次升级之后将会被覆盖,请提前备份。路径在 `astrbot/core/utils/t2i/template` 目录下。
|
||||
|
||||
0. ‼️‼️‼️ 修复 LLM 仍会调用已禁用的工具的问题 ([#2729](https://github.com/Soulter/AstrBot/issues/2729))
|
||||
1. ‼️ 修复 WebChat 下,Agent 长时任务时,SSE 连接自动断开的问题
|
||||
2. ‼️ 修复自定义文转图模板更新版本后会被覆盖的问题 ([#2677](https://github.com/Soulter/AstrBot/issues/2677))
|
||||
3. 修复 Satori 适配器教程链接 ([#2668](https://github.com/Soulter/AstrBot/issues/2668))
|
||||
4. 修复插件页表格视图中,点击状态字段表头排序不起作用的问题 ([#2714](https://github.com/Soulter/AstrBot/issues/2714))
|
||||
5. 修复工具调用时的 content 内容在重新加载后没有显示在 webchat 的问题 ([#2727](https://github.com/Soulter/AstrBot/issues/2727))
|
||||
6. 允许添加多个 tavily API Key 进行轮询 ([#2725](https://github.com/Soulter/AstrBot/issues/2725))
|
||||
7. 添加 --webui-dir 启动参数以支持指定 WebUI 构建文件目录 ([#2680](https://github.com/Soulter/AstrBot/issues/2680))
|
||||
8. 兼容指令名和第一个参数之间没有空格的情况 ([#2650](https://github.com/Soulter/AstrBot/issues/2650))
|
||||
9. 支持在 WebUI 自定义 OpenAI API extra_body 参数 ([#2719](https://github.com/Soulter/AstrBot/issues/2719))
|
||||
10. 增加 on_platform_loaded 钩子以在消息平台适配器实例化完成后触发 ([#2651](https://github.com/Soulter/AstrBot/issues/2651))
|
||||
5
changelogs/v4.1.1.md
Normal file
5
changelogs/v4.1.1.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# What's Changed
|
||||
|
||||
修复了 v4.1.0 `model referenced before assignment` 的错误。
|
||||
|
||||
> 如果已经使用自定义文转图模板,此次升级之后将会被覆盖,请提前备份。路径在 `astrbot/core/utils/t2i/template` 目录下。
|
||||
9
changelogs/v4.1.2.md
Normal file
9
changelogs/v4.1.2.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# What's Changed
|
||||
|
||||
0. ‼️‼️‼️ fix: 修复 4.1.1 版本下,指令调用异常的问题
|
||||
1. ‼️‼️ fix: 修复多配置文件配置的不同人格无法生效的问题 ([#2739](https://github.com/AstrBotDevs/AstrBot/issues/2739))
|
||||
2. ‼️‼️ fix: 修复人格所选择的工具无法应用的问题 ([#2739](https://github.com/AstrBotDevs/AstrBot/issues/2739))
|
||||
3. ‼️‼️ fix: 修复平台配置下的「内容安全」组无法生效 ([#2751](https://github.com/AstrBotDevs/AstrBot/issues/2751))
|
||||
4. perf: 检查服务提供商可用性时跳过未启用的提供商,解决部分 `provider with id xxx not found` 的问题
|
||||
|
||||
fixes: [#2724](https://github.com/AstrBotDevs/AstrBot/issues/2724)
|
||||
8
changelogs/v4.1.3.md
Normal file
8
changelogs/v4.1.3.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# What's Changed
|
||||
|
||||
0. ‼️ fix: 修复 4.0.0 版本之后,配置默认 TTS 或者 STT 模型之后仍无法生效的问题 ([#2758](https://github.com/Soulter/AstrBot/issues/2758))
|
||||
1. ‼️ fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 ([#2757](https://github.com/Soulter/AstrBot/issues/2757))
|
||||
2. feat: 支持在 WebUI 复制提供商配置以简化操作 ([#2767](https://github.com/Soulter/AstrBot/issues/2767))
|
||||
3. fix: handle image value correctly for mcp BlobResourceContents ([#2753](https://github.com/Soulter/AstrBot/issues/2753))
|
||||
4. feat: 增加 QQ 群名称识别到 system prompt, 并提供相应的配置 ([#2770](https://github.com/Soulter/AstrBot/issues/2770))
|
||||
5. fix: parameter type/default handling in CommandFilter
|
||||
10
changelogs/v4.1.4.md
Normal file
10
changelogs/v4.1.4.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# What's Changed
|
||||
|
||||
0. ‼️ fix: 修复 4.0.0 版本之后,配置默认 TTS 或者 STT 模型之后仍无法生效的问题 ([#2758](https://github.com/Soulter/AstrBot/issues/2758))
|
||||
1. ‼️ fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 ([#2757](https://github.com/Soulter/AstrBot/issues/2757))
|
||||
2. feat: 支持在 WebUI 复制提供商配置以简化操作 ([#2767](https://github.com/Soulter/AstrBot/issues/2767))
|
||||
3. fix: handle image value correctly for mcp BlobResourceContents ([#2753](https://github.com/Soulter/AstrBot/issues/2753))
|
||||
4. feat: 增加 QQ 群名称识别到 system prompt, 并提供相应的配置 ([#2770](https://github.com/Soulter/AstrBot/issues/2770))
|
||||
5. fix: 修复 4.1.3 的异常问题
|
||||
|
||||
**总之上个版本有很严重的 bug 赶快更新!**
|
||||
11
changelogs/v4.1.5.md
Normal file
11
changelogs/v4.1.5.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# What's Changed
|
||||
|
||||
0. feat: 新增 Misskey 平台适配器 ([#2774](https://github.com/AstrBotDevs/AstrBot/issues/2774))
|
||||
1. fix: 修复aiocqhttp适配器at会获取群昵称而消息不会获取的逻辑不一致 ([#2769](https://github.com/AstrBotDevs/AstrBot/issues/2769))
|
||||
2. fix: 修复「对话管理」页面的关键词搜索功能失效的问题并优化一些 UI 样式 ([#2837](https://github.com/AstrBotDevs/AstrBot/issues/2837))
|
||||
3. fix: 识别「引用消息」的图片时优先使用默认图片转述提供商 ([#2836](https://github.com/AstrBotDevs/AstrBot/issues/2836))
|
||||
5. fix: 修复 Telegram 下流式传输时,第一次输出的内容会被覆盖掉的问题
|
||||
6. perf: 优化统计页内存占用和消息数据趋势的样式 ([#2826](https://github.com/AstrBotDevs/AstrBot/issues/2826))
|
||||
7. perf: 优化 「插件页」、「对话管理页」、「会话管理页」的样式
|
||||
8. fix: on_tool_end hook unavailable
|
||||
9. feat: add audioop-lts dependencies ([#2809](https://github.com/AstrBotDevs/AstrBot/issues/2809))
|
||||
3
changelogs/v4.1.6.md
Normal file
3
changelogs/v4.1.6.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# What's Changed
|
||||
|
||||
1. fix: 修复在某些情况下,出现 「返回的 Provider 不是 Provider 类型的错误」
|
||||
BIN
dashboard/src/assets/images/platform_logos/misskey.png
Normal file
BIN
dashboard/src/assets/images/platform_logos/misskey.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
1347
dashboard/src/components/chat/Chat.vue
Normal file
1347
dashboard/src/components/chat/Chat.vue
Normal file
File diff suppressed because it is too large
Load Diff
775
dashboard/src/components/chat/MessageList.vue
Normal file
775
dashboard/src/components/chat/MessageList.vue
Normal file
@@ -0,0 +1,775 @@
|
||||
<template>
|
||||
<div class="messages-container" ref="messageContainer">
|
||||
<!-- 聊天消息列表 -->
|
||||
<div class="message-list">
|
||||
<div class="message-item fade-in" v-for="(msg, index) in messages" :key="index">
|
||||
<!-- 用户消息 -->
|
||||
<div v-if="msg.content.type == 'user'" class="user-message">
|
||||
<div class="message-bubble user-bubble" :class="{ 'has-audio': msg.content.audio_url }"
|
||||
:style="{ backgroundColor: isDark ? '#2d2e30' : '#e7ebf4' }">
|
||||
<pre
|
||||
style="font-family: inherit; white-space: pre-wrap; word-wrap: break-word;">{{ msg.content.message }}</pre>
|
||||
|
||||
<!-- 图片附件 -->
|
||||
<div class="image-attachments" v-if="msg.content.image_url && msg.content.image_url.length > 0">
|
||||
<div v-for="(img, index) in msg.content.image_url" :key="index" class="image-attachment">
|
||||
<img :src="img" class="attached-image" @click="$emit('openImagePreview', img)" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 音频附件 -->
|
||||
<div class="audio-attachment" v-if="msg.content.audio_url && msg.content.audio_url.length > 0">
|
||||
<audio controls class="audio-player">
|
||||
<source :src="msg.content.audio_url" type="audio/wav">
|
||||
{{ t('messages.errors.browser.audioNotSupported') }}
|
||||
</audio>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bot Messages -->
|
||||
<div v-else class="bot-message">
|
||||
<div v-if="isStreaming && index === messages.length - 1" style="width: 36px; height: 36px;">
|
||||
<v-progress-circular indeterminate size="28" width="2"
|
||||
style="margin-top: 16px;"></v-progress-circular>
|
||||
</div>
|
||||
<v-avatar v-else class="bot-avatar" size="36">
|
||||
<span class="text-h2">✨</span>
|
||||
</v-avatar>
|
||||
<div class="bot-message-content">
|
||||
<div class="message-bubble bot-bubble">
|
||||
<!-- Text -->
|
||||
<div v-if="msg.content.message && msg.content.message.trim()"
|
||||
v-html="md.render(msg.content.message)" class="markdown-content"></div>
|
||||
|
||||
<!-- Image -->
|
||||
<div class="embedded-images"
|
||||
v-if="msg.content.embedded_images && msg.content.embedded_images.length > 0">
|
||||
<div v-for="(img, imgIndex) in msg.content.embedded_images" :key="imgIndex"
|
||||
class="embedded-image">
|
||||
<img :src="img" class="bot-embedded-image"
|
||||
@click="$emit('openImagePreview', img)" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Audio -->
|
||||
<div class="embedded-audio" v-if="msg.content.embedded_audio">
|
||||
<audio controls class="audio-player">
|
||||
<source :src="msg.content.embedded_audio" type="audio/wav">
|
||||
{{ t('messages.errors.browser.audioNotSupported') }}
|
||||
</audio>
|
||||
</div>
|
||||
</div>
|
||||
<div class="message-actions">
|
||||
<v-btn :icon="getCopyIcon(index)" size="small" variant="text" class="copy-message-btn"
|
||||
:class="{ 'copy-success': isCopySuccess(index) }"
|
||||
@click="copyBotMessage(msg.content.message, index)" :title="t('core.common.copy')" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
import MarkdownIt from 'markdown-it';
|
||||
import hljs from 'highlight.js';
|
||||
import 'highlight.js/styles/github.css';
|
||||
|
||||
const md = new MarkdownIt({
|
||||
html: false,
|
||||
breaks: true,
|
||||
linkify: true,
|
||||
highlight: function (code, lang) {
|
||||
if (lang && hljs.getLanguage(lang)) {
|
||||
try {
|
||||
return hljs.highlight(code, { language: lang }).value;
|
||||
} catch (err) {
|
||||
console.error('Highlight error:', err);
|
||||
}
|
||||
}
|
||||
return hljs.highlightAuto(code).value;
|
||||
}
|
||||
});
|
||||
|
||||
export default {
|
||||
name: 'MessageList',
|
||||
props: {
|
||||
messages: {
|
||||
type: Array,
|
||||
required: true
|
||||
},
|
||||
isDark: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
isStreaming: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
},
|
||||
emits: ['openImagePreview'],
|
||||
setup() {
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n('features/chat');
|
||||
|
||||
return {
|
||||
t,
|
||||
tm,
|
||||
md
|
||||
};
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
copiedMessages: new Set(),
|
||||
isUserNearBottom: true,
|
||||
scrollThreshold: 1,
|
||||
scrollTimer: null
|
||||
};
|
||||
},
|
||||
mounted() {
|
||||
this.initCodeCopyButtons();
|
||||
this.initImageClickEvents();
|
||||
this.addScrollListener();
|
||||
this.scrollToBottom();
|
||||
},
|
||||
updated() {
|
||||
this.initCodeCopyButtons();
|
||||
this.initImageClickEvents();
|
||||
if (this.isUserNearBottom) {
|
||||
this.scrollToBottom();
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
// 复制代码到剪贴板
|
||||
copyCodeToClipboard(code) {
|
||||
navigator.clipboard.writeText(code).then(() => {
|
||||
console.log('代码已复制到剪贴板');
|
||||
}).catch(err => {
|
||||
console.error('复制失败:', err);
|
||||
// 如果现代API失败,使用传统方法
|
||||
const textArea = document.createElement('textarea');
|
||||
textArea.value = code;
|
||||
document.body.appendChild(textArea);
|
||||
textArea.select();
|
||||
try {
|
||||
document.execCommand('copy');
|
||||
console.log('代码已复制到剪贴板 (fallback)');
|
||||
} catch (fallbackErr) {
|
||||
console.error('复制失败 (fallback):', fallbackErr);
|
||||
}
|
||||
document.body.removeChild(textArea);
|
||||
});
|
||||
},
|
||||
|
||||
// 复制bot消息到剪贴板
|
||||
copyBotMessage(message, messageIndex) {
|
||||
// 获取对应的消息对象
|
||||
const msgObj = this.messages[messageIndex].content;
|
||||
let textToCopy = '';
|
||||
|
||||
// 如果有文本消息,添加到复制内容中
|
||||
if (message && message.trim()) {
|
||||
// 移除HTML标签,获取纯文本
|
||||
const tempDiv = document.createElement('div');
|
||||
tempDiv.innerHTML = message;
|
||||
textToCopy = tempDiv.textContent || tempDiv.innerText || message;
|
||||
}
|
||||
|
||||
// 如果有内嵌图片,添加说明
|
||||
if (msgObj && msgObj.embedded_images && msgObj.embedded_images.length > 0) {
|
||||
if (textToCopy) textToCopy += '\n\n';
|
||||
textToCopy += `[包含 ${msgObj.embedded_images.length} 张图片]`;
|
||||
}
|
||||
|
||||
// 如果有内嵌音频,添加说明
|
||||
if (msgObj && msgObj.embedded_audio) {
|
||||
if (textToCopy) textToCopy += '\n\n';
|
||||
textToCopy += '[包含音频内容]';
|
||||
}
|
||||
|
||||
// 如果没有任何内容,使用默认文本
|
||||
if (!textToCopy.trim()) {
|
||||
textToCopy = '[媒体内容]';
|
||||
}
|
||||
|
||||
navigator.clipboard.writeText(textToCopy).then(() => {
|
||||
console.log('消息已复制到剪贴板');
|
||||
this.showCopySuccess(messageIndex);
|
||||
}).catch(err => {
|
||||
console.error('复制失败:', err);
|
||||
// 如果现代API失败,使用传统方法
|
||||
const textArea = document.createElement('textarea');
|
||||
textArea.value = textToCopy;
|
||||
document.body.appendChild(textArea);
|
||||
textArea.select();
|
||||
try {
|
||||
document.execCommand('copy');
|
||||
console.log('消息已复制到剪贴板 (fallback)');
|
||||
this.showCopySuccess(messageIndex);
|
||||
} catch (fallbackErr) {
|
||||
console.error('复制失败 (fallback):', fallbackErr);
|
||||
}
|
||||
document.body.removeChild(textArea);
|
||||
});
|
||||
},
|
||||
|
||||
// 显示复制成功提示
|
||||
showCopySuccess(messageIndex) {
|
||||
this.copiedMessages.add(messageIndex);
|
||||
|
||||
// 2秒后移除成功状态
|
||||
setTimeout(() => {
|
||||
this.copiedMessages.delete(messageIndex);
|
||||
}, 2000);
|
||||
},
|
||||
|
||||
// 获取复制按钮图标
|
||||
getCopyIcon(messageIndex) {
|
||||
return this.copiedMessages.has(messageIndex) ? 'mdi-check' : 'mdi-content-copy';
|
||||
},
|
||||
|
||||
// 检查是否为复制成功状态
|
||||
isCopySuccess(messageIndex) {
|
||||
return this.copiedMessages.has(messageIndex);
|
||||
},
|
||||
|
||||
// 获取复制图标SVG
|
||||
getCopyIconSvg() {
|
||||
return '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>';
|
||||
},
|
||||
|
||||
// 获取成功图标SVG
|
||||
getSuccessIconSvg() {
|
||||
return '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><polyline points="20,6 9,17 4,12"></polyline></svg>';
|
||||
},
|
||||
|
||||
// 初始化代码块复制按钮
|
||||
initCodeCopyButtons() {
|
||||
this.$nextTick(() => {
|
||||
const codeBlocks = this.$refs.messageContainer?.querySelectorAll('pre code') || [];
|
||||
codeBlocks.forEach((codeBlock, index) => {
|
||||
const pre = codeBlock.parentElement;
|
||||
if (pre && !pre.querySelector('.copy-code-btn')) {
|
||||
const button = document.createElement('button');
|
||||
button.className = 'copy-code-btn';
|
||||
button.innerHTML = this.getCopyIconSvg();
|
||||
button.title = '复制代码';
|
||||
button.addEventListener('click', () => {
|
||||
this.copyCodeToClipboard(codeBlock.textContent);
|
||||
// 显示复制成功提示
|
||||
button.innerHTML = this.getSuccessIconSvg();
|
||||
button.style.color = '#4caf50';
|
||||
setTimeout(() => {
|
||||
button.innerHTML = this.getCopyIconSvg();
|
||||
button.style.color = '';
|
||||
}, 2000);
|
||||
});
|
||||
pre.style.position = 'relative';
|
||||
pre.appendChild(button);
|
||||
}
|
||||
});
|
||||
});
|
||||
},
|
||||
|
||||
initImageClickEvents() {
|
||||
this.$nextTick(() => {
|
||||
// 查找所有动态生成的图片(在markdown-content中)
|
||||
const images = document.querySelectorAll('.markdown-content img');
|
||||
images.forEach((img) => {
|
||||
if (!img.hasAttribute('data-click-enabled')) {
|
||||
img.style.cursor = 'pointer';
|
||||
img.setAttribute('data-click-enabled', 'true');
|
||||
img.onclick = () => this.$emit('openImagePreview', img.src);
|
||||
}
|
||||
});
|
||||
});
|
||||
},
|
||||
|
||||
scrollToBottom() {
|
||||
this.$nextTick(() => {
|
||||
const container = this.$refs.messageContainer;
|
||||
if (container) {
|
||||
container.scrollTop = container.scrollHeight;
|
||||
this.isUserNearBottom = true; // 程序滚动到底部后标记用户在底部
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
// 添加滚动事件监听器
|
||||
addScrollListener() {
|
||||
const container = this.$refs.messageContainer;
|
||||
if (container) {
|
||||
container.addEventListener('scroll', this.throttledHandleScroll);
|
||||
}
|
||||
},
|
||||
|
||||
// 节流处理滚动事件
|
||||
throttledHandleScroll() {
|
||||
if (this.scrollTimer) return;
|
||||
|
||||
this.scrollTimer = setTimeout(() => {
|
||||
this.handleScroll();
|
||||
this.scrollTimer = null;
|
||||
}, 50); // 50ms 节流
|
||||
},
|
||||
|
||||
// 处理滚动事件
|
||||
handleScroll() {
|
||||
const container = this.$refs.messageContainer;
|
||||
if (container) {
|
||||
const { scrollTop, scrollHeight, clientHeight } = container;
|
||||
const distanceFromBottom = scrollHeight - (scrollTop + clientHeight);
|
||||
|
||||
// 判断用户是否在底部附近
|
||||
this.isUserNearBottom = distanceFromBottom <= this.scrollThreshold;
|
||||
}
|
||||
},
|
||||
|
||||
// 组件销毁时移除监听器
|
||||
beforeUnmount() {
|
||||
const container = this.$refs.messageContainer;
|
||||
if (container) {
|
||||
container.removeEventListener('scroll', this.throttledHandleScroll);
|
||||
}
|
||||
// 清理定时器
|
||||
if (this.scrollTimer) {
|
||||
clearTimeout(this.scrollTimer);
|
||||
this.scrollTimer = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
/* 基础动画 */
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.messages-container {
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
overflow-y: auto;
|
||||
padding: 16px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
/* 消息列表样式 */
|
||||
.message-list {
|
||||
max-width: 900px;
|
||||
margin: 0 auto;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.message-item {
|
||||
margin-bottom: 24px;
|
||||
animation: fadeIn 0.3s ease-out;
|
||||
}
|
||||
|
||||
.user-message {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.bot-message {
|
||||
display: flex;
|
||||
justify-content: flex-start;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.bot-message-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
max-width: 80%;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.message-actions {
|
||||
display: flex;
|
||||
gap: 4px;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s ease;
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
.bot-message:hover .message-actions {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.copy-message-btn {
|
||||
opacity: 0.6;
|
||||
transition: all 0.2s ease;
|
||||
color: var(--v-theme-secondary);
|
||||
}
|
||||
|
||||
.copy-message-btn:hover {
|
||||
opacity: 1;
|
||||
background-color: rgba(103, 58, 183, 0.1);
|
||||
}
|
||||
|
||||
.copy-message-btn.copy-success {
|
||||
color: #4caf50;
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.copy-message-btn.copy-success:hover {
|
||||
color: #4caf50;
|
||||
background-color: rgba(76, 175, 80, 0.1);
|
||||
}
|
||||
|
||||
.message-bubble {
|
||||
padding: 8px 16px;
|
||||
border-radius: 12px;
|
||||
}
|
||||
|
||||
.user-bubble {
|
||||
color: var(--v-theme-primaryText);
|
||||
padding: 18px 20px;
|
||||
font-size: 15px;
|
||||
max-width: 60%;
|
||||
border-radius: 1.5rem;
|
||||
}
|
||||
|
||||
.bot-bubble {
|
||||
border: 1px solid var(--v-theme-border);
|
||||
color: var(--v-theme-primaryText);
|
||||
font-size: 15px;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.user-avatar,
|
||||
.bot-avatar {
|
||||
align-self: flex-start;
|
||||
margin-top: 12px;
|
||||
}
|
||||
|
||||
/* 附件样式 */
|
||||
.image-attachments {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-top: 8px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.image-attachment {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.attached-image {
|
||||
width: 120px;
|
||||
height: 120px;
|
||||
object-fit: cover;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.audio-attachment {
|
||||
margin-top: 8px;
|
||||
min-width: 250px;
|
||||
}
|
||||
|
||||
/* 包含音频的消息气泡最小宽度 */
|
||||
.message-bubble.has-audio {
|
||||
min-width: 280px;
|
||||
}
|
||||
|
||||
.audio-player {
|
||||
width: 100%;
|
||||
height: 36px;
|
||||
border-radius: 18px;
|
||||
}
|
||||
|
||||
.embedded-images {
|
||||
margin-top: 8px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.embedded-image {
|
||||
display: flex;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
|
||||
.bot-embedded-image {
|
||||
max-width: 80%;
|
||||
width: auto;
|
||||
height: auto;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.bot-embedded-image:hover {
|
||||
transform: scale(1.02);
|
||||
}
|
||||
|
||||
.embedded-audio {
|
||||
width: 300px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.embedded-audio .audio-player {
|
||||
width: 100%;
|
||||
max-width: 300px;
|
||||
}
|
||||
|
||||
/* 动画类 */
|
||||
.fade-in {
|
||||
animation: fadeIn 0.3s ease-in-out;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style>
|
||||
/* Markdown内容样式 - 需要全局样式 */
|
||||
.markdown-content {
|
||||
font-family: inherit;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.markdown-content h1,
|
||||
.markdown-content h2,
|
||||
.markdown-content h3,
|
||||
.markdown-content h4,
|
||||
.markdown-content h5,
|
||||
.markdown-content h6 {
|
||||
margin-top: 16px;
|
||||
margin-bottom: 10px;
|
||||
font-weight: 600;
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
.markdown-content h1 {
|
||||
font-size: 1.8em;
|
||||
border-bottom: 1px solid var(--v-theme-border);
|
||||
padding-bottom: 6px;
|
||||
}
|
||||
|
||||
.markdown-content h2 {
|
||||
font-size: 1.5em;
|
||||
}
|
||||
|
||||
.markdown-content h3 {
|
||||
font-size: 1.3em;
|
||||
}
|
||||
|
||||
.markdown-content li {
|
||||
margin-left: 16px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.markdown-content p {
|
||||
margin-top: .5rem;
|
||||
margin-bottom: .5rem;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
background-color: var(--v-theme-surface);
|
||||
padding: 12px;
|
||||
border-radius: 6px;
|
||||
overflow-x: auto;
|
||||
margin: 12px 0;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.markdown-content code {
|
||||
background-color: rgb(var(--v-theme-codeBg));
|
||||
padding: 2px 4px;
|
||||
border-radius: 4px;
|
||||
font-family: 'Fira Code', monospace;
|
||||
font-size: 0.9em;
|
||||
color: var(--v-theme-code);
|
||||
}
|
||||
|
||||
/* 代码块中的code标签样式 */
|
||||
.markdown-content pre code {
|
||||
background-color: transparent;
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
font-family: 'Fira Code', 'Consolas', 'Monaco', 'Courier New', monospace;
|
||||
font-size: 0.85em;
|
||||
color: inherit;
|
||||
display: block;
|
||||
overflow-x: auto;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
/* 自定义代码高亮样式 */
|
||||
.markdown-content pre {
|
||||
border: 1px solid var(--v-theme-border);
|
||||
background-color: rgb(var(--v-theme-preBg));
|
||||
border-radius: 16px;
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
/* 确保highlight.js的样式正确应用 */
|
||||
.markdown-content pre code.hljs {
|
||||
background: transparent !important;
|
||||
color: inherit;
|
||||
}
|
||||
|
||||
/* 亮色主题下的代码高亮 */
|
||||
.v-theme--light .markdown-content pre {
|
||||
background-color: #f6f8fa;
|
||||
}
|
||||
|
||||
/* 暗色主题下的代码块样式 */
|
||||
.v-theme--dark .markdown-content pre {
|
||||
background-color: #0d1117 !important;
|
||||
border-color: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
.v-theme--dark .markdown-content pre code {
|
||||
color: #e6edf3 !important;
|
||||
}
|
||||
|
||||
/* 暗色主题下的highlight.js样式覆盖 */
|
||||
.v-theme--dark .hljs {
|
||||
background: #0d1117 !important;
|
||||
color: #e6edf3 !important;
|
||||
}
|
||||
|
||||
.v-theme--dark .hljs-keyword,
|
||||
.v-theme--dark .hljs-selector-tag,
|
||||
.v-theme--dark .hljs-built_in,
|
||||
.v-theme--dark .hljs-name,
|
||||
.v-theme--dark .hljs-tag {
|
||||
color: #ff7b72 !important;
|
||||
}
|
||||
|
||||
.v-theme--dark .hljs-string,
|
||||
.v-theme--dark .hljs-title,
|
||||
.v-theme--dark .hljs-section,
|
||||
.v-theme--dark .hljs-attribute,
|
||||
.v-theme--dark .hljs-literal,
|
||||
.v-theme--dark .hljs-template-tag,
|
||||
.v-theme--dark .hljs-template-variable,
|
||||
.v-theme--dark .hljs-type,
|
||||
.v-theme--dark .hljs-addition {
|
||||
color: #a5d6ff !important;
|
||||
}
|
||||
|
||||
.v-theme--dark .hljs-comment,
|
||||
.v-theme--dark .hljs-quote,
|
||||
.v-theme--dark .hljs-deletion,
|
||||
.v-theme--dark .hljs-meta {
|
||||
color: #8b949e !important;
|
||||
}
|
||||
|
||||
.v-theme--dark .hljs-number,
|
||||
.v-theme--dark .hljs-regexp,
|
||||
.v-theme--dark .hljs-symbol,
|
||||
.v-theme--dark .hljs-variable,
|
||||
.v-theme--dark .hljs-template-variable,
|
||||
.v-theme--dark .hljs-link,
|
||||
.v-theme--dark .hljs-selector-attr,
|
||||
.v-theme--dark .hljs-selector-pseudo {
|
||||
color: #79c0ff !important;
|
||||
}
|
||||
|
||||
.v-theme--dark .hljs-function,
|
||||
.v-theme--dark .hljs-class,
|
||||
.v-theme--dark .hljs-title.class_ {
|
||||
color: #d2a8ff !important;
|
||||
}
|
||||
|
||||
/* 复制按钮样式 */
|
||||
.copy-code-btn {
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
background: rgba(255, 255, 255, 0.9);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: 4px;
|
||||
padding: 6px;
|
||||
cursor: pointer;
|
||||
opacity: 0;
|
||||
transition: all 0.2s ease;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: #666;
|
||||
font-size: 12px;
|
||||
z-index: 10;
|
||||
backdrop-filter: blur(4px);
|
||||
}
|
||||
|
||||
.copy-code-btn:hover {
|
||||
background: rgba(255, 255, 255, 1);
|
||||
color: #333;
|
||||
transform: scale(1.05);
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.copy-code-btn:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.markdown-content pre:hover .copy-code-btn {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.v-theme--dark .copy-code-btn {
|
||||
background: rgba(45, 45, 45, 0.9);
|
||||
border-color: rgba(255, 255, 255, 0.15);
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
.v-theme--dark .copy-code-btn:hover {
|
||||
background: rgba(45, 45, 45, 1);
|
||||
color: #fff;
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
.markdown-content img {
|
||||
max-width: 100%;
|
||||
border-radius: 8px;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
border-left: 4px solid var(--v-theme-secondary);
|
||||
padding-left: 16px;
|
||||
color: var(--v-theme-secondaryText);
|
||||
margin: 16px 0;
|
||||
}
|
||||
|
||||
.markdown-content table {
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
margin: 16px 0;
|
||||
}
|
||||
|
||||
.markdown-content th,
|
||||
.markdown-content td {
|
||||
border: 1px solid var(--v-theme-background);
|
||||
padding: 8px 12px;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.markdown-content th {
|
||||
background-color: var(--v-theme-containerBg);
|
||||
}
|
||||
</style>
|
||||
@@ -2,6 +2,7 @@
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref, computed } from 'vue'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import ObjectEditor from './ObjectEditor.vue'
|
||||
import ProviderSelector from './ProviderSelector.vue'
|
||||
import PersonaSelector from './PersonaSelector.vue'
|
||||
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
|
||||
@@ -80,7 +81,7 @@ function shouldShowItem(itemMeta, itemKey) {
|
||||
|
||||
function hasVisibleItemsAfter(items, currentIndex) {
|
||||
const itemEntries = Object.entries(items)
|
||||
|
||||
|
||||
// 检查当前索引之后是否还有可见的配置项
|
||||
for (let i = currentIndex + 1; i < itemEntries.length; i++) {
|
||||
const [itemKey, itemValue] = itemEntries[i]
|
||||
@@ -89,7 +90,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return false
|
||||
}
|
||||
</script>
|
||||
@@ -130,7 +131,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-expand-transition>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Regular Property -->
|
||||
<template v-else>
|
||||
<v-row v-if="!metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)" class="config-row">
|
||||
@@ -145,7 +146,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-list-item-title>
|
||||
|
||||
<v-list-item-subtitle class="property-hint">
|
||||
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint"
|
||||
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint"
|
||||
class="important-hint">‼️</span>
|
||||
{{ metadata[metadataKey].items[key]?.hint }}
|
||||
</v-list-item-subtitle>
|
||||
@@ -153,10 +154,10 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="1" class="d-flex align-center type-indicator">
|
||||
<v-chip v-if="!metadata[metadataKey].items[key]?.invisible"
|
||||
color="primary"
|
||||
label
|
||||
size="x-small"
|
||||
<v-chip v-if="!metadata[metadataKey].items[key]?.invisible"
|
||||
color="primary"
|
||||
label
|
||||
size="x-small"
|
||||
variant="flat">
|
||||
{{ metadata[metadataKey].items[key]?.type || 'string' }}
|
||||
</v-chip>
|
||||
@@ -166,35 +167,35 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<div v-if="metadata[metadataKey].items[key]" class="w-100">
|
||||
<!-- Special handling for specific metadata types -->
|
||||
<div v-if="metadata[metadataKey].items[key]?._special === 'select_provider'">
|
||||
<ProviderSelector
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'chat_completion'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_provider_stt'">
|
||||
<ProviderSelector
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'speech_to_text'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_provider_tts'">
|
||||
<ProviderSelector
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'text_to_speech'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_persona'">
|
||||
<PersonaSelector
|
||||
<PersonaSelector
|
||||
v-model="iterable[key]"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector
|
||||
<KnowledgeBaseSelector
|
||||
v-model="iterable[key]"
|
||||
/>
|
||||
</div>
|
||||
<!-- List item with options-->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.type === 'list' && metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible && metadata[metadataKey].items[key]?.render_type === 'checkbox'"
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.type === 'list' && metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible && metadata[metadataKey].items[key]?.render_type === 'checkbox'"
|
||||
class="d-flex flex-wrap gap-20">
|
||||
<v-checkbox
|
||||
v-for="(option, index) in metadata[metadataKey].items[key]?.options"
|
||||
@@ -233,10 +234,10 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
|
||||
<!-- Code Editor with Full Screen Option -->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.editor_mode && !metadata[metadataKey].items[key]?.invisible" class="editor-container">
|
||||
<VueMonacoEditor
|
||||
:theme="metadata[metadataKey].items[key]?.editor_theme || 'vs-light'"
|
||||
<VueMonacoEditor
|
||||
:theme="metadata[metadataKey].items[key]?.editor_theme || 'vs-light'"
|
||||
:language="metadata[metadataKey].items[key]?.editor_language || 'json'"
|
||||
style="min-height: 100px; flex-grow: 1; border: 1px solid rgba(0, 0, 0, 0.1);"
|
||||
style="min-height: 100px; flex-grow: 1; border: 1px solid rgba(0, 0, 0, 0.1);"
|
||||
v-model:value="iterable[key]"
|
||||
>
|
||||
</VueMonacoEditor>
|
||||
@@ -252,7 +253,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<v-icon>mdi-fullscreen</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- String input -->
|
||||
<v-text-field
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'string' && !metadata[metadataKey].items[key]?.invisible"
|
||||
@@ -262,7 +263,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
|
||||
|
||||
<!-- Numeric input -->
|
||||
<v-text-field
|
||||
v-else-if="(metadata[metadataKey].items[key]?.type === 'int' || metadata[metadataKey].items[key]?.type === 'float') && !metadata[metadataKey]?.invisible"
|
||||
@@ -273,7 +274,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
type="number"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
|
||||
|
||||
<!-- Text area -->
|
||||
<v-textarea
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'text' && !metadata[metadataKey].items[key]?.invisible"
|
||||
@@ -283,7 +284,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-textarea>
|
||||
|
||||
|
||||
<!-- Boolean switch -->
|
||||
<v-switch
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'bool' && !metadata[metadataKey].items[key]?.invisible"
|
||||
@@ -293,20 +294,27 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
density="compact"
|
||||
hide-details
|
||||
></v-switch>
|
||||
|
||||
|
||||
<!-- List item -->
|
||||
<ListConfigItem
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'list' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
class="config-field"
|
||||
/>
|
||||
|
||||
<!-- Dict item (key-value editor) -->
|
||||
<ObjectEditor
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'dict' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
class="config-field"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Fallback for unknown metadata -->
|
||||
<div v-else class="w-100">
|
||||
<v-text-field
|
||||
v-model="iterable[key]"
|
||||
:label="key"
|
||||
<v-text-field
|
||||
v-model="iterable[key]"
|
||||
:label="key"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
@@ -316,14 +324,14 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<v-divider
|
||||
<v-divider
|
||||
v-if="hasVisibleItemsAfter(filteredIterable, index) && !metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)"
|
||||
class="config-divider"
|
||||
></v-divider>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Simple Value Configuration -->
|
||||
<div v-else class="simple-config">
|
||||
<v-row class="config-row">
|
||||
@@ -342,9 +350,9 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="1" class="d-flex align-center type-indicator">
|
||||
<v-chip v-if="!metadata[metadataKey]?.invisible"
|
||||
color="primary"
|
||||
label
|
||||
<v-chip v-if="!metadata[metadataKey]?.invisible"
|
||||
color="primary"
|
||||
label
|
||||
size="x-small"
|
||||
variant="flat">
|
||||
{{ metadata[metadataKey]?.type }}
|
||||
@@ -364,7 +372,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-select>
|
||||
|
||||
|
||||
<!-- String input -->
|
||||
<v-text-field
|
||||
v-else-if="metadata[metadataKey]?.type === 'string' && !metadata[metadataKey]?.invisible"
|
||||
@@ -374,7 +382,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
|
||||
|
||||
<!-- Numeric input -->
|
||||
<v-text-field
|
||||
v-else-if="(metadata[metadataKey]?.type === 'int' || metadata[metadataKey]?.type === 'float') && !metadata[metadataKey]?.invisible"
|
||||
@@ -385,7 +393,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
type="number"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
|
||||
|
||||
<!-- Text area -->
|
||||
<v-textarea
|
||||
v-else-if="metadata[metadataKey]?.type === 'text' && !metadata[metadataKey]?.invisible"
|
||||
@@ -396,7 +404,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-textarea>
|
||||
|
||||
|
||||
<!-- Boolean switch -->
|
||||
<v-switch
|
||||
v-else-if="metadata[metadataKey]?.type === 'bool' && !metadata[metadataKey]?.invisible"
|
||||
@@ -406,9 +414,9 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
density="compact"
|
||||
hide-details
|
||||
></v-switch>
|
||||
|
||||
|
||||
<!-- List item -->
|
||||
<ListConfigItem
|
||||
<ListConfigItem
|
||||
v-else-if="metadata[metadataKey]?.type === 'list' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
class="config-field"
|
||||
@@ -435,9 +443,9 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-toolbar-items>
|
||||
</v-toolbar>
|
||||
<v-card-text class="pa-0">
|
||||
<VueMonacoEditor
|
||||
<VueMonacoEditor
|
||||
:theme="currentEditingTheme"
|
||||
:language="currentEditingLanguage"
|
||||
:language="currentEditingLanguage"
|
||||
style="height: calc(100vh - 64px);"
|
||||
v-model:value="currentEditingKeyIterable[currentEditingKey]"
|
||||
>
|
||||
@@ -567,11 +575,11 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
.nested-object {
|
||||
padding-left: 8px;
|
||||
}
|
||||
|
||||
|
||||
.config-row {
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
|
||||
.property-info, .type-indicator, .config-input {
|
||||
padding: 4px;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref, computed } from 'vue'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import ObjectEditor from './ObjectEditor.vue'
|
||||
import ProviderSelector from './ProviderSelector.vue'
|
||||
import PersonaSelector from './PersonaSelector.vue'
|
||||
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
|
||||
@@ -102,7 +103,7 @@ function shouldShowItem(itemMeta, itemKey) {
|
||||
|
||||
function hasVisibleItemsAfter(items, currentIndex) {
|
||||
const itemEntries = Object.entries(items)
|
||||
|
||||
|
||||
// 检查当前索引之后是否还有可见的配置项
|
||||
for (let i = currentIndex + 1; i < itemEntries.length; i++) {
|
||||
const [itemKey, itemMeta] = itemEntries[i]
|
||||
@@ -110,7 +111,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return false
|
||||
}
|
||||
</script>
|
||||
@@ -188,13 +189,20 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
color="primary" inset density="compact" hide-details style="display: flex; justify-content: end;"></v-switch>
|
||||
|
||||
<!-- List item for JSON selector -->
|
||||
<ListConfigItem
|
||||
<ListConfigItem
|
||||
v-else-if="itemMeta?.type === 'list'"
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
button-text="修改"
|
||||
class="config-field"
|
||||
/>
|
||||
|
||||
<!-- Object editor for JSON selector -->
|
||||
<ObjectEditor
|
||||
v-else-if="itemMeta?.type === 'dict'"
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
class="config-field"
|
||||
/>
|
||||
|
||||
<!-- Fallback for JSON selector -->
|
||||
<v-text-field v-else v-model="createSelectorModel(itemKey).value" density="compact" variant="outlined"
|
||||
class="config-field" hide-details></v-text-field>
|
||||
@@ -202,48 +210,48 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
|
||||
<!-- Special handling for specific metadata types -->
|
||||
<div v-else-if="itemMeta?._special === 'select_provider'">
|
||||
<ProviderSelector
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'chat_completion'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_provider_stt'">
|
||||
<ProviderSelector
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'speech_to_text'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_provider_tts'">
|
||||
<ProviderSelector
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'text_to_speech'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'provider_pool'">
|
||||
<ProviderSelector
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'chat_completion'"
|
||||
button-text="选择提供商池..."
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_persona'">
|
||||
<PersonaSelector
|
||||
<PersonaSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'persona_pool'">
|
||||
<PersonaSelector
|
||||
<PersonaSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
button-text="选择人格池..."
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector
|
||||
<KnowledgeBaseSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_plugin_set'">
|
||||
<PluginSetSelector
|
||||
<PluginSetSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
</div>
|
||||
@@ -261,12 +269,12 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<small class="text-grey">已选择的插件:</small>
|
||||
</div>
|
||||
<div class="d-flex flex-wrap ga-2 mt-2">
|
||||
<v-chip
|
||||
v-for="plugin in (createSelectorModel(itemKey).value || [])"
|
||||
:key="plugin"
|
||||
size="small"
|
||||
label
|
||||
color="primary"
|
||||
<v-chip
|
||||
v-for="plugin in (createSelectorModel(itemKey).value || [])"
|
||||
:key="plugin"
|
||||
size="small"
|
||||
label
|
||||
color="primary"
|
||||
variant="outlined"
|
||||
>
|
||||
{{ plugin === '*' ? '所有插件' : plugin }}
|
||||
|
||||
@@ -4,28 +4,28 @@
|
||||
<span class="text-h2 text-truncate" :title="getItemTitle()">{{ getItemTitle() }}</span>
|
||||
<v-tooltip location="top">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-switch
|
||||
color="primary"
|
||||
hide-details
|
||||
density="compact"
|
||||
<v-switch
|
||||
color="primary"
|
||||
hide-details
|
||||
density="compact"
|
||||
:model-value="getItemEnabled()"
|
||||
:loading="loading"
|
||||
:disabled="loading"
|
||||
v-bind="props"
|
||||
v-bind="props"
|
||||
@update:model-value="toggleEnabled"
|
||||
></v-switch>
|
||||
</template>
|
||||
<span>{{ getItemEnabled() ? t('core.common.itemCard.enabled') : t('core.common.itemCard.disabled') }}</span>
|
||||
</v-tooltip>
|
||||
</v-card-title>
|
||||
|
||||
|
||||
<v-card-text>
|
||||
<slot name="item-details" :item="item"></slot>
|
||||
</v-card-text>
|
||||
|
||||
|
||||
<v-card-actions style="margin: 8px;">
|
||||
<v-btn
|
||||
variant="outlined"
|
||||
variant="outlined"
|
||||
color="error"
|
||||
rounded="xl"
|
||||
@click="$emit('delete', item)"
|
||||
@@ -40,6 +40,15 @@
|
||||
>
|
||||
{{ t('core.common.itemCard.edit') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
v-if="showCopyButton"
|
||||
variant="tonal"
|
||||
color="secondary"
|
||||
rounded="xl"
|
||||
@click="$emit('copy', item)"
|
||||
>
|
||||
{{ t('core.common.itemCard.copy') }}
|
||||
</v-btn>
|
||||
<v-spacer></v-spacer>
|
||||
</v-card-actions>
|
||||
|
||||
@@ -83,9 +92,13 @@ export default {
|
||||
loading: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
showCopyButton: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
},
|
||||
emits: ['toggle-enabled', 'delete', 'edit'],
|
||||
emits: ['toggle-enabled', 'delete', 'edit', 'copy'],
|
||||
methods: {
|
||||
getItemTitle() {
|
||||
return this.item[this.titleField];
|
||||
|
||||
282
dashboard/src/components/shared/ObjectEditor.vue
Normal file
282
dashboard/src/components/shared/ObjectEditor.vue
Normal file
@@ -0,0 +1,282 @@
|
||||
<template>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<div>
|
||||
<span v-if="!modelValue || Object.keys(modelValue).length === 0" style="color: rgb(var(--v-theme-primaryText));">
|
||||
暂无项目
|
||||
</span>
|
||||
<div v-else class="d-flex flex-wrap ga-2">
|
||||
<v-chip v-for="key in displayKeys" :key="key" size="x-small" label color="primary">
|
||||
{{ key.length > 20 ? key.slice(0, 20) + '...' : key }}
|
||||
</v-chip>
|
||||
<v-chip v-if="Object.keys(modelValue).length > maxDisplayItems" size="x-small" label color="grey-lighten-1">
|
||||
+{{ Object.keys(modelValue).length - maxDisplayItems }}
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- Key-Value Management Dialog -->
|
||||
<v-dialog v-model="dialog" max-width="600px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
|
||||
{{ dialogTitle }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-4" style="max-height: 400px; overflow-y: auto;">
|
||||
<div v-if="localKeyValuePairs.length > 0">
|
||||
<div v-for="(pair, index) in localKeyValuePairs" :key="index" class="key-value-pair">
|
||||
<v-row no-gutters align="center" class="mb-2">
|
||||
<v-col cols="4">
|
||||
<v-text-field
|
||||
v-model="pair.key"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
placeholder="键名"
|
||||
@blur="updateKey(index, pair.key)"
|
||||
></v-text-field>
|
||||
</v-col>
|
||||
<v-col cols="7" class="pl-2 d-flex align-center justify-end">
|
||||
<v-text-field
|
||||
v-if="pair.type === 'string'"
|
||||
v-model="pair.value"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
placeholder="字符串值"
|
||||
></v-text-field>
|
||||
<v-text-field
|
||||
v-else-if="pair.type === 'number'"
|
||||
v-model.number="pair.value"
|
||||
type="number"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
placeholder="数值"
|
||||
></v-text-field>
|
||||
<v-switch
|
||||
v-else-if="pair.type === 'boolean'"
|
||||
v-model="pair.value"
|
||||
density="compact"
|
||||
hide-details
|
||||
color="primary"
|
||||
></v-switch>
|
||||
</v-col>
|
||||
<v-col cols="1" class="pl-2">
|
||||
<v-btn
|
||||
icon
|
||||
variant="text"
|
||||
size="small"
|
||||
color="error"
|
||||
@click="removeKeyValuePair(index)"
|
||||
>
|
||||
<v-icon>mdi-delete</v-icon>
|
||||
</v-btn>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-code-json</v-icon>
|
||||
<p class="text-grey mt-4">暂无参数</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<!-- Add new key-value pair section -->
|
||||
<v-card-text class="pa-4">
|
||||
<div class="d-flex align-center ga-2">
|
||||
<v-text-field
|
||||
v-model="newKey"
|
||||
label="新键名"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
></v-text-field>
|
||||
<v-select
|
||||
v-model="newValueType"
|
||||
:items="['string', 'number', 'boolean']"
|
||||
label="值类型"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
style="max-width: 120px;"
|
||||
></v-select>
|
||||
<v-btn @click="addKeyValuePair" variant="tonal" color="primary">
|
||||
<v-icon>mdi-plus</v-icon>
|
||||
添加
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelDialog">取消</v-btn>
|
||||
<v-btn color="primary" @click="confirmDialog">确认</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: Object,
|
||||
required: true
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '修改'
|
||||
},
|
||||
dialogTitle: {
|
||||
type: String,
|
||||
default: '修改键值对'
|
||||
},
|
||||
maxDisplayItems: {
|
||||
type: Number,
|
||||
default: 1
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
|
||||
const dialog = ref(false)
|
||||
const localKeyValuePairs = ref([])
|
||||
const originalKeyValuePairs = ref([])
|
||||
const newKey = ref('')
|
||||
const newValueType = ref('string')
|
||||
|
||||
// 计算要显示的键名
|
||||
const displayKeys = computed(() => {
|
||||
return Object.keys(props.modelValue).slice(0, props.maxDisplayItems)
|
||||
})
|
||||
|
||||
// 监听 modelValue 变化,主要用于初始化
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
// This watch is primarily for initialization or external changes
|
||||
// The dialog-based editing handles internal updates
|
||||
}, { immediate: true })
|
||||
|
||||
function initializeLocalKeyValuePairs() {
|
||||
localKeyValuePairs.value = []
|
||||
for (const [key, value] of Object.entries(props.modelValue)) {
|
||||
localKeyValuePairs.value.push({
|
||||
key: key,
|
||||
value: value,
|
||||
type: typeof value // Store the original type
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function openDialog() {
|
||||
initializeLocalKeyValuePairs()
|
||||
originalKeyValuePairs.value = JSON.parse(JSON.stringify(localKeyValuePairs.value)) // Deep copy
|
||||
newKey.value = ''
|
||||
newValueType.value = 'string'
|
||||
dialog.value = true
|
||||
}
|
||||
|
||||
function addKeyValuePair() {
|
||||
const key = newKey.value.trim()
|
||||
if (key !== '') {
|
||||
const isKeyExists = localKeyValuePairs.value.some(pair => pair.key === key)
|
||||
if (isKeyExists) {
|
||||
alert('键名已存在')
|
||||
return
|
||||
}
|
||||
|
||||
let defaultValue
|
||||
switch (newValueType.value) {
|
||||
case 'number':
|
||||
defaultValue = 0
|
||||
break
|
||||
case 'boolean':
|
||||
defaultValue = false
|
||||
break
|
||||
default: // string
|
||||
defaultValue = ""
|
||||
break
|
||||
}
|
||||
|
||||
localKeyValuePairs.value.push({
|
||||
key: key,
|
||||
value: defaultValue,
|
||||
type: newValueType.value
|
||||
})
|
||||
newKey.value = ''
|
||||
}
|
||||
}
|
||||
|
||||
function removeKeyValuePair(index) {
|
||||
localKeyValuePairs.value.splice(index, 1)
|
||||
}
|
||||
|
||||
function updateKey(index, newKey) {
|
||||
const originalKey = localKeyValuePairs.value[index].key
|
||||
// 如果键名没有改变,则不执行任何操作
|
||||
if (originalKey === newKey) return
|
||||
|
||||
// 检查新键名是否已存在
|
||||
const isKeyExists = localKeyValuePairs.value.some((pair, i) => i !== index && pair.key === newKey)
|
||||
if (isKeyExists) {
|
||||
// 如果键名已存在,提示用户并恢复原值
|
||||
alert('键名已存在')
|
||||
// 将键名恢复为修改前的原始值
|
||||
localKeyValuePairs.value[index].key = originalKey
|
||||
return
|
||||
}
|
||||
|
||||
// 更新本地副本
|
||||
localKeyValuePairs.value[index].key = newKey
|
||||
}
|
||||
|
||||
function confirmDialog() {
|
||||
const updatedValue = {}
|
||||
for (const pair of localKeyValuePairs.value) {
|
||||
let convertedValue = pair.value
|
||||
// 根据声明的类型进行转换
|
||||
switch (pair.type) {
|
||||
case 'number':
|
||||
// 尝试转换为数字,如果失败则保持原值(或设为默认值0)
|
||||
convertedValue = Number(pair.value)
|
||||
// 可选:检查是否为有效数字,无效则设为0或报错
|
||||
// if (isNaN(convertedValue)) convertedValue = 0;
|
||||
break
|
||||
case 'boolean':
|
||||
// 布尔值通常由 v-switch 正确处理,但为保险起见可以显式转换
|
||||
// 注意:在 JavaScript 中,只有严格的 false, 0, "", null, undefined, NaN 会被转换为 false
|
||||
// 这里直接赋值 pair.value 应该是安全的,因为 v-model 绑定的就是布尔值
|
||||
// convertedValue = Boolean(pair.value)
|
||||
break
|
||||
case 'string':
|
||||
default:
|
||||
// 默认转换为字符串
|
||||
convertedValue = String(pair.value)
|
||||
break
|
||||
}
|
||||
updatedValue[pair.key] = convertedValue
|
||||
}
|
||||
emit('update:modelValue', updatedValue)
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function cancelDialog() {
|
||||
// Reset to original state
|
||||
localKeyValuePairs.value = JSON.parse(JSON.stringify(originalKeyValuePairs.value))
|
||||
dialog.value = false
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.key-value-pair {
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
@@ -1,20 +0,0 @@
|
||||
<script setup lang="ts">
|
||||
const props = defineProps({
|
||||
title: String
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<v-card variant="outlined" elevation="0" class="withbg">
|
||||
<v-card-item>
|
||||
<div class="d-sm-flex align-center justify-space-between">
|
||||
<v-card-title>{{ props.title }}</v-card-title>
|
||||
<slot name="action"></slot>
|
||||
</div>
|
||||
</v-card-item>
|
||||
<v-divider></v-divider>
|
||||
<v-card-text>
|
||||
<slot />
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</template>
|
||||
@@ -73,6 +73,7 @@
|
||||
"disabled": "已禁用",
|
||||
"delete": "删除",
|
||||
"edit": "编辑",
|
||||
"copy": "复制",
|
||||
"noData": "暂无数据"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
"subtitle": "管理和查看用户对话历史记录",
|
||||
"filters": {
|
||||
"title": "筛选条件",
|
||||
"platform": "平台",
|
||||
"platform": "消息平台 ID",
|
||||
"type": "类型",
|
||||
"search": "搜索关键词",
|
||||
"reset": "重置"
|
||||
@@ -15,9 +15,9 @@
|
||||
"table": {
|
||||
"headers": {
|
||||
"title": "对话标题",
|
||||
"platform": "平台",
|
||||
"platform": "消息平台 ID",
|
||||
"type": "类型",
|
||||
"sessionId": "ID",
|
||||
"sessionId": "ID (UMO)",
|
||||
"createdAt": "创建时间",
|
||||
"updatedAt": "更新时间",
|
||||
"actions": "操作"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"refresh": "刷新",
|
||||
"edit": "编辑",
|
||||
"apply": "应用批量设置",
|
||||
"editName": "编辑会话名称",
|
||||
"editName": "备注",
|
||||
"save": "保存",
|
||||
"cancel": "取消"
|
||||
},
|
||||
@@ -22,13 +22,13 @@
|
||||
"table": {
|
||||
"headers": {
|
||||
"sessionStatus": "会话状态",
|
||||
"sessionInfo": "会话信息",
|
||||
"sessionInfo": "ID (UMO)",
|
||||
"persona": "人格",
|
||||
"chatProvider": "Chat Provider",
|
||||
"sttProvider": "STT Provider",
|
||||
"ttsProvider": "TTS Provider",
|
||||
"llmStatus": "LLM启停",
|
||||
"ttsStatus": "TTS启停",
|
||||
"chatProvider": "聊天模型",
|
||||
"sttProvider": "语音识别模型",
|
||||
"ttsProvider": "语音合成模型",
|
||||
"llmStatus": "启用 LLM",
|
||||
"ttsStatus": "启用 TTS",
|
||||
"pluginManagement": "插件管理"
|
||||
}
|
||||
},
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user