Compare commits

..

4 Commits

Author SHA1 Message Date
Soulter
6a59e894e1 Merge remote-tracking branch 'origin/master' into dev 2025-08-02 14:05:38 +08:00
Raven95676
fb10faa2dc Merge branch 'master' into dev 2025-07-26 19:13:58 +08:00
Raven95676
e4df0e83ed Merge branch 'releases/3.5.23' into dev 2025-07-26 16:49:27 +08:00
Raven95676
4a22664b8e Merge branch 'releases/3.5.23' into dev 2025-07-26 16:34:46 +08:00
163 changed files with 7872 additions and 12606 deletions

View File

@@ -0,0 +1,31 @@
---
name: '🥳 发布插件'
title: "[Plugin] 插件名"
about: 提交插件到插件市场
labels: [ "plugin-publish" ]
assignees: ''
---
欢迎发布插件到插件市场!
## 插件基本信息
请将插件信息填写到下方的 Json 代码块中。`tags`(插件标签)和 `social_link`(社交链接)选填。
```json
{
"name": "插件名",
"desc": "插件介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": ""
}
```
## 检查
- [ ] 我的插件经过完整的测试
- [ ] 我的插件不包含恶意代码
- [ ] 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。

View File

@@ -1,56 +0,0 @@
name: 🥳 发布插件
description: 提交插件到插件市场
title: "[Plugin] 插件名"
labels: ["plugin-publish"]
assignees: []
body:
- type: markdown
attributes:
value: |
欢迎发布插件到插件市场!
- type: markdown
attributes:
value: |
## 插件基本信息
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
- type: textarea
id: plugin-info
attributes:
label: 插件信息
description: 请在下方代码块中填写您的插件信息确保反引号包裹了JSON
value: |
```json
{
"name": "插件名",
"desc": "插件介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": ""
}
```
validations:
required: true
- type: markdown
attributes:
value: |
## 检查
- type: checkboxes
id: checks
attributes:
label: 插件检查清单
description: 请确认以下所有项目
options:
- label: 我的插件经过完整的测试
required: true
- label: 我的插件不包含恶意代码
required: true
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true

View File

@@ -1,63 +0,0 @@
# AstrBot Development Instructions
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
## Working Effectively
### Bootstrap and Install Dependencies
- **Python 3.10+ required** - Check `.python-version` file
- Install UV package manager: `pip install uv`
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
- Create required directories: `mkdir -p data/plugins data/config data/temp`
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
- Navigate to dashboard: `cd dashboard`
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
- Dashboard creates optimized production build in `dashboard/dist/`
### Testing
- Do not generate test files for now.
### Code Quality and Linting
- Install ruff linter: `uv add --dev ruff`
- Check code style: `uv run ruff check .` -- takes <1 second
- Check formatting: `uv run ruff format --check .` -- takes <1 second
- Fix formatting: `uv run ruff format .`
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
### Common Issues and Workarounds
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
## CI/CD Integration
- GitHub Actions workflows in `.github/workflows/`
- Docker builds supported via `Dockerfile`
- Pre-commit hooks enforce ruff formatting and linting
## Docker Support
- Primary deployment method: `docker run soulter/astrbot:latest`
- Compose file available: `compose.yml`
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
- Volume mount required: `./data:/AstrBot/data`
## Multi-language Support
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
- UI supports internationalization
- Default language is Chinese
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.

View File

@@ -13,7 +13,7 @@ jobs:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Dashboard Build
run: |
@@ -70,7 +70,7 @@ jobs:
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5

View File

@@ -56,7 +56,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
uses: actions/checkout@v4
with:
fetch-depth: 0

View File

@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: npm install, build
run: |
@@ -25,8 +25,6 @@ jobs:
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
mkdir -p dashboard/dist/assets
echo $COMMIT_SHA > dashboard/dist/assets/version
cd dashboard
zip -r dist.zip dist
- name: Archive production artifacts
uses: actions/upload-artifact@v4
@@ -35,13 +33,3 @@ jobs:
path: |
dashboard/dist
!dist/**/*.md
- name: Create GitHub Release
uses: ncipollo/release-action@v1
with:
tag: release-${{ github.sha }}
owner: AstrBotDevs
repo: astrbot-release-harbour
body: "Automated release from commit ${{ github.sha }}"
token: ${{ secrets.ASTRBOT_HARBOUR_TOKEN }}
artifacts: "dashboard/dist.zip"

View File

@@ -12,7 +12,7 @@ jobs:
steps:
- name: Pull The Codes
uses: actions/checkout@v5
uses: actions/checkout@v4
with:
fetch-depth: 0 # Must be 0 so we can fetch tags

View File

@@ -1,4 +1,4 @@
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
<p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
@@ -25,52 +25,59 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
> [!WARNING]
>
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
## ✨ 近期更新
<details><summary>1. AstrBot 现已自带知识库能力</summary>
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
![image](https://github.com/user-attachments/assets/28b639b0-bb5c-4958-8e94-92ae8cfd1ab4)
</details>
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富
5. **WebUI**。可视化配置和管理机器人,功能齐全
> [!NOTE]
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper
2. **多消息平台接入**。支持接入 QQOneBot、QQ 官方机器人平台、QQ 频道、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
> [!TIP]
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。
## ✨ 使用方式
#### Docker 部署
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
#### 宝塔面板部署
AstrBot 与宝塔面板合作,已上架至宝塔面板。
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### 1Panel 部署
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
#### 在 雨云 上部署
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### 在 Replit 上部署
社区贡献的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### Windows 一键安装器部署
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### 宝塔面板部署
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### CasaOS 部署
社区贡献的部署方式。
@@ -94,8 +101,24 @@ git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
或者,直接通过 uvx 安装 AstrBot
```bash
mkdir astrbot && cd astrbot
uvx astrbot init
# uvx astrbot run
```
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
#### 在 Replit 上部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### 在 雨云 上部署
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
## ⚡ 消息平台支持情况
| 平台 | 支持性 |

View File

@@ -3,18 +3,5 @@ from astrbot import logger
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_agent as agent
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
__all__ = [
"AstrBotConfig",
"logger",
"html_renderer",
"llm_tool",
"agent",
"sp",
"ToolSet",
"FunctionTool",
"BaseFunctionToolExecutor",
]
__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"]

View File

@@ -117,9 +117,6 @@ def build_plug_list(plugins_dir: Path) -> list:
# 从 metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir)
if "desc" not in metadata and "description" in metadata:
metadata["desc"] = metadata["description"]
# 如果成功加载元数据,添加到结果列表
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]

View File

@@ -1,4 +1,5 @@
import os
import asyncio
from .log import LogManager, LogBroker # noqa
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
@@ -20,7 +21,7 @@ html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot")
db_helper = SQLiteDatabase(DB_PATH)
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
sp = SharedPreferences(db_helper=db_helper)
sp = SharedPreferences()
# 文件令牌服务
file_token_service = FileTokenService()
pip_installer = PipInstaller(

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
from .tool import FunctionTool
from typing import Generic
from .run_context import TContext
from .hooks import BaseAgentRunHooks
@dataclass
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str, FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -1,34 +0,0 @@
from typing import Generic
from .tool import FunctionTool
from .agent import Agent
from .run_context import TContext
class HandoffTool(FunctionTool, Generic[TContext]):
"""Handoff tool for delegating tasks to another agent."""
def __init__(
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
):
self.agent = agent
super().__init__(
name=f"transfer_to_{agent.name}",
parameters=parameters or self.default_parameters(),
description=agent.instructions or self.default_description(agent.name),
**kwargs,
)
def default_parameters(self) -> dict:
return {
"type": "object",
"properties": {
"input": {
"type": "string",
"description": "The input to be handed off to another agent. This should be a clear and concise request or task.",
},
},
}
def default_description(self, agent_name: str | None) -> str:
agent_name = agent_name or "another"
return f"Delegate tasks to {self.name} agent to handle the request."

View File

@@ -1,27 +0,0 @@
import mcp
from dataclasses import dataclass
from .run_context import ContextWrapper, TContext
from typing import Generic
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.agent.tool import FunctionTool
@dataclass
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_tool_start(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
): ...
async def on_tool_end(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
tool_result: mcp.types.CallToolResult | None,
): ...
async def on_agent_done(
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
): ...

View File

@@ -1,208 +0,0 @@
import asyncio
import logging
from datetime import timedelta
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core.utils.log_pipe import LogPipe
try:
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""快速测试 MCP 服务器可达性"""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
如果 `url` 参数存在:
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
# 处理 MCP 服务的错误日志
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
if not success:
raise Exception(error_msg)
if cfg.get("transport") != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str):
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
), # type: ignore
),
)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done

View File

@@ -1,12 +0,0 @@
from dataclasses import dataclass
import typing as T
from astrbot.core.message.message_event_result import MessageChain
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData

View File

@@ -1,17 +0,0 @@
from dataclasses import dataclass
from typing import Any, Generic
from typing_extensions import TypeVar
from astrbot.core.platform.astr_message_event import AstrMessageEvent
TContext = TypeVar("TContext", default=Any)
@dataclass
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
event: AstrMessageEvent
NoContext = ContextWrapper[None]

View File

@@ -1,3 +0,0 @@
from .base import BaseAgentRunner
__all__ = ["BaseAgentRunner"]

View File

@@ -1,256 +0,0 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Literal, Any, Optional
from .mcp_client import MCPClient
@dataclass
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str | None = None
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str | None = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient | None = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
def __dict__(self) -> dict[str, Any]:
"""将 FunctionTool 转换为字典格式"""
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
"origin": self.origin,
"mcp_server_name": self.mcp_server_name,
}
class ToolSet:
"""A set of function tools that can be used in function calling.
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
"""Check if the tool set is empty."""
return len(self.tools) == 0
def add_tool(self, tool: FunctionTool):
"""Add a tool to the set."""
# 检查是否已存在同名工具
for i, existing_tool in enumerate(self.tools):
if existing_tool.name == tool.name:
self.tools[i] = tool
return
self.tools.append(tool)
def remove_tool(self, name: str):
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
def get_tool(self, name: str) -> Optional[FunctionTool]:
"""Get a tool by its name."""
for tool in self.tools:
if tool.name == name:
return tool
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FunctionTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.add_tool(_func)
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
def remove_func(self, name: str):
"""Remove a function tool by its name."""
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> list[FunctionTool]:
"""Get all function tools."""
return self.get_tool(name)
@property
def func_list(self) -> list[FunctionTool]:
"""Get the list of function tools."""
return self.tools
def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
"""Convert tools to OpenAI API function calling schema format."""
result = []
for tool in self.tools:
func_def = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
},
}
if tool.parameters.get("properties") or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
return result
def anthropic_schema(self) -> list[dict]:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
}
result.append(tool_def)
return result
def google_schema(self) -> dict:
"""Convert tools to Google GenAI API format."""
def convert_schema(schema: dict) -> dict:
"""Convert schema to Gemini API format."""
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
if properties:
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = [
{
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
declarations = {}
if tools:
declarations["function_declarations"] = tools
return declarations
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
return self.openai_schema(omit_empty_parameter_field)
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
def get_func_desc_anthropic_style(self):
return self.anthropic_schema()
@deprecated(reason="Use google_schema() instead", version="4.0.0")
def get_func_desc_google_genai_style(self):
return self.google_schema()
def names(self) -> list[str]:
"""获取所有工具的名称列表"""
return [tool.name for tool in self.tools]
def __len__(self):
return len(self.tools)
def __bool__(self):
return len(self.tools) > 0
def __iter__(self):
return iter(self.tools)
def __repr__(self):
return f"ToolSet(tools={self.tools})"
def __str__(self):
return f"ToolSet(tools={self.tools})"

View File

@@ -1,11 +0,0 @@
import mcp
from typing import Any, Generic, AsyncGenerator
from .run_context import TContext, ContextWrapper
from .tool import FunctionTool
class BaseFunctionToolExecutor(Generic[TContext]):
@classmethod
async def execute(
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...

View File

@@ -1,11 +0,0 @@
from dataclasses import dataclass
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@dataclass
class AstrAgentContext:
provider: Provider
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool

View File

@@ -1,276 +0,0 @@
import os
import uuid
from astrbot.core import AstrBotConfig, logger
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
from typing import TypeVar, TypedDict
_VT = TypeVar("_VT")
class ConfInfo(TypedDict):
"""Configuration information for a specific session or platform."""
id: str # UUID of the configuration or "default"
umop: list[str] # Unified Message Origin Pattern
name: str
path: str # File name to the configuration file
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
id="default",
umop=["::"],
name="default",
path=ASTRBOT_CONFIG_PATH,
)
class AstrBotConfigManager:
"""A class to manage the system configuration of AstrBot, aka ACM"""
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
self.sp = sp
self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config
self._load_all_configs()
def _load_all_configs(self):
"""Load all configurations from the shared preferences."""
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
for uuid_, meta in abconf_data.items():
filename = meta["path"]
conf_path = os.path.join(get_astrbot_config_path(), filename)
if os.path.exists(conf_path):
conf = AstrBotConfig(config_path=conf_path)
self.confs[uuid_] = conf
else:
logger.warning(
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
)
continue
def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
p1_ls = p1.split(":")
p2_ls = p2.split(":")
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
Returns:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "umop": list, "path": str, "name": str }
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if isinstance(umo, MessageSession):
umo = str(umo)
else:
try:
umo = str(MessageSession.from_str(umo)) # validate
except Exception:
return DEFAULT_CONFIG_CONF_INFO
for uuid_, meta in abconf_data.items():
for pattern in meta["umop"]:
if self._is_umo_match(pattern, umo):
return ConfInfo(**meta, id=uuid_)
return DEFAULT_CONFIG_CONF_INFO
def _save_conf_mapping(
self,
abconf_path: str,
abconf_id: str,
umo_parts: list[str] | list[MessageSession],
abconf_name: str | None = None,
) -> None:
"""保存配置文件的映射关系"""
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
"umop": umo_parts,
"path": abconf_path,
"name": random_word,
}
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
if not umo:
return self.confs["default"]
if isinstance(umo, MessageSession):
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
uuid_ = self._load_conf_mapping(umo)["id"]
conf = self.confs.get(uuid_)
if not conf:
conf = self.confs["default"] # default MUST exists
return conf
@property
def default_conf(self) -> AstrBotConfig:
"""获取默认配置文件"""
return self.confs["default"]
def get_conf_info(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件元数据"""
if isinstance(umo, MessageSession):
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
return self._load_conf_mapping(umo)
def get_conf_list(self) -> list[ConfInfo]:
"""获取所有配置文件的元数据列表"""
conf_list = []
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
abconf_mapping = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
for uuid_, meta in abconf_mapping.items():
conf_list.append(ConfInfo(**meta, id=uuid_))
return conf_list
def create_conf(
self,
umo_parts: list[str] | list[MessageSession],
config: dict = DEFAULT_CONFIG,
name: str | None = None,
) -> str:
"""
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
"""
conf_uuid = str(uuid.uuid4())
conf_file_name = f"abconf_{conf_uuid}.json"
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
conf = AstrBotConfig(config_path=conf_path, default_config=config)
conf.save_config()
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
self.confs[conf_uuid] = conf
return conf_uuid
def delete_conf(self, conf_id: str) -> bool:
"""删除指定配置文件
Args:
conf_id: 配置文件的 UUID
Returns:
bool: 删除是否成功
Raises:
ValueError: 如果试图删除默认配置文件
"""
if conf_id == "default":
raise ValueError("不能删除默认配置文件")
# 从映射中移除
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
# 获取配置文件路径
conf_path = os.path.join(
get_astrbot_config_path(), abconf_data[conf_id]["path"]
)
# 删除配置文件
try:
if os.path.exists(conf_path):
os.remove(conf_path)
logger.info(f"已删除配置文件: {conf_path}")
except Exception as e:
logger.error(f"删除配置文件 {conf_path} 失败: {e}")
return False
# 从内存中移除
if conf_id in self.confs:
del self.confs[conf_id]
# 从映射中移除
del abconf_data[conf_id]
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
logger.info(f"成功删除配置文件 {conf_id}")
return True
def update_conf_info(
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
) -> bool:
"""更新配置文件信息
Args:
conf_id: 配置文件的 UUID
name: 新的配置文件名称 (可选)
umo_parts: 新的 UMO 部分列表 (可选)
Returns:
bool: 更新是否成功
"""
if conf_id == "default":
raise ValueError("不能更新默认配置文件的信息")
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
# 更新名称
if name is not None:
abconf_data[conf_id]["name"] = name
# 更新 UMO 部分
if umo_parts is not None:
# 验证 UMO 部分格式
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data[conf_id]["umop"] = umo_parts
# 保存更新
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
logger.info(f"成功更新配置文件 {conf_id} 的信息")
return True
def g(
self, umo: str | None = None, key: str | None = None, default: _VT = None
) -> _VT:
"""获取配置项。umo 为 None 时使用默认配置"""
if umo is None:
return self.confs["default"].get(key, default)
conf = self.get_conf(umo)
return conf.get(key, default)

File diff suppressed because it is too large Load Diff

View File

@@ -5,44 +5,40 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
from astrbot.core.db.po import Conversation
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: Dict[str, str] = {}
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp())
updated_at = int(conv_v2.updated_at.timestamp())
return Conversation(
platform_id=conv_v2.platform_id,
user_id=conv_v2.user_id,
cid=conv_v2.conversation_id,
history=json.dumps(conv_v2.content or []),
title=conv_v2.title,
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
)
def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save())
async def new_conversation(
self,
unified_msg_origin: str,
platform_id: str | None = None,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
) -> str:
async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
"""新建对话,并将当前会话的对话转移到新对话
Args:
@@ -50,23 +46,11 @@ class ConversationManager:
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
if not platform_id:
# 如果没有提供 platform_id则从 unified_msg_origin 中解析
parts = unified_msg_origin.split(":")
if len(parts) >= 3:
platform_id = parts[0]
if not platform_id:
platform_id = "unknown"
conv = await self.db.create_conversation(
user_id=unified_msg_origin,
platform_id=platform_id,
content=content,
title=title,
persona_id=persona_id,
)
self.session_conversations[unified_msg_origin] = conv.conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
return conv.conversation_id
conversation_id = str(uuid.uuid4())
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话
@@ -76,10 +60,10 @@ class ConversationManager:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str | None = None
self, unified_msg_origin: str, conversation_id: str = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
@@ -87,18 +71,13 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
f = False
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
f = True
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
if f:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
del self.session_conversations[unified_msg_origin]
sp.put("session_conversation", self.session_conversations)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
"""获取会话当前的对话 ID
Args:
@@ -106,19 +85,14 @@ class ConversationManager:
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
ret = self.session_conversations.get(unified_msg_origin, None)
if not ret:
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
if ret:
self.session_conversations[unified_msg_origin] = ret
return ret
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(
self,
unified_msg_origin: str,
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation | None:
) -> Conversation:
"""获取会话的对话
Args:
@@ -127,74 +101,27 @@ class ConversationManager:
Returns:
conversation (Conversation): 对话对象
"""
conv = await self.db.get_conversation_by_id(cid=conversation_id)
conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
if not conv and create_if_not_exists:
# 如果对话不存在且需要创建,则新建一个对话
conversation_id = await self.new_conversation(unified_msg_origin)
conv = await self.db.get_conversation_by_id(cid=conversation_id)
conv_res = None
if conv:
conv_res = self._convert_conv_from_v2_to_v1(conv)
return conv_res
return self.db.get_conversation_by_user_id(
unified_msg_origin, conversation_id
)
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(
self, unified_msg_origin: str | None = None, platform_id: str | None = None
) -> List[Conversation]:
"""获取对话列表
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
"""获取会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversations (List[Conversation]): 对话对象列表
"""
convs = await self.db.get_conversations(
user_id=unified_msg_origin, platform_id=platform_id
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res
async def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[Conversation], int]:
"""获取过滤后的对话列表
Args:
page (int): 页码, 默认为 1
page_size (int): 每页大小, 默认为 20
platform_ids (list[str]): 平台 ID 列表, 可选
search_query (str): 搜索查询字符串, 可选
Returns:
conversations (list[Conversation]): 对话对象列表
"""
convs, cnt = await self.db.get_filtered_conversations(
page=page,
page_size=page_size,
platform_ids=platform_ids,
search_query=search_query,
**kwargs,
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res, cnt
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(
self,
unified_msg_origin: str,
conversation_id: str | None = None,
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
):
"""更新会话的对话
@@ -203,55 +130,40 @@ class ConversationManager:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if not conversation_id:
# 如果没有提供 conversation_id则获取当前的
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
if conversation_id:
await self.db.update_conversation(
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
title=title,
persona_id=persona_id,
content=history,
history=json.dumps(history),
)
async def update_conversation_title(
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
):
async def update_conversation_title(self, unified_msg_origin: str, title: str):
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
Deprecated:
Use `update_conversation` with `title` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
title=title,
)
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin, cid=conversation_id, title=title
)
async def update_conversation_persona_id(
self,
unified_msg_origin: str,
persona_id: str,
conversation_id: str | None = None,
self, unified_msg_origin: str, persona_id: str
):
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
Deprecated:
Use `update_conversation` with `persona_id` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
persona_id=persona_id,
)
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
)
async def get_human_readable_context(
self, unified_msg_origin, conversation_id, page=1, page_size=10

View File

@@ -15,23 +15,20 @@ import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config, html_renderer
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger, sp
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
@@ -50,23 +47,12 @@ class AstrBotCoreLifecycle:
self.db = db # 初始化数据库
# 设置代理
proxy_config = self.astrbot_config.get("http_proxy", "")
if proxy_config != "":
os.environ["https_proxy"] = proxy_config
os.environ["http_proxy"] = proxy_config
logger.debug(f"Using proxy: {proxy_config}")
# 设置 no_proxy
no_proxy_list = self.astrbot_config.get("no_proxy", [])
os.environ["no_proxy"] = ",".join(no_proxy_list)
else:
# 清空代理环境变量
if "https_proxy" in os.environ:
del os.environ["https_proxy"]
if "http_proxy" in os.environ:
del os.environ["http_proxy"]
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
logger.debug("HTTP proxy cleared")
if self.astrbot_config.get("http_proxy", ""):
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
if proxy := os.environ.get("https_proxy"):
logger.debug(f"Using proxy: {proxy}")
os.environ["no_proxy"] = "localhost"
async def initialize(self):
"""
@@ -80,26 +66,11 @@ class AstrBotCoreLifecycle:
else:
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
await self.db.initialize()
await html_renderer.initialize()
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
default_config=self.astrbot_config, sp=sp
)
# 初始化事件队列
self.event_queue = Queue()
# 初始化人格管理器
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
await self.persona_mgr.initialize()
# 初始化供应商管理器
self.provider_manager = ProviderManager(
self.astrbot_config_mgr, self.db, self.persona_mgr
)
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
@@ -107,9 +78,6 @@ class AstrBotCoreLifecycle:
# 初始化对话管理器
self.conversation_manager = ConversationManager(self.db)
# 初始化平台消息历史管理器
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
@@ -118,9 +86,6 @@ class AstrBotCoreLifecycle:
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.platform_message_history_manager,
self.persona_mgr,
self.astrbot_config_mgr,
)
# 初始化插件管理器
@@ -133,16 +98,16 @@ class AstrBotCoreLifecycle:
await self.provider_manager.initialize()
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
self.pipeline_scheduler = PipelineScheduler(
PipelineContext(self.astrbot_config, self.plugin_manager)
)
await self.pipeline_scheduler.initialize()
# 初始化更新器
self.astrbot_updator = AstrBotUpdator()
# 初始化事件总线
self.event_bus = EventBus(
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
)
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
# 记录启动时间
self.start_time = int(time.time())
@@ -259,39 +224,6 @@ class AstrBotCoreLifecycle:
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
asyncio.create_task(
platform_inst.run(),
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
)
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
)
return tasks
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
"""加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
mapping = {}
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
mapping[conf_id] = scheduler
return mapping
async def reload_pipeline_scheduler(self, conf_id: str):
"""重新加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
if not ab_config:
raise ValueError(f"配置文件 {conf_id} 不存在")
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
self.pipeline_scheduler_mapping[conf_id] = scheduler

View File

@@ -1,20 +1,7 @@
import abc
import datetime
import typing as T
from deprecated import deprecated
from dataclasses import dataclass
from astrbot.core.db.po import (
Stats,
PlatformStat,
ConversationV2,
PlatformMessageHistory,
Attachment,
Persona,
Preference,
)
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from typing import List, Dict, Any, Tuple
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@dataclass
@@ -23,262 +10,152 @@ class BaseDatabase(abc.ABC):
数据库基类
"""
DATABASE_URL = ""
def __init__(self) -> None:
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.engine, class_=AsyncSession, expire_on_commit=False
)
async def initialize(self):
"""初始化数据库连接"""
pass
@asynccontextmanager
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
"""Get a database session."""
if not self.inited:
await self.initialize()
self.inited = True
async with self.AsyncSessionLocal() as session:
yield session
def insert_base_metrics(self, metrics: dict):
"""插入基础指标数据"""
self.insert_platform_metrics(metrics["platform_stats"])
self.insert_plugin_metrics(metrics["plugin_stats"])
self.insert_command_metrics(metrics["command_stats"])
self.insert_llm_metrics(metrics["llm_stats"])
@abc.abstractmethod
def insert_platform_metrics(self, metrics: dict):
"""插入平台指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_plugin_metrics(self, metrics: dict):
"""插入插件指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_command_metrics(self, metrics: dict):
"""插入指令指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_llm_metrics(self, metrics: dict):
"""插入 LLM 指标数据"""
raise NotImplementedError
@abc.abstractmethod
def update_llm_history(self, session_id: str, content: str, provider_type: str):
"""更新 LLM 历史记录。当不存在 session_id 时插入"""
raise NotImplementedError
@abc.abstractmethod
def get_llm_history(
self, session_id: str = None, provider_type: str = None
) -> List[LLMHistory]:
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据"""
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_total_message_count(self) -> int:
"""获取总消息数"""
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据(合并)"""
raise NotImplementedError
# New methods in v4.0.0
@abc.abstractmethod
def insert_atri_vision_data(self, vision_data: ATRIVision):
"""插入 ATRI 视觉数据"""
raise NotImplementedError
@abc.abstractmethod
async def insert_platform_stats(
self,
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime.datetime | None = None,
) -> None:
"""Insert a new platform statistic record."""
...
def get_atri_vision_data(self) -> List[ATRIVision]:
"""获取 ATRI 视觉数据"""
raise NotImplementedError
@abc.abstractmethod
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
...
def get_atri_vision_data_by_path_or_id(
self, url_or_path: str, id: str
) -> ATRIVision:
"""通过 url 或 path 获取 ATRI 视觉数据"""
raise NotImplementedError
@abc.abstractmethod
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
...
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
"""通过 user_id 和 cid 获取 Conversation"""
raise NotImplementedError
@abc.abstractmethod
async def get_conversations(
self, user_id: str | None = None, platform_id: str | None = None
) -> list[ConversationV2]:
"""Get all conversations for a specific user and platform_id(optional).
content is not included in the result.
"""
...
def new_conversation(self, user_id: str, cid: str):
"""新建 Conversation"""
raise NotImplementedError
@abc.abstractmethod
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
"""Get a specific conversation by its ID."""
...
def get_conversations(self, user_id: str) -> List[Conversation]:
raise NotImplementedError
@abc.abstractmethod
async def get_all_conversations(
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新 Conversation"""
raise NotImplementedError
@abc.abstractmethod
def delete_conversation(self, user_id: str, cid: str):
"""删除 Conversation"""
raise NotImplementedError
@abc.abstractmethod
def update_conversation_title(self, user_id: str, cid: str, title: str):
"""更新 Conversation 标题"""
raise NotImplementedError
@abc.abstractmethod
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
"""更新 Conversation Persona ID"""
raise NotImplementedError
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> list[ConversationV2]:
"""Get all conversations with pagination."""
...
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页
Args:
page: 页码从1开始
page_size: 每页数量
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
@abc.abstractmethod
async def get_filtered_conversations(
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[ConversationV2], int]:
"""Get conversations filtered by platform IDs and search query."""
...
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表
@abc.abstractmethod
async def create_conversation(
self,
user_id: str,
platform_id: str,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
cid: str | None = None,
created_at: datetime.datetime | None = None,
updated_at: datetime.datetime | None = None,
) -> ConversationV2:
"""Create a new conversation."""
...
Args:
page: 页码
page_size: 每页数量
platforms: 平台筛选列表
message_types: 消息类型筛选列表
search_query: 搜索关键词
exclude_ids: 排除的用户ID列表
exclude_platforms: 排除的平台列表
@abc.abstractmethod
async def update_conversation(
self,
cid: str,
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
) -> None:
"""Update a conversation's history."""
...
@abc.abstractmethod
async def delete_conversation(self, cid: str) -> None:
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,
platform_id: str,
user_id: str,
content: list[dict],
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
"""Insert a new platform message history record."""
...
@abc.abstractmethod
async def delete_platform_message_offset(
self, platform_id: str, user_id: str, offset_sec: int = 86400
) -> None:
"""Delete platform message history records older than the specified offset."""
...
@abc.abstractmethod
async def get_platform_message_history(
self,
platform_id: str,
user_id: str,
page: int = 1,
page_size: int = 20,
) -> list[PlatformMessageHistory]:
"""Get platform message history for a specific user."""
...
@abc.abstractmethod
async def insert_attachment(
self,
path: str,
type: str,
mime_type: str,
):
"""Insert a new attachment record."""
...
@abc.abstractmethod
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
"""Get an attachment by its ID."""
...
@abc.abstractmethod
async def insert_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona:
"""Insert a new persona record."""
...
@abc.abstractmethod
async def get_persona_by_id(self, persona_id: str) -> Persona:
"""Get a persona by its ID."""
...
@abc.abstractmethod
async def get_personas(self) -> list[Persona]:
"""Get all personas for a specific bot."""
...
@abc.abstractmethod
async def update_persona(
self,
persona_id: str,
system_prompt: str | None = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona | None:
"""Update a persona's system prompt or begin dialogs."""
...
@abc.abstractmethod
async def delete_persona(self, persona_id: str) -> None:
"""Delete a persona by its ID."""
...
@abc.abstractmethod
async def insert_preference_or_update(
self, scope: str, scope_id: str, key: str, value: dict
) -> Preference:
"""Insert a new preference record."""
...
@abc.abstractmethod
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
"""Get a preference by scope ID and key."""
...
@abc.abstractmethod
async def get_preferences(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""Get all preferences for a specific scope ID or key."""
...
@abc.abstractmethod
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
"""Remove a preference by scope ID and key."""
...
@abc.abstractmethod
async def clear_preferences(self, scope: str, scope_id: str) -> None:
"""Clear all preferences for a specific scope ID."""
...
# @abc.abstractmethod
# async def insert_llm_message(
# self,
# cid: str,
# role: str,
# content: list,
# tool_calls: list = None,
# tool_call_id: str = None,
# parent_id: str = None,
# ) -> LLMMessage:
# """Insert a new LLM message into the conversation."""
# ...
# @abc.abstractmethod
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
# """Get all LLM messages for a specific conversation."""
# ...
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError

View File

@@ -1,64 +0,0 @@
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
from astrbot.core.config import AstrBotConfig
from astrbot.api import logger, sp
from .migra_3_to_4 import (
migration_conversation_table,
migration_platform_table,
migration_webchat_data,
migration_persona_data,
migration_preferences,
)
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
"""
检查是否需要进行数据库迁移
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4则需要进行迁移。
"""
data_v3_exists = os.path.exists(get_astrbot_data_path())
if not data_v3_exists:
return False
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_v4"
)
if migration_done:
return False
return True
async def do_migration_v4(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
astrbot_config: AstrBotConfig,
):
"""
执行数据库迁移
迁移旧的 webchat_conversation 表到新的 conversation 表。
迁移旧的 platform 到新的 platform_stats 表。
"""
if not await check_migration_needed_v4(db_helper):
return
logger.info("开始执行数据库迁移...")
# 执行会话表迁移
await migration_conversation_table(db_helper, platform_id_map)
# 执行人格数据迁移
await migration_persona_data(db_helper, astrbot_config)
# 执行 WebChat 数据迁移
await migration_webchat_data(db_helper, platform_id_map)
# 执行偏好设置迁移
await migration_preferences(db_helper,platform_id_map)
# 执行平台统计表迁移
await migration_platform_table(db_helper, platform_id_map)
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_v4", True)
logger.info("数据库迁移完成。")

View File

@@ -1,338 +0,0 @@
import json
import datetime
from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from .shared_preferences_v3 import sp as sp_v3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from sqlalchemy import text
"""
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
2. 迁移旧的 platform 到新的 platform_stats 表。
"""
def get_platform_id(
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
{"platform_id": old_platform_name, "platform_type": old_platform_name},
).get("platform_id", old_platform_name)
def get_platform_type(
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
{"platform_id": old_platform_name, "platform_type": old_platform_name},
).get("platform_type", old_platform_name)
async def migration_conversation_table(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id(
platform_id_map, session.platform_name
)
session.platform_id = platform_id # 更新平台名称为新的 ID
conv_v2 = ConversationV2(
user_id=str(session),
content=json.loads(conv.history) if conv.history else [],
platform_id=platform_id,
title=conv.title,
persona_id=conv.persona_id,
conversation_id=conv.cid,
created_at=datetime.datetime.fromtimestamp(conv.created_at),
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
)
dbsession.add(conv_v2)
except Exception as e:
logger.error(
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
async def migration_platform_table(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
secs_from_2023_4_10_to_now = (
datetime.datetime.now(datetime.timezone.utc)
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
).total_seconds()
offset_sec = int(secs_from_2023_4_10_to_now)
logger.info(f"迁移旧平台数据offset_sec: {offset_sec} 秒。")
stats = db_helper_v3.get_base_stats(offset_sec=offset_sec)
logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...")
platform_stats_v3 = stats.platform
if not platform_stats_v3:
logger.info("没有找到旧平台数据,跳过迁移。")
return
first_time_stamp = platform_stats_v3[0].timestamp
end_time_stamp = platform_stats_v3[-1].timestamp
start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时
end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时
idx = 0
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
total_buckets = (end_time - start_time) // 3600
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
if bucket_idx % 500 == 0:
progress = int((bucket_idx + 1) / total_buckets * 100)
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
cnt = 0
while (
idx < len(platform_stats_v3)
and platform_stats_v3[idx].timestamp < bucket_end
):
cnt += platform_stats_v3[idx].count
idx += 1
if cnt == 0:
continue
platform_id = get_platform_id(
platform_id_map, platform_stats_v3[idx].name
)
platform_type = get_platform_type(
platform_id_map, platform_stats_v3[idx].name
)
try:
await dbsession.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
count = platform_stats.count + EXCLUDED.count
"""),
{
"timestamp": datetime.datetime.fromtimestamp(
bucket_end, tz=datetime.timezone.utc
),
"platform_id": platform_id,
"platform_type": platform_type,
"count": cnt,
},
)
except Exception:
logger.error(
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
exc_info=True,
)
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
async def migration_webchat_data(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" in conv.user_id:
continue
platform_id = "webchat"
history = json.loads(conv.history) if conv.history else []
for msg in history:
type_ = msg.get("type") # user type, "bot" or "user"
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=conv.cid, # we use conv.cid as user_id for webchat
content=msg,
sender_id=type_,
sender_name=type_,
)
dbsession.add(new_history)
except Exception:
logger.error(
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
async def migration_persona_data(
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
):
"""
迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
"""
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
total_personas = len(v3_persona_config)
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
for idx, persona in enumerate(v3_persona_config):
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
progress = int((idx + 1) / total_personas * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
try:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
mood_prompt = ""
user_turn = True
for mood_dialog in mood_imitation_dialogs:
if user_turn:
mood_prompt += f"A: {mood_dialog}\n"
else:
mood_prompt += f"B: {mood_dialog}\n"
user_turn = not user_turn
system_prompt = persona.get("prompt", "")
if mood_prompt:
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
persona_new = await db_helper.insert_persona(
persona_id=persona["name"],
system_prompt=system_prompt,
begin_dialogs=begin_dialogs,
)
logger.info(
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
async def migration_preferences(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
# 1. global scope migration
keys = [
"inactivated_llm_tools",
"inactivated_plugins",
"curr_provider",
"curr_provider_tts",
"curr_provider_stt",
"alter_cmd",
]
for key in keys:
value = sp_v3.get(key)
if value is not None:
await sp.put_async("global", "global", key, value)
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
# 2. umo scope migration
session_conversation = sp_v3.get("session_conversation", default={})
for umo, conversation_id in session_conversation.items():
if not umo or not conversation_id:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
session_service_config = sp_v3.get("session_service_config", default={})
for umo, config in session_service_config.items():
if not umo or not config:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_service_config", config)
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
session_variables = sp_v3.get("session_variables", default={})
for umo, variables in session_variables.items():
if not umo or not variables:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_variables", variables)
except Exception as e:
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
session_provider_perf = sp_v3.get("session_provider_perf", default={})
for umo, perf in session_provider_perf.items():
if not umo or not perf:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
for provider_type, provider_id in perf.items():
await sp.put_async(
"umo", str(session), f"provider_perf_{provider_type}", provider_id
)
logger.info(
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
)
except Exception as e:
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)

View File

@@ -1,45 +0,0 @@
import json
import os
from typing import TypeVar
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
self._data = self._load_preferences()
def _load_preferences(self):
if os.path.exists(self.path):
try:
with open(self.path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
os.remove(self.path)
return {}
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
self._data[key] = value
self._save_preferences()
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
self._data.clear()
self._save_preferences()
sp = SharedPreferences()

View File

@@ -1,493 +0,0 @@
import sqlite3
import time
from astrbot.core.db.po import Platform, Stats
from typing import Tuple, List, Dict, Any
from dataclasses import dataclass
@dataclass
class Conversation:
"""LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
INIT_SQL = """
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';
"""
class SQLiteDatabase():
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path
sql = INIT_SQL
# 初始化数据库
self.conn = self._get_conn(self.db_path)
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
"""
PRAGMA table_info(webchat_conversation)
"""
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
"""
)
self.conn.commit()
if not has_persona_id:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
"""
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: Tuple = None):
conn = self.conn
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
conn = self._get_conn(self.db_path)
c = conn.cursor()
if params:
c.execute(sql, params)
c.close()
else:
c.execute(sql)
c.close()
conn.commit()
def insert_platform_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def insert_llm_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM platform
"""
+ where_clause
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform=platform)
def get_total_message_count(self) -> int:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT SUM(count) FROM platform
"""
)
res = c.fetchone()
c.close()
return res[0]
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT name, SUM(count), timestamp FROM platform
"""
+ where_clause
+ " GROUP BY name"
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform, [], [])
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
res = c.fetchone()
c.close()
if not res:
return
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
"""
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
""",
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
""",
(user_id,),
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
)
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新对话,并且同时更新时间"""
updated_at = int(time.time())
self._exec_sql(
"""
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
""",
(history, updated_at, user_id, cid),
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
""",
(title, user_id, cid),
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
""",
(persona_id, user_id, cid),
)
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
"""
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()

View File

@@ -1,233 +1,7 @@
import uuid
"""指标数据"""
from datetime import datetime, timezone
from dataclasses import dataclass, field
from sqlmodel import (
SQLModel,
Text,
JSON,
UniqueConstraint,
Field,
)
from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True):
"""This class represents the statistics of bot usage across different platforms.
Note: In astrbot v4, we moved `platform` table to here.
"""
__tablename__ = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
platform_id: str = Field(nullable=False)
platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc.
count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"timestamp",
"platform_id",
"platform_type",
name="uix_platform_stats",
),
)
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations"
inner_conversation_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conversation_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
content: Optional[list] = Field(default=None, sa_type=JSON)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
title: Optional[str] = Field(default=None, max_length=255)
persona_id: Optional[str] = Field(default=None)
__table_args__ = (
UniqueConstraint(
"conversation_id",
name="uix_conversation_id",
),
)
class Persona(SQLModel, table=True):
"""Persona is a set of instructions for LLMs to follow.
It can be used to customize the behavior of LLMs.
"""
__tablename__ = "personas"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with"""
tools: Optional[list] = Field(default=None, sa_type=JSON)
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"persona_id",
name="uix_persona_id",
),
)
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
__tablename__ = "preferences"
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
scope: str = Field(nullable=False)
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
scope_id: str = Field(nullable=False)
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
key: str = Field(nullable=False)
value: dict = Field(sa_type=JSON, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"scope",
"scope_id",
"key",
name="uix_preference_scope_scope_id_key",
),
)
class PlatformMessageHistory(SQLModel, table=True):
"""This class represents the message history for a specific platform.
It is used to store messages that are not LLM-generated, such as user messages
or platform-specific messages.
"""
__tablename__ = "platform_message_history"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
sender_name: Optional[str] = Field(
default=None
) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class Attachment(SQLModel, table=True):
"""This class represents attachments for messages in AstrBot.
Attachments can be images, files, or other media types.
"""
__tablename__ = "attachments"
inner_attachment_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
attachment_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
path: str = Field(nullable=False) # Path to the file on disk
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
mime_type: str = Field(nullable=False) # MIME type of the file
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"attachment_id",
name="uix_attachment_id",
),
)
@dataclass
class Conversation:
"""LLM 对话类
对于 WebChathistory 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
在 v4.0.0 版本及之后WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中,
"""
platform_id: str
user_id: str
cid: str
"""对话 ID, 是 uuid 格式的字符串"""
history: str = ""
"""字符串格式的对话列表。"""
title: str | None = ""
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
class Personality(TypedDict):
"""LLM 人格类。
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
prompt: str = ""
name: str = ""
begin_dialogs: list[str] = []
mood_imitation_dialogs: list[str] = []
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
_begin_dialogs_processed: list[dict] = []
_mood_imitation_dialogs_processed: str = ""
# ====
# Deprecated, and will be removed in future versions.
# ====
from typing import List
@dataclass
@@ -239,6 +13,77 @@ class Platform:
timestamp: int
@dataclass
class Provider:
"""供应商使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Plugin:
"""插件使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Command:
"""命令使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Stats:
platform: list[Platform] = field(default_factory=list)
platform: List[Platform] = field(default_factory=list)
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
@dataclass
class LLMHistory:
"""LLM 聊天时持久化的信息"""
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision:
"""Deprecated"""
id: str
url_or_path: str
caption: str
is_meme: bool
keywords: List[str]
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
@dataclass
class Conversation:
"""LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';

View File

@@ -5,7 +5,6 @@ from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.provider import RerankProvider
class FaissVecDB(BaseVecDB):
@@ -18,7 +17,6 @@ class FaissVecDB(BaseVecDB):
doc_store_path: str,
index_store_path: str,
embedding_provider: EmbeddingProvider,
rerank_provider: RerankProvider | None = None,
):
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
@@ -28,14 +26,11 @@ class FaissVecDB(BaseVecDB):
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
self.rerank_provider = rerank_provider
async def initialize(self):
await self.document_storage.initialize()
async def insert(
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
@@ -58,12 +53,7 @@ class FaissVecDB(BaseVecDB):
return int_id
async def retrieve(
self,
query: str,
k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
) -> list[Result]:
"""
搜索最相似的文档。
@@ -72,7 +62,6 @@ class FaissVecDB(BaseVecDB):
query (str): 查询文本
k (int): 返回的最相似文档的数量
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
rerank (bool): 是否使用重排序。这需要在实例化时提供 rerank_provider, 如果未提供并且 rerank 为 True, 不会抛出异常。
metadata_filters (dict): 元数据过滤器
Returns:
@@ -83,6 +72,7 @@ class FaissVecDB(BaseVecDB):
vector=np.array([embedding]).astype("float32"),
k=fetch_k if metadata_filters else k,
)
# TODO: rerank
if len(indices[0]) == 0 or indices[0][0] == -1:
return []
# normalize scores
@@ -93,7 +83,7 @@ class FaissVecDB(BaseVecDB):
)
if not fetched_docs:
return []
result_docs: list[Result] = []
result_docs = []
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
for i, indice_idx in enumerate(indices[0]):
@@ -103,20 +93,7 @@ class FaissVecDB(BaseVecDB):
fetch_doc = fetched_docs[pos]
score = scores[0][i]
result_docs.append(Result(similarity=float(score), data=fetch_doc))
top_k_results = result_docs[:k]
if rerank and self.rerank_provider:
documents = [doc.data["text"] for doc in top_k_results]
reranked_results = await self.rerank_provider.rerank(query, documents)
reranked_results = sorted(
reranked_results, key=lambda x: x.relevance_score, reverse=True
)
top_k_results = [
top_k_results[reranked_result.index] for reranked_result in reranked_results
]
return top_k_results
return result_docs[:k]
async def delete(self, doc_id: int):
"""

View File

@@ -16,32 +16,30 @@ from asyncio import Queue
from astrbot.core.pipeline.scheduler import PipelineScheduler
from astrbot.core import logger
from .platform import AstrMessageEvent
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
class EventBus:
"""用于处理事件的分发和处理"""
"""事件总线: 用于处理事件的分发和处理
def __init__(
self,
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None,
):
维护一个异步队列, 来接受各种消息事件
"""
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
self.astrbot_config_mgr = astrbot_config_mgr
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
async def dispatch(self):
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
while True:
event: AstrMessageEvent = await self.event_queue.get()
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
asyncio.create_task(scheduler.execute(event))
event: AstrMessageEvent = (
await self.event_queue.get()
) # 从事件队列中获取新的事件
self._print_event(event) # 打印日志
asyncio.create_task(
self.pipeline_scheduler.execute(event)
) # 创建新的异步任务来执行管道调度器的处理逻辑
def _print_event(self, event: AstrMessageEvent, conf_name: str):
def _print_event(self, event: AstrMessageEvent):
"""用于记录事件信息
Args:
@@ -50,10 +48,10 @@ class EventBus:
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name():
logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
else:
logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
)

View File

@@ -2,8 +2,6 @@ import asyncio
import os
import uuid
import time
from urllib.parse import urlparse, unquote
import platform
class FileTokenService:
@@ -17,9 +15,7 @@ class FileTokenService:
async def _cleanup_expired_tokens(self):
"""清理过期的令牌"""
now = time.time()
expired_tokens = [
token for token, (_, expire) in self.staged_files.items() if expire < now
]
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
for token in expired_tokens:
self.staged_files.pop(token, None)
@@ -36,35 +32,15 @@ class FileTokenService:
Raises:
FileNotFoundError: 当路径不存在时抛出
"""
# 处理 file:///
try:
parsed_uri = urlparse(file_path)
if parsed_uri.scheme == "file":
local_path = unquote(parsed_uri.path)
if platform.system() == "Windows" and local_path.startswith("/"):
local_path = local_path[1:]
else:
# 如果没有 file:/// 前缀,则认为是普通路径
local_path = file_path
except Exception:
# 解析失败时,按原路径处理
local_path = file_path
async with self.lock:
await self._cleanup_expired_tokens()
if not os.path.exists(local_path):
raise FileNotFoundError(
f"文件不存在: {local_path} (原始输入: {file_path})"
)
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
file_token = str(uuid.uuid4())
expire_time = time.time() + (
timeout if timeout is not None else self.default_timeout
)
# 存储转换后的真实路径
self.staged_files[file_token] = (local_path, expire_time)
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
self.staged_files[file_token] = (file_path, expire_time)
return file_token
async def handle_file(self, file_token: str) -> str:

View File

@@ -1,183 +0,0 @@
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, Personality
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.platform.message_session import MessageSession
from astrbot import logger
DEFAULT_PERSONALITY = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
begin_dialogs=[],
mood_imitation_dialogs=[],
tools=None,
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed="",
)
class PersonaManager:
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
self.db = db_helper
self.acm = acm
default_ps = acm.default_conf.get("provider_settings", {})
self.default_persona: str = default_ps.get("default_personality", "default")
self.personas: list[Persona] = []
self.selected_default_persona: Persona | None = None
self.personas_v3: list[Personality] = []
self.selected_default_persona_v3: Personality | None = None
self.persona_v3_config: list[dict] = []
async def initialize(self):
self.personas = await self.get_all_personas()
self.get_v3_persona_data()
logger.info(f"已加载 {len(self.personas)} 个人格。")
async def get_persona(self, persona_id: str):
"""获取指定 persona 的信息"""
persona = await self.db.get_persona_by_id(persona_id)
if not persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
return persona
async def get_default_persona_v3(
self, umo: str | MessageSession | None = None
) -> Personality:
"""获取默认 persona"""
cfg = self.acm.get_conf(umo)
default_persona_id = cfg.get("provider_settings", {}).get(
"default_personality", "default"
)
if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY
try:
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
except Exception:
return DEFAULT_PERSONALITY
async def delete_persona(self, persona_id: str):
"""删除指定 persona"""
if not await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} does not exist.")
await self.db.delete_persona(persona_id)
self.personas = [p for p in self.personas if p.persona_id != persona_id]
self.get_v3_persona_data()
async def update_persona(
self,
persona_id: str,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id)
if not existing_persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
persona = await self.db.update_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
if persona:
for i, p in enumerate(self.personas):
if p.persona_id == persona_id:
self.personas[i] = persona
break
self.get_v3_persona_data()
return persona
async def get_all_personas(self) -> list[Persona]:
"""获取所有 personas"""
return await self.db.get_personas()
async def create_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} already exists.")
new_persona = await self.db.insert_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
self.personas.append(new_persona)
self.get_v3_persona_data()
return new_persona
def get_v3_persona_data(
self,
) -> tuple[list[dict], list[Personality], Personality]:
"""获取 AstrBot <4.0.0 版本的 persona 数据。
Returns:
- list[dict]: 包含 persona 配置的字典列表。
- list[Personality]: 包含 Personality 对象的列表。
- Personality: 默认选择的 Personality 对象。
"""
v3_persona_config = [
{
"prompt": persona.system_prompt,
"name": persona.persona_id,
"begin_dialogs": persona.begin_dialogs or [],
"mood_imitation_dialogs": [], # deprecated
"tools": persona.tools,
}
for persona in self.personas
]
personas_v3: list[Personality] = []
selected_default_persona: Personality | None = None
for persona_cfg in v3_persona_config:
begin_dialogs = persona_cfg.get("begin_dialogs", [])
bd_processed = []
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append(
{
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
}
)
user_turn = not user_turn
try:
persona = Personality(
**persona_cfg,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed="", # deprecated
)
if persona["name"] == self.default_persona:
selected_default_persona = persona
personas_v3.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not selected_default_persona and len(personas_v3) > 0:
# 默认选择第一个
selected_default_persona = personas_v3[0]
if not selected_default_persona:
selected_default_persona = DEFAULT_PERSONALITY
personas_v3.append(selected_default_persona)
self.personas_v3 = personas_v3
self.selected_default_persona_v3 = selected_default_persona
self.persona_v3_config = v3_persona_config
self.selected_default_persona = Persona(
persona_id=selected_default_persona["name"],
system_prompt=selected_default_persona["prompt"],
begin_dialogs=selected_default_persona["begin_dialogs"],
tools=selected_default_persona["tools"] or None,
)
return v3_persona_config, personas_v3, selected_default_persona

View File

@@ -4,6 +4,7 @@ from astrbot.core.message.message_event_result import (
)
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
@@ -20,6 +21,7 @@ STAGES_ORDER = [
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
@@ -32,6 +34,7 @@ __all__ = [
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PlatformCompatibilityStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",

View File

@@ -1,7 +1,14 @@
import inspect
import traceback
import typing as T
from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star import PluginManager
from .context_utils import call_handler, call_event_hook
from astrbot.api import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
@dataclass
@@ -10,6 +17,97 @@ class PipelineContext:
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
async def call_event_hook(
self,
event: AstrMessageEvent,
hook_type: EventType,
*args,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
"""
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, platform_id=platform_id
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return event.is_stopped()
async def call_handler(
self,
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[None, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -1,98 +0,0 @@
import inspect
import traceback
import typing as T
from astrbot import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
async def call_event_hook(
event: AstrMessageEvent,
hook_type: EventType,
*args,
**kwargs,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
# """
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args, **kwargs)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return event.is_stopped()

View File

@@ -0,0 +1,56 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core import logger
@register_stage
class PlatformCompatibilityStage(Stage):
"""检查所有处理器的平台兼容性。
这个阶段会检查所有处理器是否在当前平台启用如果未启用则设置platform_compatible属性为False。
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化平台兼容性检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 获取当前平台ID
platform_id = event.get_platform_id()
# 获取已激活的处理器
activated_handlers = event.get_extra("activated_handlers")
if activated_handlers is None:
activated_handlers = []
# 标记不兼容的处理器
for handler in activated_handlers:
if not isinstance(handler, StarHandlerMetadata):
continue
# 检查处理器是否在当前平台启用
enabled = handler.is_enabled_for_platform(platform_id)
if not enabled:
if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name
logger.debug(
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
)
# 设置处理器为平台不兼容状态
# TODO: 更好的标记方式
handler.platform_compatible = False
else:
# 确保处理器为平台兼容状态
handler.platform_compatible = True
# 更新已激活的处理器列表
event.set_extra("activated_handlers", activated_handlers)

View File

@@ -1,33 +1,32 @@
import abc
import typing as T
from enum import Enum, auto
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponse
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from astrbot.core.provider import Provider
from dataclasses import dataclass
from astrbot.core.provider.entities import LLMResponse
from ....message.message_event_result import MessageChain
from enum import Enum, auto
class AgentState(Enum):
"""Defines the state of the agent."""
IDLE = auto() # Initial state
RUNNING = auto() # Currently processing
DONE = auto() # Completed
ERROR = auto() # Error state
"""Agent 状态枚举"""
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
class BaseAgentRunner(T.Generic[TContext]):
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData
class BaseAgentRunner:
@abc.abstractmethod
async def reset(
self,
provider: Provider,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
async def reset(self) -> None:
"""
Reset the agent to its initial state.
This method should be called before starting a new run.

View File

@@ -1,12 +1,10 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentState
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponseData
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
from ...context import PipelineContext
from astrbot.core.provider.provider import Provider
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageChain,
)
@@ -23,8 +21,8 @@ from mcp.types import (
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
CallToolResult,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
if sys.version_info >= (3, 12):
@@ -33,25 +31,28 @@ else:
from typing_extensions import override
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
# TODO:
# 1. 处理平台不兼容的处理器
class ToolLoopAgent(BaseAgentRunner):
def __init__(
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.provider = provider
self.req = None
self.event = event
self.pipeline_ctx = pipeline_ctx
self._state = AgentState.IDLE
self.final_llm_resp = None
self.streaming = False
@override
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
self.req = req
self.streaming = streaming
self.final_llm_resp = None
self._state = AgentState.IDLE
self.tool_executor = tool_executor
self.agent_hooks = agent_hooks
self.run_context = run_context
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
@@ -77,12 +78,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
@@ -129,10 +124,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 执行事件钩子
if await self.pipeline_ctx.call_event_hook(
self.event, EventType.OnLLMResponseEvent, llm_resp
):
return
# 返回 LLM 结果
if llm_resp.result_chain:
@@ -196,33 +193,50 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
try:
await self.agent_hooks.on_tool_start(
self.run_context, func_tool, func_tool_args
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
except Exception as e:
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
executor = self.tool_executor.execute(
tool=func_tool,
run_context=self.run_context,
**func_tool_args,
)
async for resp in executor:
if isinstance(resp, CallToolResult):
res = resp
if isinstance(res.content[0], TextContent):
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if not res:
continue
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
content=resource.text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
@@ -233,85 +247,43 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
else:
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
# 尝试调用工具函数
wrapper = self.pipeline_ctx.call_handler(
self.event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None:
# Tool 返回结果
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
yield MessageChain().message(resp)
else:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(
type="tool_direct_result"
).base64_image(res.content[0].data)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool_name,
func_tool_args,
resp,
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.run_context.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
else:
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
self.run_context.event.clear_result()
self.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(

View File

@@ -20,270 +20,12 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_handler
from ...context import PipelineContext
from ..agent_runner.tool_loop_agent import ToolLoopAgent
from ..stage import Stage
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.star_handler import star_map
from astrbot.core.astr_agent_context import AstrAgentContext
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
if tool.origin == "local":
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
elif tool.origin == "mcp":
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
raise Exception(f"Unknown function origin: {tool.origin}")
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input", "agent")
agent_runner = AgentRunner()
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description,
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
)
astr_agent_ctx = AstrAgentContext(
provider=run_context.context.provider,
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
)
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await run_context.event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx, event=run_context.event
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
streaming=run_context.context.streaming,
)
async for _ in run_agent(agent_runner, 15, True):
pass
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
result = (
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
)
text_content = mcp.types.TextContent(
type="text",
text=result,
)
yield mcp.types.CallToolResult(content=[text_content])
else:
yield mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not run_context.event:
raise ValueError("Event must be provided for local function tools.")
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run"):
raise ValueError("Tool must have a valid handler or 'run' method.")
awaitable = tool.handler or getattr(tool, "run")
wrapper = call_handler(
event=run_context.event,
handler=awaitable,
**tool_args,
)
async for resp in wrapper:
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await tool.mcp_client.session.call_tool(
name=tool.name,
arguments=tool_args,
)
if not res:
return
yield res
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.event, EventType.OnLLMResponseEvent, llm_response
)
MAIN_AGENT_HOOKS = MainAgentHooks()
async def run_agent(
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue
if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
astr_event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
class LLMRequestSubStage(Stage):
@@ -323,20 +65,6 @@ class LLMRequestSubStage(Stage):
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent):
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
# 获取对话上下文
cid = await conv_mgr.get_curr_conversation_id(umo)
if not cid:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
return conversation
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
@@ -372,14 +100,30 @@ class LLMRequestSubStage(Stage):
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(
event.unified_msg_origin
)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
if not conversation:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
@@ -389,7 +133,7 @@ class LLMRequestSubStage(Stage):
return
# 执行请求 LLM 前事件钩子。
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if isinstance(req.contexts, str):
@@ -423,62 +167,92 @@ class LLMRequestSubStage(Stage):
# fix messages
req.contexts = self.fix_messages(req.contexts)
# check provider modalities
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
req.image_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
if "tool_use" not in provider_cfg:
logger.debug(f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。")
req.func_tool = None
# 插件可用性设置
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
plugin = star_map.get(tool.handler_module_path)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
# run agent
agent_runner = AgentRunner()
# Call Agent
tool_loop_agent = ToolLoopAgent(
provider=provider,
event=event,
pipeline_ctx=self.ctx,
)
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
)
astr_agent_ctx = AstrAgentContext(
provider=provider,
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
)
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
async def requesting():
step_idx = 0
while step_idx < self.max_step:
step_idx += 1
try:
async for resp in tool_loop_agent.step():
if event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if self.streaming_response:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if (
self.show_tool_use
or event.get_platform_name() == "webchat"
):
resp.data["chain"].type = "tool_call"
await event.send(resp.data["chain"])
continue
if not self.streaming_response:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if tool_loop_agent.done():
break
except Exception as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
)
if self.streaming_response:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use)
)
.set_async_stream(requesting())
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if tool_loop_agent.done():
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
@@ -492,15 +266,15 @@ class LLMRequestSubStage(Stage):
)
)
else:
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
async for _ in requesting():
yield
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
@@ -533,10 +307,19 @@ class LLMRequestSubStage(Stage):
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
unified_msg_origin=event.unified_msg_origin,
title=title,
conversation_id=req.conversation.cid,
event.unified_msg_origin, title=title
)
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
# webchat adapter 中session_id 的格式是 f"webchat!{username}!{cid}"
# TODO: 优化 WebChat 适配器的对话管理
if event.session_id:
username, cid = event.session_id.split("!")[1:3]
db_helper = self.ctx.plugin_manager.context._db
db_helper.update_conversation_title(
user_id=username,
cid=cid,
title=title,
)
async def _save_to_history(
self,
@@ -552,10 +335,6 @@ class LLMRequestSubStage(Stage):
):
return
if not llm_response.completion_text and not req.tool_calls_result:
logger.debug("LLM 响应为空,不保存记录。")
return
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入

View File

@@ -2,7 +2,7 @@
本地 Agent 模式的 AstrBot 插件调用 Stage
"""
from ...context import PipelineContext, call_handler
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -33,6 +33,16 @@ class StarRequestSubStage(Stage):
handlers_parsed_params = {}
for handler in activated_handlers:
# 检查处理器是否在当前平台兼容
if (
hasattr(handler, "platform_compatible")
and handler.platform_compatible is False
):
logger.debug(
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
)
continue
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
@@ -40,7 +50,7 @@ class StarRequestSubStage(Stage):
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
)
wrapper = call_handler(event, handler.handler, **params)
wrapper = self.ctx.call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果

View File

@@ -128,7 +128,7 @@ class RespondStage(Stage):
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_id()})")
logger.info(f"应用流式输出({event.get_platform_name()})")
await event.send_streaming(result.async_stream, use_fallback)
return
elif len(result.chain) > 0:
@@ -144,6 +144,8 @@ class RespondStage(Stage):
try:
if await self._is_empty_message_chain(result.chain):
logger.info("消息为空,跳过发送阶段")
event.clear_result()
event.stop_event()
return
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
@@ -214,7 +216,7 @@ class RespondStage(Stage):
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
)
for handler in handlers:
try:

View File

@@ -64,10 +64,9 @@ class ResultDecorateStage(Stage):
]
self.content_safe_check_stage = None
if self.content_safe_check_reply:
for stage_cls in registered_stages:
if stage_cls.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx)
for stage in registered_stages:
if stage.__class__.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage
async def process(
self, event: AstrMessageEvent
@@ -99,7 +98,7 @@ class ResultDecorateStage(Stage):
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
)
for handler in handlers:
try:

View File

@@ -11,17 +11,16 @@ class PipelineScheduler:
def __init__(self, context: PipelineContext):
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__)
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
) # 按照顺序排序
self.ctx = context # 上下文对象
self.stages = [] # 存储阶段实例
async def initialize(self):
"""初始化管道调度器时, 初始化所有阶段"""
for stage_cls in registered_stages:
stage_instance = stage_cls() # 创建实例
await stage_instance.initialize(self.ctx)
self.stages.append(stage_instance)
for stage in registered_stages:
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
"""依次执行各个阶段
@@ -30,9 +29,9 @@ class PipelineScheduler:
event (AstrMessageEvent): 事件对象
from_stage (int): 从第几个阶段开始执行, 默认从0开始
"""
for i in range(from_stage, len(self.stages)):
stage = self.stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__.__name__}")
for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
coroutine = stage.process(
event
) # 调用阶段的process方法, 返回协程或者异步生成器

View File

@@ -1,15 +1,15 @@
from __future__ import annotations
import abc
from typing import List, AsyncGenerator, Union, Type
from typing import List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
def register_stage(cls):
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
registered_stages.append(cls)
registered_stages.append(cls())
return cls

View File

@@ -112,17 +112,8 @@ class WakingCheckStage(Stage):
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
# 将 plugins_name 设置到 event 中
enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"])
if enabled_plugins_name == ["*"]:
# 如果是 *,则表示所有插件都启用
event.plugins_name = None
else:
event.plugins_name = enabled_plugins_name
logger.debug(f"enabled_plugins_name: {enabled_plugins_name}")
for handler in star_handlers_registry.get_handlers_by_event_type(
EventType.AdapterMessageEvent, plugins_name=event.plugins_name
EventType.AdapterMessageEvent
):
# filter 需满足 AND 逻辑关系
passed = True

View File

@@ -3,10 +3,9 @@ import asyncio
import re
import hashlib
import uuid
from dataclasses import dataclass
from typing import List, Union, Optional, AsyncGenerator
from astrbot import logger
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
Plain,
@@ -24,7 +23,21 @@ from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
@dataclass
class MessageSesion:
platform_name: str
message_type: MessageType
session_id: str
def __str__(self):
return f"{self.platform_name}:{self.message_type.value}:{self.session_id}"
@staticmethod
def from_str(session_str: str):
platform_name, message_type, session_id = session_str.split(":")
return MessageSesion(platform_name, MessageType(message_type), session_id)
class AstrMessageEvent(abc.ABC):
@@ -51,7 +64,7 @@ class AstrMessageEvent(abc.ABC):
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.id,
platform_name=platform_meta.name,
message_type=message_obj.type,
session_id=session_id,
)
@@ -65,23 +78,13 @@ class AstrMessageEvent(abc.ABC):
self.call_llm = False
"""是否在此消息事件中禁止默认的 LLM 请求"""
self.plugins_name: list[str] | None = None
"""该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。"""
# back_compability
self.platform = platform_meta
def get_platform_name(self):
"""获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。
NOTE: 用户可能会同时运行多个相同类型的平台适配器。"""
return self.platform_meta.name
def get_platform_id(self):
"""获取这个事件所属的平台的 ID。
NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。
"""
return self.platform_meta.id
def get_message_str(self) -> str:
@@ -185,7 +188,6 @@ class AstrMessageEvent(abc.ABC):
"""
清除额外的信息。
"""
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
self._extras.clear()
def is_private_chat(self) -> bool:

View File

@@ -18,9 +18,6 @@ class PlatformManager:
self.platforms_config = config["platform"]
self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
这个配置中的 unique_session 需要特殊处理,
约定整个项目中对 unique_session 的引用都从 default 的配置中获取"""
self.event_queue = event_queue
async def initialize(self):

View File

@@ -1,28 +0,0 @@
from astrbot.core.platform.message_type import MessageType
from dataclasses import dataclass
@dataclass
class MessageSession:
"""描述一条消息在 AstrBot 中对应的会话的唯一标识。
如果您需要实例化 MessageSession请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。"""
platform_name: str
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
message_type: MessageType
session_id: str
platform_id: str = None
def __str__(self):
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
def __post_init__(self):
self.platform_id = self.platform_name
@staticmethod
def from_str(session_str: str):
platform_id, message_type, session_id = session_str.split(":")
return MessageSession(platform_id, MessageType(message_type), session_id)
MessageSesion = MessageSession # back compatibility

View File

@@ -5,7 +5,7 @@ from asyncio import Queue
from .platform_metadata import PlatformMetadata
from .astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from .message_session import MessageSesion
from .astr_message_event import MessageSesion
from astrbot.core.utils.metrics import Metric

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
@dataclass
class PlatformMetadata:
name: str
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
"""平台的名称"""
description: str
"""平台的描述"""
id: str = None

View File

@@ -26,7 +26,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
segment.text,
"AstrBot",
segment.text,
self.message_obj.raw_message,
)

View File

@@ -3,23 +3,15 @@ import botpy.message
import botpy.types
import botpy.types.message
import asyncio
import base64
import aiofiles
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from astrbot.api.message_components import Plain, Image
from botpy import Client
from botpy.http import Route
from astrbot.api import logger
from botpy.types.message import Media
from botpy.types import message
from typing import Optional
import random
import uuid
import os
class QQOfficialMessageEvent(AstrMessageEvent):
@@ -44,7 +36,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
ret = None
try:
async for chain in generator:
source = self.message_obj.raw_message
@@ -94,10 +85,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text,
image_base64,
image_path,
record_file_path
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
if not plain_text and not image_base64 and not image_path and not record_file_path:
if not plain_text and not image_base64 and not image_path:
return
payload = {
@@ -108,8 +98,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
payload["msg_seq"] = random.randint(1, 10000)
ret = None
match type(source):
case botpy.message.GroupMessage:
if image_base64:
@@ -118,12 +106,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path: # group record msg
media = await self.upload_group_and_c2c_record(
record_file_path, 3, group_openid=source.group_openid
)
payload["media"] = media
payload["msg_type"] = 7
ret = await self.bot.api.post_group_message(
group_openid=source.group_openid, **payload
)
@@ -134,12 +116,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path: # c2c record
media = await self.upload_group_and_c2c_record(
record_file_path, 3, openid = source.author.user_openid
)
payload["media"] = media
payload["msg_type"] = 7
if stream:
ret = await self.post_c2c_message(
openid=source.author.user_openid,
@@ -189,59 +165,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
return await self.bot.api._http.request(route, json=payload)
async def upload_group_and_c2c_record(
self,
file_source: str,
file_type: int,
srv_send_msg: bool = False,
**kwargs
) -> Optional[Media]:
"""
上传媒体文件
"""
# 构建基础payload
payload = {
"file_type": file_type,
"srv_send_msg": srv_send_msg
}
# 处理文件数据
if os.path.exists(file_source):
# 读取本地文件
async with aiofiles.open(file_source, 'rb') as f:
file_content = await f.read()
# use base64 encode
payload["file_data"] = base64.b64encode(file_content).decode('utf-8')
else:
# 使用URL
payload["url"] = file_source
# 添加接收者信息和确定路由
if "openid" in kwargs:
payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
elif "group_openid" in kwargs:
payload["group_openid"] =kwargs["group_openid"]
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"])
else:
return None
try:
# 使用底层HTTP请求
result = await self.bot.api._http.request(route, json=payload)
if result:
return Media(
file_uuid=result.get("file_uuid"),
file_info=result.get("file_info"),
ttl=result.get("ttl", 0),
file_id=result.get("id", "")
)
except Exception as e:
logger.error(f"上传请求错误: {e}")
return None
async def post_c2c_message(
self,
openid: str,
@@ -268,7 +191,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text = ""
image_base64 = None # only one img supported
image_file_path = None
record_file_path = None
for i in message.chain:
if isinstance(i, Plain):
plain_text += i.text
@@ -284,21 +206,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
else:
image_base64 = file_to_base64(i.file)
image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record):
if i.file:
record_wav_path = await i.convert_to_file_path() # wav 路径
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
try:
duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path)
if duration > 0:
record_file_path = record_tecent_silk_path
else:
record_file_path = None
logger.error("转换音频格式时出错音频时长不大于0")
except Exception as e:
logger.error(f"处理语音时出错: {e}")
record_file_path = None
else:
logger.debug(f"qq_official 忽略 {i.type}")
return plain_text, image_base64, image_file_path, record_file_path
return plain_text, image_base64, image_file_path

View File

@@ -77,7 +77,7 @@ class WebChatAdapter(Platform):
os.makedirs(self.imgs_dir, exist_ok=True)
self.metadata = PlatformMetadata(
name="webchat", description="webchat", id="webchat"
name="webchat", description="webchat", id=self.config.get("id", "")
)
async def send_by_session(

View File

@@ -1,47 +0,0 @@
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import PlatformMessageHistory
class PlatformMessageHistoryManager:
def __init__(self, db_helper: BaseDatabase):
self.db = db_helper
async def insert(
self,
platform_id: str,
user_id: str,
content: list[dict], # TODO: parse from message chain
sender_id: str = None,
sender_name: str = None,
):
"""Insert a new platform message history record."""
await self.db.insert_platform_message_history(
platform_id=platform_id,
user_id=user_id,
content=content,
sender_id=sender_id,
sender_name=sender_name,
)
async def get(
self,
platform_id: str,
user_id: str,
page: int = 1,
page_size: int = 200,
) -> list[PlatformMessageHistory]:
"""Get platform message history for a specific user."""
history = await self.db.get_platform_message_history(
platform_id=platform_id,
user_id=user_id,
page=page,
page_size=page_size,
)
history.reverse()
return history
async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400):
"""Delete platform message history records older than the specified offset."""
await self.db.delete_platform_message_offset(
platform_id=platform_id, user_id=user_id, offset_sec=offset_sec
)

View File

@@ -5,7 +5,7 @@ from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
from astrbot.core.agent.tool import ToolSet
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
@@ -20,7 +20,6 @@ class ProviderType(enum.Enum):
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
EMBEDDING = "embedding"
RERANK = "rerank"
@dataclass
@@ -98,7 +97,7 @@ class ProviderRequest:
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
func_tool: ToolSet | None = None
func_tool: FuncCall | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
"""上下文。格式与 openai 的上下文格式一致:
@@ -294,10 +293,3 @@ class LLMResponse:
}
)
return ret
@dataclass
class RerankResult:
index: int
"""在候选列表中的索引位置"""
relevance_score: float
"""相关性分数"""

View File

@@ -1,17 +1,32 @@
from __future__ import annotations
import json
import textwrap
import os
import asyncio
import aiohttp
import logging
from datetime import timedelta
from typing import Dict, List, Awaitable
from typing import Dict, List, Awaitable, Literal, Any
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core import sp
from astrbot.core.utils.log_pipe import LogPipe
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.agent.mcp_client import MCPClient
from astrbot.core.agent.tool import ToolSet, FunctionTool
try:
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
@@ -24,10 +39,6 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型
# alias
FuncTool = FunctionTool
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
@@ -94,9 +105,181 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"{e!s}"
class FunctionToolManager:
@dataclass
class FuncTool:
"""
用于描述一个函数调用工具。
"""
name: str
parameters: Dict
description: str
handler: Awaitable = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def execute(self, **args) -> Any:
"""执行函数调用"""
if self.origin == "local":
if not self.handler:
raise Exception(f"Local function {self.name} has no handler")
return await self.handler(**args)
elif self.origin == "mcp":
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
actual_tool_name = (
self.name.split(":")[-1] if ":" in self.name else self.name
)
return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.active: bool = True
self.tools: List[mcp.Tool] = []
self.server_errlogs: List[str] = []
self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
如果 `url` 参数存在:
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
# 处理 MCP 服务的错误日志
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
if not success:
raise Exception(error_msg)
if cfg.get("transport") != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str):
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
), # type: ignore
),
)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done
class FuncCall:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
@@ -104,29 +287,6 @@ class FunctionToolManager:
def empty(self) -> bool:
return len(self.func_list) == 0
def spec_to_func(
self,
name: str,
func_args: list,
desc: str,
handler: Awaitable,
) -> FuncTool:
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
return FuncTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
def add_func(
self,
name: str,
@@ -144,14 +304,22 @@ class FunctionToolManager:
# check if the tool has been added before
self.remove_func(name)
self.func_list.append(
self.spec_to_func(
name=name,
func_args=func_args,
desc=desc,
handler=handler,
)
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FuncTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.func_list.append(_func)
logger.info(f"添加函数调用工具: {name}")
def remove_func(self, name: str) -> None:
@@ -163,15 +331,11 @@ class FunctionToolManager:
self.func_list.pop(i)
break
def get_func(self, name) -> FuncTool | None:
def get_func(self, name) -> FuncTool:
for f in self.func_list:
if f.name == name:
return f
def get_full_tool_set(self) -> ToolSet:
"""获取完整工具集"""
tool_set = ToolSet(self.func_list.copy())
return tool_set
return None
async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
@@ -392,179 +556,203 @@ class FunctionToolManager:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.openai_schema(
omit_empty_parameter_field=omit_empty_parameter_field
)
_l = []
# 处理所有工具包括本地和MCP工具
for f in self.func_list:
if not f.active:
continue
func_ = {
"type": "function",
"function": {
"name": f.name,
# "parameters": f.parameters,
"description": f.description,
},
}
func_["function"]["parameters"] = f.parameters
if not f.parameters.get("properties") and omit_empty_parameter_field:
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True则删除 parameters 字段
del func_["function"]["parameters"]
_l.append(func_)
return _l
def get_func_desc_anthropic_style(self) -> list:
"""
获得 Anthropic API 风格的**已经激活**的工具描述
"""
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.anthropic_schema()
tools = []
for f in self.func_list:
if not f.active:
continue
# Convert internal format to Anthropic style
tool = {
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": f.parameters.get("properties", {}),
# Keep the required field from the original parameters if it exists
"required": f.parameters.get("required", []),
},
}
tools.append(tool)
return tools
def get_func_desc_google_genai_style(self) -> dict:
"""
获得 Google GenAI API 风格的**已经激活**的工具描述
"""
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.google_schema()
def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False"""
func_tool = self.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get(
"inactivated_llm_tools", [], scope="global", scope_id="global"
)
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put(
"inactivated_llm_tools",
inactivated_llm_tools,
scope="global",
scope_id="global",
)
return True
return False
# 因为不想解决循环引用,所以这里直接传入 star_map 先了...
def activate_llm_tool(self, name: str, star_map: dict) -> bool:
func_tool = self.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
)
func_tool.active = True
inactivated_llm_tools: list = sp.get(
"inactivated_llm_tools", [], scope="global", scope_id="global"
)
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put(
"inactivated_llm_tools",
inactivated_llm_tools,
scope="global",
scope_id="global",
)
return True
return False
@property
def mcp_config_path(self):
data_dir = get_astrbot_data_path()
return os.path.join(data_dir, "mcp_server.json")
def load_mcp_config(self):
if not os.path.exists(self.mcp_config_path):
# 配置文件不存在,创建默认配置
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
return DEFAULT_MCP_CONFIG
try:
with open(self.mcp_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"加载 MCP 配置失败: {e}")
return DEFAULT_MCP_CONFIG
def save_mcp_config(self, config: dict):
try:
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
return True
except Exception as e:
logger.error(f"保存 MCP 配置失败: {e}")
return False
async def sync_modelscope_mcp_servers(self, access_token: str) -> None:
"""从 ModelScope 平台同步 MCP 服务器配置"""
base_url = "https://www.modelscope.cn/openapi/v1"
url = f"{base_url}/mcp/servers/operational"
headers = {
"Authorization": f"Bearer {access_token.strip()}",
"Content-Type": "application/json",
# Gemini API 支持的数据类型和格式
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
mcp_server_list = data.get("data", {}).get(
"mcp_server_list", []
)
local_mcp_config = self.load_mcp_config()
def convert_schema(schema: dict) -> dict:
"""转换 schema 为 Gemini API 格式"""
synced_count = 0
for server in mcp_server_list:
server_name = server["name"]
operational_urls = server.get("operational_urls", [])
if not operational_urls:
continue
url_info = operational_urls[0]
server_url = url_info.get("url")
if not server_url:
continue
# 添加到配置中(同名会覆盖)
local_mcp_config["mcpServers"][server_name] = {
"url": server_url,
"transport": "sse",
"active": True,
"provider": "modelscope",
}
synced_count += 1
# 如果 schema 包含 anyOf则只返回 anyOf 字段
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
if synced_count > 0:
self.save_mcp_config(local_mcp_config)
tasks = []
for server in mcp_server_list:
name = server["name"]
tasks.append(
self.enable_mcp_server(
name=name,
config=local_mcp_config["mcpServers"][name],
)
)
await asyncio.gather(*tasks)
logger.info(
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器"
)
else:
logger.warning("没有找到可用的 ModelScope MCP 服务器")
else:
raise Exception(
f"ModelScope API 请求失败: HTTP {response.status}"
)
result = {}
except aiohttp.ClientError as e:
raise Exception(f"网络连接错误: {str(e)}")
except Exception as e:
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {str(e)}")
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
# 暂时指定默认为null
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
if properties: # 只在有非空属性时添加
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = [
{
"name": f.name,
"description": f.description,
**({"parameters": convert_schema(f.parameters)}),
}
for f in self.func_list
if f.active
]
declarations = {}
if tools:
declarations["function_declarations"] = tools
return declarations
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
func_definition = json.dumps(_l, ensure_ascii=False)
prompt = textwrap.dedent(f"""
ROLE:
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
TOOLS:
可用的函数列表:
{func_definition}
LIMIT:
1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
EXAMPLE:
1. `用户提问`:请问一下天气怎么样? `函数调用`[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
用户的提问是:{question}
""")
_c = 0
while _c < 3:
try:
res = await provider.text_chat(prompt, session_id)
if res.find("```") != -1:
res = res[res.find("```json") + 7 : res.rfind("```")]
res = json.loads(res)
break
except Exception as e:
_c += 1
if _c == 3:
raise e
if "The message you submitted was too long" in str(e):
raise e
if "res" in res and not res["res"]:
return "", False
tool_call_result = []
for tool in res:
# 说明有函数调用
func_name = tool["name"]
args = tool["args"]
# 调用函数
func_tool = self.get_func(func_name)
if not func_tool:
raise Exception(f"Request function {func_name} not found.")
ret = await func_tool.execute(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)
# alias
FuncCall = FunctionToolManager

View File

@@ -3,32 +3,89 @@ import traceback
from typing import List
from astrbot.core import logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
from .register import llm_tools, provider_cls_map
from ..persona_mgr import PersonaManager
class ProviderManager:
def __init__(
self,
acm: AstrBotConfigManager,
db_helper: BaseDatabase,
persona_mgr: PersonaManager,
):
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
self.providers_config: List = config["provider"]
self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
self.persona_configs: list = config.get("persona", [])
self.astrbot_config = config
# 人格相关属性v4.0.0 版本后被废弃,推荐使用 PersonaManager
self.default_persona_name = persona_mgr.default_persona
# 人格情景管理
# 目前没有拆成独立的模块
self.default_persona_name = self.provider_settings.get(
"default_personality", "default"
)
self.personas: List[Personality] = []
self.selected_default_persona = None
for persona in self.persona_configs:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
bd_processed = []
mid_processed = ""
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append(
{
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
}
)
user_turn = not user_turn
if mood_imitation_dialogs:
if len(mood_imitation_dialogs) % 2 != 0:
logger.error(
f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
)
mood_imitation_dialogs = []
user_turn = True
for dialog in mood_imitation_dialogs:
role = "A" if user_turn else "B"
mid_processed += f"{role}: {dialog}\n"
if not user_turn:
mid_processed += "\n"
user_turn = not user_turn
try:
persona = Personality(
**persona,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed=mid_processed,
)
if persona["name"] == self.default_persona_name:
self.selected_default_persona = persona
self.personas.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not self.selected_default_persona and len(self.personas) > 0:
# 默认选择第一个
self.selected_default_persona = self.personas[0]
if not self.selected_default_persona:
self.selected_default_persona = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed="",
)
self.personas.append(self.selected_default_persona)
self.provider_insts: List[Provider] = []
"""加载的 Provider 的实例"""
@@ -43,111 +100,46 @@ class ProviderManager:
self.llm_tools = llm_tools
self.curr_provider_inst: Provider | None = None
"""默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
"""默认的 Provider 实例"""
self.curr_stt_provider_inst: STTProvider | None = None
"""默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
"""默认的 Speech To Text Provider 实例"""
self.curr_tts_provider_inst: TTSProvider | None = None
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
"""默认的 Text To Speech Provider 实例"""
self.db_helper = db_helper
@property
def persona_configs(self) -> list:
"""动态获取最新的 persona 配置"""
return self.persona_mgr.persona_v3_config
@property
def personas(self) -> list:
"""动态获取最新的 personas 列表"""
return self.persona_mgr.personas_v3
@property
def selected_default_persona(self):
"""动态获取最新的默认选中 persona。已弃用请使用 context.persona_mgr.get_default_persona_v3()"""
return self.persona_mgr.selected_default_persona_v3
# kdb(experimental)
self.curr_kdb_name = ""
kdb_cfg = config.get("knowledge_db", {})
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
async def set_provider(
self, provider_id: str, provider_type: ProviderType, umo: str | None = None
self, provider_id: str, provider_type: ProviderType, umo: str = None
):
"""设置提供商。
Args:
provider_id (str): 提供商 ID。
provider_type (ProviderType): 提供商类型。
umo (str, optional): 用户会话 ID用于提供商会话隔离。
Version 4.0.0: 这个版本下已经默认隔离提供商
umo (str, optional): 用户会话 ID用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
"""
if provider_id not in self.inst_map:
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
if umo:
await sp.session_put(
umo,
f"provider_perf_{provider_type.value}",
provider_id,
)
if umo and self.provider_settings["separate_provider"]:
perf = sp.get("session_provider_perf", {})
session_perf = perf.get(umo, {})
session_perf[provider_type.value] = provider_id
perf[umo] = session_perf
sp.put("session_provider_perf", perf)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
sp.put("curr_provider_tts", provider_id)
elif provider_type == ProviderType.SPEECH_TO_TEXT:
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
sp.put("curr_provider_stt", provider_id)
elif provider_type == ProviderType.CHAT_COMPLETION:
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(self, provider_type: ProviderType, umo=None):
"""获取正在使用的提供商实例。
Args:
provider_type (ProviderType): 提供商类型。
umo (str, optional): 用户会话 ID用于提供商会话隔离。
Returns:
Provider: 正在使用的提供商实例。
"""
provider = None
if umo:
provider_id = sp.get(
f"provider_perf_{provider_type.value}",
None,
scope="umo",
scope_id=umo,
)
if provider_id:
provider = self.inst_map.get(provider_id)
if not provider:
# default setting
config = self.acm.get_conf(umo)
if provider_type == ProviderType.CHAT_COMPLETION:
provider_id = config["provider_settings"].get("default_provider_id")
provider = self.inst_map.get(provider_id)
if not provider:
provider = self.provider_insts[0] if self.provider_insts else None
elif provider_type == ProviderType.SPEECH_TO_TEXT:
provider_id = config["provider_stt_settings"].get("provider_id")
if not provider_id:
return None
provider = self.inst_map.get(provider_id)
if not provider:
provider = (
self.stt_provider_insts[0] if self.stt_provider_insts else None
)
elif provider_type == ProviderType.TEXT_TO_SPEECH:
provider_id = config["provider_tts_settings"].get("provider_id")
if not provider_id:
return None
provider = self.inst_map.get(provider_id)
if not provider:
provider = (
self.tts_provider_insts[0] if self.tts_provider_insts else None
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
return provider
sp.put("curr_provider", provider_id)
async def initialize(self):
# 逐个初始化提供商
@@ -156,22 +148,13 @@ class ProviderManager:
# 设置默认提供商
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
scope="global",
scope_id="global",
"curr_provider", self.provider_settings.get("default_provider_id")
)
selected_stt_provider_id = sp.get(
"curr_provider_stt",
self.provider_stt_settings.get("provider_id"),
scope="global",
scope_id="global",
"curr_provider_stt", self.provider_stt_settings.get("provider_id")
)
selected_tts_provider_id = sp.get(
"curr_provider_tts",
self.provider_tts_settings.get("provider_id"),
scope="global",
scope_id="global",
"curr_provider_tts", self.provider_tts_settings.get("provider_id")
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
if not self.curr_provider_inst and self.provider_insts:
@@ -279,10 +262,6 @@ class ProviderManager:
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
case "vllm_rerank":
from .sources.vllm_rerank_source import (
VLLMRerankProvider as VLLMRerankProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
@@ -366,7 +345,7 @@ class ProviderManager:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)

View File

@@ -1,18 +1,23 @@
import abc
from typing import List
from typing import AsyncGenerator
from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import (
LLMResponse,
ToolCallsResult,
ProviderType,
RerankResult,
)
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.db.po import Personality
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
begin_dialogs: List[str] = []
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict] = []
_mood_imitation_dialogs_processed: str = ""
@dataclass
class ProviderMeta:
id: str
@@ -85,7 +90,7 @@ class Provider(AbstractProvider):
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: ToolSet = None,
func_tool: FuncCall = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
@@ -114,7 +119,7 @@ class Provider(AbstractProvider):
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: ToolSet = None,
func_tool: FuncCall = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
@@ -201,17 +206,3 @@ class EmbeddingProvider(AbstractProvider):
def get_dim(self) -> int:
"""获取向量的维度"""
...
class RerankProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def rerank(
self, query: str, documents: list[str], top_n: int | None = None
) -> list[RerankResult]:
"""获取查询和文档的重排序分数"""
...

View File

@@ -75,7 +75,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.session_get(session_id, "session_variables", default={})
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
if (

View File

@@ -97,7 +97,8 @@ class ProviderDify(Provider):
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.session_get(session_id, "session_variables", default={})
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt

View File

@@ -1,6 +1,5 @@
import os
import uuid
import re
import ormsgpack
from pydantic import BaseModel, conint
from httpx import AsyncClient
@@ -25,8 +24,8 @@ class ServeTTSRequest(BaseModel):
# 参考音频
references: list[ServeReferenceAudio] = []
# 参考模型 ID
# 例如 https://fish.audio/m/626bb6d3f3364c9cbc3aa6a67300a664/
# 其中reference_id为 626bb6d3f3364c9cbc3aa6a67300a664
# 例如 https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# 其中reference_id为 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# 对中英文文本进行标准化,这可以提高数字的稳定性
normalize: bool = True
@@ -45,7 +44,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "")
self.reference_id: str = provider_config.get("fishaudio-tts-reference-id", "")
self.character: str = provider_config.get("fishaudio-tts-character", "可莉")
self.api_base: str = provider_config.get(
"api_base", "https://api.fish-audio.cn/v1"
@@ -83,43 +81,11 @@ class ProviderFishAudioTTSAPI(TTSProvider):
return item["_id"]
return None
def _validate_reference_id(self, reference_id: str) -> bool:
"""
验证reference_id格式是否有效
Args:
reference_id: 参考模型ID
Returns:
bool: ID是否有效
"""
if not reference_id or not reference_id.strip():
return False
# FishAudio的reference_id通常是32位十六进制字符串
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
pattern = r'^[a-fA-F0-9]{32}$'
return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict:
# 向前兼容逻辑优先使用reference_id如果没有则使用角色名称查询
if self.reference_id and self.reference_id.strip():
# 验证reference_id格式
if not self._validate_reference_id(self.reference_id):
raise ValueError(
f"无效的FishAudio参考模型ID: '{self.reference_id}'. "
f"请确保ID是32位十六进制字符串例如: 626bb6d3f3364c9cbc3aa6a67300a664"
f"您可以从 https://fish.audio/zh-CN/discovery 获取有效的模型ID。"
)
reference_id = self.reference_id.strip()
else:
# 回退到原来的角色名称查询逻辑
reference_id = await self._get_reference_id_by_character(self.character)
return ServeTTSRequest(
text=text,
format="wav",
reference_id=reference_id,
reference_id=await self._get_reference_id_by_character(self.character),
)
async def get_audio(self, text: str) -> str:

View File

@@ -431,7 +431,6 @@ class ProviderGoogleGenAI(Provider):
continue
llm_response = LLMResponse("assistant")
llm_response.raw_completion = result
llm_response.result_chain = self._process_content_parts(result, llm_response)
return llm_response
@@ -471,10 +470,6 @@ class ProviderGoogleGenAI(Provider):
raise
continue
# Accumulate the complete response text for the final response
accumulated_text = ""
final_response = None
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)
@@ -482,43 +477,27 @@ class ProviderGoogleGenAI(Provider):
part.function_call for part in chunk.candidates[0].content.parts
):
llm_response = LLMResponse("assistant", is_chunk=False)
llm_response.raw_completion = chunk
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
)
yield llm_response
return
break
if chunk.text:
accumulated_text += chunk.text
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
yield llm_response
if chunk.candidates[0].finish_reason:
# Process the final chunk for potential tool calls or other content
if chunk.candidates[0].content.parts:
final_response = LLMResponse("assistant", is_chunk=False)
final_response.raw_completion = chunk
final_response.result_chain = self._process_content_parts(
chunk, final_response
llm_response = LLMResponse("assistant", is_chunk=False)
if not chunk.candidates[0].content.parts:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
else:
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
)
yield llm_response
break
# Yield final complete response with accumulated text
if not final_response:
final_response = LLMResponse("assistant", is_chunk=False)
# Set the complete accumulated text in the final response
if accumulated_text:
final_response.result_chain = MessageChain(
chain=[Comp.Plain(accumulated_text)]
)
elif not final_response.result_chain:
# If no text was accumulated and no final response was set, provide empty space
final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
yield final_response
async def text_chat(
self,
prompt: str,

View File

@@ -99,14 +99,6 @@ class ProviderOpenAIOfficial(Provider):
for key in to_del:
del payloads[key]
model = payloads.get("model", "")
# 针对 qwen3 模型的特殊处理:非流式调用必须设置 enable_thinking=false
if "qwen3" in model.lower():
extra_body["enable_thinking"] = False
# 针对 deepseek 模型的特殊处理deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
elif model == "deepseek-reasoner" and "tools" in payloads:
del payloads["tools"]
completion = await self.client.chat.completions.create(
**payloads, stream=False, extra_body=extra_body
)
@@ -184,7 +176,7 @@ class ProviderOpenAIOfficial(Provider):
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
if choice.message.content is not None:
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
llm_response.result_chain = MessageChain().message(completion_text)
@@ -218,7 +210,7 @@ class ProviderOpenAIOfficial(Provider):
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
)
if llm_response.completion_text is None and not llm_response.tools_call_args:
if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")

View File

@@ -1,59 +0,0 @@
import aiohttp
from ..provider import RerankProvider
from ..register import register_provider_adapter
from ..entities import ProviderType, RerankResult
@register_provider_adapter(
"vllm_rerank",
"VLLM Rerank 适配器",
provider_type=ProviderType.RERANK,
)
class VLLMRerankProvider(RerankProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
self.auth_key = provider_config.get("rerank_api_key", "")
self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
self.base_url = self.base_url.rstrip("/")
self.timeout = provider_config.get("timeout", 20)
self.model = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
h = {}
if self.auth_key:
h["Authorization"] = f"Bearer {self.auth_key}"
self.client = aiohttp.ClientSession(
headers=h,
timeout=aiohttp.ClientTimeout(total=self.timeout),
)
async def rerank(
self, query: str, documents: list[str], top_n: int | None = None
) -> list[RerankResult]:
payload = {
"query": query,
"documents": documents,
"model": self.model,
}
if top_n is not None:
payload["top_n"] = top_n
async with self.client.post(
f"{self.base_url}/v1/rerank", json=payload
) as response:
response_data = await response.json()
results = response_data.get("results", [])
return [
RerankResult(
index=result["index"],
relevance_score=result["relevance_score"],
)
for result in results
]
async def terminate(self) -> None:
"""关闭客户端会话"""
if self.client:
await self.client.close()
self.client = None

View File

@@ -34,7 +34,7 @@ class Star(CommandParserMixin):
@staticmethod
async def html_render(
tmpl: str, data: dict, return_url=True, options: dict | None = None
tmpl: str, data: dict, return_url=True, options: dict = None
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(

View File

@@ -1,24 +1,17 @@
from asyncio import Queue
from typing import List, Union
from astrbot.core.provider.provider import (
Provider,
TTSProvider,
STTProvider,
EmbeddingProvider,
)
from astrbot.core import sp
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform import Platform
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.persona_mgr import PersonaManager
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
@@ -29,7 +22,6 @@ from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
ADAPTER_NAME_2_TYPE,
)
from deprecated import deprecated
class Context:
@@ -37,6 +29,19 @@ class Context:
暴露给插件的接口上下文。
"""
_event_queue: Queue = None
"""事件队列。消息平台通过事件队列传递消息事件。"""
_config: AstrBotConfig = None
"""AstrBot 配置信息"""
_db: BaseDatabase = None
"""AstrBot 数据库"""
provider_manager: ProviderManager = None
platform_manager: PlatformManager = None
registered_web_apis: list = []
# back compatibility
@@ -48,27 +53,18 @@ class Context:
event_queue: Queue,
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager,
platform_manager: PlatformManager,
conversation_manager: ConversationManager,
message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager,
astrbot_config_mgr: AstrBotConfigManager,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
):
self._event_queue = event_queue
"""事件队列。消息平台通过事件队列传递消息事件。"""
self._config = config
"""AstrBot 默认配置"""
self._db = db
"""AstrBot 数据库"""
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.conversation_manager = conversation_manager
self.message_history_manager = message_history_manager
self.persona_manager = persona_manager
self.astrbot_config_mgr = astrbot_config_mgr
def get_registered_star(self, star_name: str) -> StarMetadata | None:
def get_registered_star(self, star_name: str) -> StarMetadata:
"""根据插件名获取插件的 Metadata"""
for star in star_registry:
if star.name == star_name:
@@ -78,7 +74,7 @@ class Context:
"""获取当前载入的所有插件 Metadata 的列表"""
return star_registry
def get_llm_tool_manager(self) -> FunctionToolManager:
def get_llm_tool_manager(self) -> FuncCall:
"""获取 LLM Tool Manager其用于管理注册的所有的 Function-calling tools"""
return self.provider_manager.llm_tools
@@ -88,14 +84,40 @@ class Context:
Returns:
如果没找到,会返回 False
"""
return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
)
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False"""
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def register_provider(self, provider: Provider):
"""
@@ -103,7 +125,7 @@ class Context:
"""
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider | None:
def get_provider_by_id(self, provider_id: str) -> Provider:
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
@@ -123,49 +145,51 @@ class Context:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str | None = None) -> Provider | None:
def get_using_provider(self, umo: str = None) -> Provider:
"""
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args:
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_provider_inst
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
"""
获取当前使用的用于 TTS 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
provider_type=ProviderType.TEXT_TO_SPEECH,
umo=umo,
)
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_tts_provider_inst
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
"""
获取当前使用的用于 STT 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_stt_provider_inst
def get_config(self, umo: str | None = None) -> AstrBotConfig:
def get_config(self) -> AstrBotConfig:
"""获取 AstrBot 的配置。"""
if not umo:
# using default config
return self._config
else:
return self.astrbot_config_mgr.get_conf(umo)
return self._config
def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
@@ -177,14 +201,9 @@ class Context:
"""
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(
self, platform_type: Union[PlatformAdapterType, str]
) -> Platform | None:
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform:
"""
获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
@@ -198,20 +217,6 @@ class Context:
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""
获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
async def send_message(
self, session: Union[str, MessageSesion], message_chain: MessageChain
) -> bool:
@@ -235,7 +240,7 @@ class Context:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().id == session.platform_name:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False

View File

@@ -11,7 +11,6 @@ from .star_handler import (
register_on_llm_request,
register_on_llm_response,
register_llm_tool,
register_agent,
register_on_decorating_result,
register_after_message_sent,
)
@@ -29,7 +28,6 @@ __all__ = [
"register_on_llm_request",
"register_on_llm_response",
"register_llm_tool",
"register_agent",
"register_on_decorating_result",
"register_after_message_sent",
]

View File

@@ -15,11 +15,6 @@ from ..filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core.agent.agent import Agent
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.astr_agent_context import AstrAgentContext
def get_handler_full_name(awaitable: Awaitable) -> str:
@@ -311,7 +306,7 @@ def register_on_llm_response(**kwargs):
return decorator
def register_llm_tool(name: str = None, **kwargs):
def register_llm_tool(name: str = None):
"""为函数调用function-calling / tools-use添加工具。
请务必按照以下格式编写一个工具包括函数注释AstrBot 会尝试解析该函数注释)
@@ -345,9 +340,6 @@ def register_llm_tool(name: str = None, **kwargs):
"""
name_ = name
registering_agent = None
if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Awaitable):
llm_tool_name = name_ if name_ else awaitable.__name__
@@ -365,65 +357,11 @@ def register_llm_tool(name: str = None, **kwargs):
"description": arg.description,
}
)
# print(llm_tool_name, registering_agent)
if not registering_agent:
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(
llm_tool_name, args, docstring.description.strip(), md.handler
)
else:
assert isinstance(registering_agent, RegisteringAgent)
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
if registering_agent._agent.tools is None:
registering_agent._agent.tools = []
registering_agent._agent.tools.append(llm_tools.spec_to_func(
llm_tool_name, args, docstring.description.strip(), awaitable
))
return awaitable
return decorator
class RegisteringAgent:
"""用于 Agent 注册"""
def llm_tool(self, *args, **kwargs):
kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[AstrAgentContext]):
self._agent = agent
def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
):
"""注册一个 Agent
Args:
name: Agent 的名称
instruction: Agent 的指令
tools: Agent 使用的工具列表
run_hooks: Agent 运行时的钩子函数
"""
tools_ = tools or []
def decorator(awaitable: Awaitable):
AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent(
name=name,
instructions=instruction,
tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(
llm_tool_name, args, docstring.description.strip(), md.handler
)
handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler=awaitable
llm_tools.func_list.append(handoff_tool)
return RegisteringAgent(agent)
return awaitable
return decorator

View File

@@ -2,6 +2,8 @@
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
"""
from typing import Dict
from astrbot.core import logger, sp
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -24,9 +26,8 @@ class SessionServiceManager:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config", {}, scope="umo", scope_id=session_id
)
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
# 如果配置了该会话的LLM状态返回该状态
llm_enabled = session_services.get("llm_enabled")
@@ -44,13 +45,16 @@ class SessionServiceManager:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
)
session_config["llm_enabled"] = enabled
sp.put(
"session_service_config", session_config, scope="umo", scope_id=session_id
)
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置LLM状态
session_config[session_id]["llm_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
logger.info(
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
@@ -84,9 +88,8 @@ class SessionServiceManager:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config", {}, scope="umo", scope_id=session_id
)
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
# 如果配置了该会话的TTS状态返回该状态
tts_enabled = session_services.get("tts_enabled")
@@ -104,13 +107,16 @@ class SessionServiceManager:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
)
session_config["tts_enabled"] = enabled
sp.put(
"session_service_config", session_config, scope="umo", scope_id=session_id
)
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置TTS状态
session_config[session_id]["tts_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
logger.info(
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
@@ -144,9 +150,8 @@ class SessionServiceManager:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config", {}, scope="umo", scope_id=session_id
)
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
# 如果配置了该会话的整体状态,返回该状态
session_enabled = session_services.get("session_enabled")
@@ -164,13 +169,16 @@ class SessionServiceManager:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
)
session_config["session_enabled"] = enabled
sp.put(
"session_service_config", session_config, scope="umo", scope_id=session_id
)
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置会话整体状态
session_config[session_id]["session_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
logger.info(
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
@@ -194,7 +202,7 @@ class SessionServiceManager:
# =============================================================================
@staticmethod
def get_session_custom_name(session_id: str) -> str | None:
def get_session_custom_name(session_id: str) -> str:
"""获取会话的自定义名称
Args:
@@ -203,9 +211,8 @@ class SessionServiceManager:
Returns:
str: 自定义名称如果没有设置则返回None
"""
session_services = sp.get(
"session_service_config", {}, scope="umo", scope_id=session_id
)
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
return session_services.get("custom_name")
@staticmethod
@@ -216,17 +223,20 @@ class SessionServiceManager:
session_id: 会话ID (unified_msg_origin)
custom_name: 自定义名称,可以为空字符串来清除名称
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
)
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置自定义名称
if custom_name and custom_name.strip():
session_config["custom_name"] = custom_name.strip()
session_config[session_id]["custom_name"] = custom_name.strip()
else:
# 如果传入空名称,则删除自定义名称
session_config.pop("custom_name", None)
sp.put(
"session_service_config", session_config, scope="umo", scope_id=session_id
)
session_config[session_id].pop("custom_name", None)
# 保存配置
sp.put("session_service_config", session_config)
logger.info(
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
@@ -248,3 +258,36 @@ class SessionServiceManager:
# 如果没有自定义名称返回session_id的最后一段
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
# =============================================================================
# 通用配置方法
# =============================================================================
@staticmethod
def get_session_service_config(session_id: str) -> Dict[str, bool]:
"""获取指定会话的服务配置
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典
"""
session_config = sp.get("session_service_config", {}) or {}
return session_config.get(
session_id,
{
"session_enabled": True, # 默认启用
"llm_enabled": True, # 默认启用
"tts_enabled": True, # 默认启用
},
)
@staticmethod
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
"""获取所有会话的服务配置
Returns:
Dict[str, Dict[str, bool]]: 所有会话的服务配置
"""
return sp.get("session_service_config", {}) or {}

View File

@@ -22,9 +22,7 @@ class SessionPluginManager:
bool: True表示启用False表示禁用
"""
# 获取会话插件配置
session_plugin_config = sp.get(
"session_plugin_config", {}, scope="umo", scope_id=session_id
)
session_plugin_config = sp.get("session_plugin_config", {}) or {}
session_config = session_plugin_config.get(session_id, {})
enabled_plugins = session_config.get("enabled_plugins", [])
@@ -53,9 +51,7 @@ class SessionPluginManager:
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_plugin_config = sp.get(
"session_plugin_config", {}, scope="umo", scope_id=session_id
)
session_plugin_config = sp.get("session_plugin_config", {}) or {}
if session_id not in session_plugin_config:
session_plugin_config[session_id] = {
"enabled_plugins": [],
@@ -83,9 +79,7 @@ class SessionPluginManager:
session_config["enabled_plugins"] = enabled_plugins
session_config["disabled_plugins"] = disabled_plugins
session_plugin_config[session_id] = session_config
sp.put(
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
)
sp.put("session_plugin_config", session_plugin_config)
logger.info(
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
@@ -101,9 +95,7 @@ class SessionPluginManager:
Returns:
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
"""
session_plugin_config = sp.get(
"session_plugin_config", {}, scope="umo", scope_id=session_id
)
session_plugin_config = sp.get("session_plugin_config", {}) or {}
return session_plugin_config.get(
session_id, {"enabled_plugins": [], "disabled_plugins": []}
)

View File

@@ -56,8 +56,32 @@ class StarMetadata:
star_handler_full_names: list[str] = field(default_factory=list)
"""注册的 Handler 的全名列表"""
supported_platforms: dict[str, bool] = field(default_factory=dict)
"""插件支持的平台ID字典key为平台IDvalue为是否支持"""
def __str__(self) -> str:
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
def __repr__(self) -> str:
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
"""更新插件支持的平台列表
Args:
plugin_enable_config: 平台插件启用配置即platform_settings.plugin_enable配置项
"""
if not plugin_enable_config:
return
# 清空之前的配置
self.supported_platforms.clear()
# 遍历所有平台配置
for platform_id, plugins in plugin_enable_config.items():
# 检查该插件在当前平台的配置
if self.name in plugins:
self.supported_platforms[platform_id] = plugins[self.name]
else:
# 如果没有明确配置,默认为启用
self.supported_platforms[platform_id] = True

View File

@@ -7,7 +7,6 @@ from .star import star_map
T = TypeVar("T", bound="StarHandlerMetadata")
class StarHandlerRegistry(Generic[T]):
def __init__(self):
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
@@ -27,10 +26,7 @@ class StarHandlerRegistry(Generic[T]):
print(handler.handler_full_name)
def get_handlers_by_event_type(
self,
event_type: EventType,
only_activated=True,
plugins_name: list[str] | None = None,
self, event_type: EventType, only_activated=True, platform_id=None
) -> List[StarHandlerMetadata]:
handlers = []
for handler in self._handlers:
@@ -40,15 +36,8 @@ class StarHandlerRegistry(Generic[T]):
plugin = star_map.get(handler.handler_module_path)
if not (plugin and plugin.activated):
continue
if plugins_name is not None and plugins_name != ["*"]:
plugin = star_map.get(handler.handler_module_path)
if not plugin:
continue
if (
plugin.name not in plugins_name
and event_type != EventType.OnAstrBotLoadedEvent
and not plugin.reserved
):
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
if not handler.is_enabled_for_platform(platform_id):
continue
handlers.append(handler)
return handlers
@@ -60,8 +49,7 @@ class StarHandlerRegistry(Generic[T]):
self, module_name: str
) -> List[StarHandlerMetadata]:
return [
handler
for handler in self._handlers
handler for handler in self._handlers
if handler.handler_module_path == module_name
]
@@ -79,7 +67,6 @@ class StarHandlerRegistry(Generic[T]):
def __len__(self):
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry()
@@ -132,3 +119,32 @@ class StarHandlerMetadata:
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
"priority", 0
)
def is_enabled_for_platform(self, platform_id: str) -> bool:
"""检查插件是否在指定平台启用
Args:
platform_id: 平台ID这是从event.get_platform_id()获取的,用于唯一标识平台实例
Returns:
bool: 是否启用True表示启用False表示禁用
"""
plugin = star_map.get(self.handler_module_path)
# 如果插件元数据不存在,默认允许执行
if not plugin or not plugin.name:
return True
# 先检查插件是否被激活
if not plugin.activated:
return False
# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
if (
hasattr(plugin, "supported_platforms")
and platform_id in plugin.supported_platforms
):
return plugin.supported_platforms[platform_id]
# 如果没有缓存数据,默认允许执行
return True

View File

@@ -22,7 +22,6 @@ from astrbot.core.utils.astrbot_path import (
get_astrbot_plugin_path,
)
from astrbot.core.utils.io import remove_dir
from astrbot.core.agent.handoff import HandoffTool, FunctionTool
from . import StarMetadata
from .context import Context
@@ -337,8 +336,30 @@ class PluginManager:
result = await self.load(specified_module_path)
# 更新所有插件的平台兼容性
await self.update_all_platform_compatibility()
return result
async def update_all_platform_compatibility(self):
"""更新所有插件的平台兼容性设置"""
# 获取最新的平台插件启用配置
plugin_enable_config = self.config.get("platform_settings", {}).get(
"plugin_enable", {}
)
logger.debug(
f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}"
)
# 遍历所有插件,更新平台兼容性
for plugin in self.context.get_all_stars():
plugin.update_platform_compatibility(plugin_enable_config)
logger.debug(
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
)
return True
async def load(self, specified_module_path=None, specified_dir_name=None):
"""载入插件。
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。
@@ -352,9 +373,10 @@ class PluginManager:
- success (bool): 是否全部加载成功
- error_message (str|None): 错误信息,成功时为 None
"""
inactivated_plugins = await sp.global_get("inactivated_plugins", [])
inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", [])
alter_cmd = await sp.global_get("alter_cmd", {})
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
alter_cmd = sp.get("alter_cmd", {})
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
@@ -458,6 +480,12 @@ class PluginManager:
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
# 更新插件的平台兼容性
plugin_enable_config = self.config.get("platform_settings", {}).get(
"plugin_enable", {}
)
metadata.update_platform_compatibility(plugin_enable_config)
assert metadata.module_path is not None, (
f"插件 {metadata.name} 的模块路径为空。"
)
@@ -475,27 +503,17 @@ class PluginManager:
)
# 绑定 llm_tool handler
for func_tool in llm_tools.func_list:
if isinstance(func_tool, HandoffTool):
need_apply = []
sub_tools = func_tool.agent.tools
for sub_tool in sub_tools:
if isinstance(sub_tool, FunctionTool):
need_apply.append(sub_tool)
else:
need_apply = [func_tool]
for ft in need_apply:
if (
ft.handler
and ft.handler.__module__ == metadata.module_path
):
ft.handler_module_path = metadata.module_path
ft.handler = functools.partial(
ft.handler,
metadata.star_cls, # type: ignore
)
if ft.name in inactivated_llm_tools:
ft.active = False
if (
func_tool.handler
and func_tool.handler.__module__ == metadata.module_path
):
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(
func_tool.handler,
metadata.star_cls, # type: ignore
)
if func_tool.name in inactivated_llm_tools:
func_tool.active = False
else:
# v3.4.0 以前的方式注册插件
@@ -758,12 +776,12 @@ class PluginManager:
await self._terminate_plugin(plugin)
# 加入到 shared_preferences 中
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
inactivated_plugins: list = sp.get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = list(
set(await sp.global_get("inactivated_llm_tools", []))
set(sp.get("inactivated_llm_tools", []))
) # 后向兼容
# 禁用插件启用的 llm_tool
@@ -773,8 +791,8 @@ class PluginManager:
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
await sp.global_put("inactivated_plugins", inactivated_plugins)
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = False
@@ -791,20 +809,20 @@ class PluginManager:
if star_metadata.star_cls is None:
return
if '__del__' in star_metadata.star_cls_type.__dict__:
if hasattr(star_metadata.star_cls, "__del__"):
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
)
elif 'terminate' in star_metadata.star_cls_type.__dict__:
elif hasattr(star_metadata.star_cls, "terminate"):
await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:
inactivated_plugins.remove(plugin.module_path)
await sp.global_put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_plugins", inactivated_plugins)
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
@@ -814,7 +832,7 @@ class PluginManager:
):
inactivated_llm_tools.remove(func_tool.name)
func_tool.active = True
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
await self.reload(plugin_name)

View File

@@ -89,6 +89,7 @@ class AstrBotUpdator(RepoZipUpdator):
file_url = update_data[0]["zipball_url"]
elif str(version).startswith("v"):
# 更新到指定版本
logger.info(f"正在更新到指定版本: {version}")
for data in update_data:
if data["tag_name"] == version:
file_url = data["zipball_url"]
@@ -97,8 +98,8 @@ class AstrBotUpdator(RepoZipUpdator):
else:
if len(str(version)) != 40:
raise Exception("commit hash 长度不正确,应为 40")
file_url = f"https://github.com/Soulter/AstrBot/archive/{version}.zip"
logger.info(f"准备更新至指定版本的 AstrBot Core: {version}")
logger.info(f"正在尝试更新到指定 commit: {version}")
file_url = "https://github.com/Soulter/AstrBot/archive/" + version + ".zip"
if proxy:
proxy = proxy.removesuffix("/")
@@ -106,7 +107,6 @@ class AstrBotUpdator(RepoZipUpdator):
try:
await download_file(file_url, "temp.zip")
logger.info("下载 AstrBot Core 更新文件完成,正在执行解压...")
self.unzip_file("temp.zip", self.MAIN_PATH)
except BaseException as e:
raise e

View File

@@ -8,7 +8,6 @@ import base64
import zipfile
import uuid
import psutil
import logging
import certifi
@@ -17,8 +16,6 @@ from typing import Union
from PIL import Image
from .astrbot_path import get_astrbot_data_path
logger = logging.getLogger("astrbot")
def on_error(func, path, exc_info):
"""
@@ -215,50 +212,19 @@ async def get_dashboard_version():
return None
async def download_dashboard(
path: str | None = None,
extract_path: str = "data",
latest: bool = True,
version: str | None = None,
proxy: str | None = None,
):
async def download_dashboard(path: str = None, extract_path: str = "data"):
"""下载管理面板文件"""
if path is None:
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
if latest or len(str(version)) != 40:
logger.info("准备下载最新发行版本的 AstrBot WebUI")
ver_name = "latest" if latest else version
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
try:
await download_file(dashboard_release_url, path, show_progress=True)
except BaseException as _:
if latest:
dashboard_release_url = "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
else:
dashboard_release_url = f"https://github.com/Soulter/AstrBot/releases/download/{version}/dist.zip"
if proxy:
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
await download_file(dashboard_release_url, path, show_progress=True)
else:
logger.info(f"准备下载指定版本的 AstrBot WebUI: {version}")
url = (
"https://api.github.com/repos/AstrBotDevs/astrbot-release-harbour/releases"
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
try:
await download_file(dashboard_release_url, path, show_progress=True)
except BaseException as _:
dashboard_release_url = (
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
)
if proxy:
url = f"{proxy}/{url}"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as resp:
if resp.status == 200:
releases = await resp.json()
for release in releases:
if version in release["tag_name"]:
download_url = release["assets"][0]["browser_download_url"]
await download_file(download_url, path, show_progress=True)
else:
logger.warning(f"未找到指定的版本的 Dashboard 构建文件: {version}")
return
await download_file(dashboard_release_url, path, show_progress=True)
print("解压管理面板文件中...")
with zipfile.ZipFile(path, "r") as z:
z.extractall(extract_path)

View File

@@ -58,10 +58,9 @@ class Metric:
pass
try:
if "adapter_name" in kwargs:
await db_helper.insert_platform_stats(
platform_id=kwargs["adapter_name"],
platform_type=kwargs.get("adapter_type", "unknown"),
)
db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1})
if "llm_name" in kwargs:
db_helper.insert_llm_metrics({kwargs["llm_name"]: 1})
except Exception as e:
logger.error(f"保存指标到数据库失败: {e}")
pass

View File

@@ -1,180 +1,43 @@
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Preference
import threading
import asyncio
import json
import os
from typing import TypeVar, Any, overload
from typing import TypeVar
from .astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, db_helper: BaseDatabase, json_storage_path=None):
if json_storage_path is None:
json_storage_path = os.path.join(
get_astrbot_data_path(), "shared_preferences.json"
)
self.path = json_storage_path
self.db_helper = db_helper
def __init__(self, path=None):
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
self._data = self._load_preferences()
self._sync_loop = asyncio.new_event_loop()
t = threading.Thread(target=self._sync_loop.run_forever, daemon=True)
t.start()
def _load_preferences(self):
if os.path.exists(self.path):
try:
with open(self.path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
os.remove(self.path)
return {}
async def get_async(
self,
scope: str,
scope_id: str,
key: str,
default: _VT = None,
) -> _VT:
"""获取指定范围和键的偏好设置"""
if scope_id is not None and key is not None:
result = await self.db_helper.get_preference(scope, scope_id, key)
if result:
ret = result.value["val"]
else:
ret = default
return ret
else:
raise ValueError(
"scope_id and key cannot be None when getting a specific preference."
)
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
async def range_get_async(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""获取指定范围的偏好设置
Note: 返回 Preference 列表,其中的 value 属性是一个 dictvalue["val"] 为值。scope_id 和 key 可以为 None这时返回该范围下所有的偏好设置。
"""
ret = await self.db_helper.get_preferences(scope, scope_id, key)
return ret
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
@overload
async def session_get(
self, umo: None, key: str, default: Any = None
) -> list[Preference]: ...
def put(self, key, value):
self._data[key] = value
self._save_preferences()
@overload
async def session_get(
self, umo: str, key: None, default: Any = None
) -> list[Preference]: ...
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
@overload
async def session_get(
self, umo: None, key: None, default: Any = None
) -> list[Preference]: ...
async def session_get(
self, umo: str | None, key: str | None = None, default: _VT = None
) -> _VT | list[Preference]:
"""获取会话范围的偏好设置
Note: 当 scope_id 或者 key 为 None返回 Preference 列表,其中的 value 属性是一个 dictvalue["val"] 为值。
"""
if umo is None or key is None:
return await self.range_get_async("umo", umo, key)
return await self.get_async("umo", umo, key, default)
@overload
async def global_get(self, key: None, default: Any = None) -> list[Preference]: ...
@overload
async def global_get(self, key: str, default: _VT = None) -> _VT: ...
async def global_get(
self, key: str | None, default: _VT = None
) -> _VT | list[Preference]:
"""获取全局范围的偏好设置
Note: 当 scope_id 或者 key 为 None返回 Preference 列表,其中的 value 属性是一个 dictvalue["val"] 为值。
"""
if key is None:
return await self.range_get_async("global", "global", key)
return await self.get_async("global", "global", key, default)
async def put_async(self, scope: str, scope_id: str, key: str, value: Any):
"""设置指定范围和键的偏好设置"""
await self.db_helper.insert_preference_or_update(
scope, scope_id, key, {"val": value}
)
async def session_put(self, umo: str, key: str, value: Any):
await self.put_async("umo", umo, key, value)
async def global_put(self, key: str, value: Any):
await self.put_async("global", "global", key, value)
async def remove_async(self, scope: str, scope_id: str, key: str):
"""删除指定范围和键的偏好设置"""
await self.db_helper.remove_preference(scope, scope_id, key)
async def session_remove(self, umo: str, key: str):
await self.remove_async("umo", umo, key)
async def global_remove(self, key: str):
"""删除全局偏好设置"""
await self.remove_async("global", "global", key)
async def clear_async(self, scope: str, scope_id: str):
"""清空指定范围的所有偏好设置"""
await self.db_helper.clear_preferences(scope, scope_id)
# ====
# DEPRECATED METHODS
# ====
def get(
self,
key: str,
default: _VT = None,
scope: str | None = None,
scope_id: str | None = "",
) -> _VT:
"""获取偏好设置(已弃用)"""
if scope_id == "":
scope_id = "unknown"
if scope_id is None or key is None:
# result = asyncio.run(self.range_get_async(scope, scope_id, key))
raise ValueError(
"scope_id and key cannot be None when getting a specific preference."
)
result = asyncio.run_coroutine_threadsafe(
self.get_async(scope or "unknown", scope_id or "unknown", key, default),
self._sync_loop,
).result()
return result if result is not None else default
def range_get(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""获取指定范围的偏好设置(已弃用)"""
result = asyncio.run_coroutine_threadsafe(
self.range_get_async(scope, scope_id, key), self._sync_loop
).result()
return result
def put(self, key, value, scope: str | None = None, scope_id: str | None = None):
"""设置偏好设置(已弃用)"""
asyncio.run_coroutine_threadsafe(
self.put_async(scope or "unknown", scope_id or "unknown", key, value),
self._sync_loop,
).result()
def remove(self, key, scope: str | None = None, scope_id: str | None = None):
"""删除偏好设置(已弃用)"""
asyncio.run_coroutine_threadsafe(
self.remove_async(scope or "unknown", scope_id or "unknown", key),
self._sync_loop,
).result()
def clear(self, scope: str | None = None, scope_id: str | None = None):
"""清空偏好设置(已弃用)"""
asyncio.run_coroutine_threadsafe(
self.clear_async(scope or "unknown", scope_id or "unknown"),
self._sync_loop,
).result()
def clear(self):
self._data.clear()
self._save_preferences()

View File

@@ -1,76 +1,37 @@
import aiohttp
import asyncio
import os
import ssl
import certifi
import logging
import random
from . import RenderStrategy
from astrbot.core.config import VERSION
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
CUSTOM_T2I_TEMPLATE_PATH = os.path.join(get_astrbot_data_path(), "t2i_template.html")
logger = logging.getLogger("astrbot")
class NetworkRenderStrategy(RenderStrategy):
def __init__(self, base_url: str | None = None) -> None:
super().__init__()
if not base_url:
self.BASE_RENDER_URL = ASTRBOT_T2I_DEFAULT_ENDPOINT
else:
self.BASE_RENDER_URL = self._clean_url(base_url)
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template", "base.html")
with open(self.TEMPLATE_PATH, "r", encoding="utf-8") as f:
self.DEFAULT_TEMPLATE = f.read()
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
self.BASE_RENDER_URL = base_url
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
self.endpoints = [self.BASE_RENDER_URL]
if self.BASE_RENDER_URL.endswith("/"):
self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1]
if not self.BASE_RENDER_URL.endswith("text2img"):
self.BASE_RENDER_URL += "/text2img"
async def initialize(self):
if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT:
asyncio.create_task(self.get_official_endpoints())
def set_endpoint(self, base_url: str):
if not base_url:
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
self.BASE_RENDER_URL = base_url
async def get_template(self) -> str:
"""获取文转图 HTML 模板
Returns:
str: 文转图 HTML 模板字符串
"""
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
with open(CUSTOM_T2I_TEMPLATE_PATH, "r", encoding="utf-8") as f:
return f.read()
return self.DEFAULT_TEMPLATE
async def get_official_endpoints(self):
"""获取官方的 t2i 端点列表。"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.soulter.top/astrbot/t2i-endpoints"
) as resp:
if resp.status == 200:
data = await resp.json()
all_endpoints: list[dict] = data.get("data", [])
self.endpoints = [
ep.get("url")
for ep in all_endpoints
if ep.get("active") and ep.get("url")
]
logger.info(
f"Successfully got {len(self.endpoints)} official T2I endpoints."
)
except Exception as e:
logger.error(f"Failed to get official endpoints: {e}")
def _clean_url(self, url: str):
if url.endswith("/"):
url = url[:-1]
if not url.endswith("text2img"):
url += "/text2img"
return url
if self.BASE_RENDER_URL.endswith("/"):
self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1]
if not self.BASE_RENDER_URL.endswith("text2img"):
self.BASE_RENDER_URL += "/text2img"
async def render_custom_template(
self,
@@ -80,7 +41,6 @@ class NetworkRenderStrategy(RenderStrategy):
options: dict | None = None,
) -> str:
"""使用自定义文转图模板"""
default_options = {"full_page": True, "type": "jpeg", "quality": 40}
if options:
default_options |= options
@@ -91,44 +51,30 @@ class NetworkRenderStrategy(RenderStrategy):
"tmpldata": tmpl_data,
"options": default_options,
}
endpoints = self.endpoints.copy() if self.endpoints else [self.BASE_RENDER_URL]
random.shuffle(endpoints)
last_exception = None
for endpoint in endpoints:
try:
if return_url:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.post(
f"{endpoint}/generate", json=post_data
) as resp:
if resp.status == 200:
ret = await resp.json()
return f"{endpoint}/{ret['data']['id']}"
else:
raise Exception(f"HTTP {resp.status}")
else:
# download_image_by_url 失败时抛异常
return await download_image_by_url(
f"{endpoint}/generate", post=True, post_data=post_data
)
except Exception as e:
last_exception = e
logger.warning(f"Endpoint {endpoint} failed: {e}, trying next...")
continue
# 全部失败
logger.error(f"All endpoints failed: {last_exception}")
raise RuntimeError(f"All endpoints failed: {last_exception}")
if return_url:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.post(
f"{self.BASE_RENDER_URL}/generate", json=post_data
) as resp:
ret = await resp.json()
return f"{self.BASE_RENDER_URL}/{ret['data']['id']}"
return await download_image_by_url(
f"{self.BASE_RENDER_URL}/generate", post=True, post_data=post_data
)
async def render(self, text: str, return_url: bool = False) -> str:
"""
返回图像的文件路径
"""
tmpl_str = await self.get_template()
with open(
os.path.join(self.TEMPLATE_PATH, "base.html"), "r", encoding="utf-8"
) as f:
tmpl_str = f.read()
assert tmpl_str
text = text.replace("`", "\\`")
return await self.render_custom_template(
tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url

View File

@@ -10,8 +10,10 @@ class HtmlRenderer:
self.network_strategy = NetworkRenderStrategy(endpoint_url)
self.local_strategy = LocalRenderStrategy()
async def initialize(self):
await self.network_strategy.initialize()
def set_network_endpoint(self, endpoint_url: str):
"""设置 t2i 的网络端点。"""
logger.info("文本转图像服务接口: " + endpoint_url)
self.network_strategy.set_endpoint(endpoint_url)
async def render_custom_template(
self,

View File

@@ -181,13 +181,14 @@ class RepoZipUpdator:
"""
os.makedirs(target_dir, exist_ok=True)
update_dir = ""
logger.info(f"解压文件: {zip_path}")
with zipfile.ZipFile(zip_path, "r") as z:
update_dir = z.namelist()[0]
z.extractall(target_dir)
logger.debug(f"解压文件完成: {zip_path}")
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
@@ -197,13 +198,13 @@ class RepoZipUpdator:
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
logger.debug(
logger.info(
f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}"
)
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except BaseException:
logger.warning(
logger.warn(
f"删除更新文件失败,可以手动删除 {zip_path}{os.path.join(target_dir, update_dir)}"
)

View File

@@ -6,11 +6,11 @@ from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
from .tools import ToolsRoute
from .tools import ToolsRoute # 导入新的ToolsRoute
from .conversation import ConversationRoute
from .file import FileRoute
from .session_management import SessionManagementRoute
from .persona import PersonaRoute
__all__ = [
"AuthRoute",
@@ -25,5 +25,4 @@ __all__ = [
"ConversationRoute",
"FileRoute",
"SessionManagementRoute",
"PersonaRoute",
]

View File

@@ -9,7 +9,6 @@ import asyncio
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.platform.astr_message_event import MessageSession
class ChatRoute(Route):
@@ -30,15 +29,28 @@ class ChatRoute(Route):
"/chat/get_file": ("GET", self.get_file),
"/chat/post_image": ("POST", self.post_image),
"/chat/post_file": ("POST", self.post_file),
"/chat/status": ("GET", self.status),
}
self.db = db
self.core_lifecycle = core_lifecycle
self.register_routes()
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
os.makedirs(self.imgs_dir, exist_ok=True)
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
self.conv_mgr = core_lifecycle.conversation_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
async def status(self):
has_llm_enabled = (
self.core_lifecycle.provider_manager.curr_provider_inst is not None
)
has_stt_enabled = (
self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
)
return (
Response()
.ok(data={"llm_enabled": has_llm_enabled, "stt_enabled": has_stt_enabled})
.__dict__
)
async def get_file(self):
filename = request.args.get("filename")
@@ -119,23 +131,24 @@ class ChatRoute(Route):
if not conversation_id:
return Response().error("conversation_id is empty").__dict__
# append user message
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
# Get conversation-specific queues
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
back_queue = webchat_queue_mgr.get_or_create_back_queue(conversation_id)
# append user message
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
logger.error(f"Failed to parse conversation history: {e}")
history = []
new_his = {"type": "user", "message": message}
if image_url:
new_his["image_url"] = image_url
if audio_url:
new_his["audio_url"] = audio_url
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id=username,
sender_name=username,
history.append(new_his)
self.db.update_conversation(
username, conversation_id, history=json.dumps(history)
)
async def stream():
@@ -151,6 +164,7 @@ class ChatRoute(Route):
result_text = result["data"]
type = result.get("type")
cid = result.get("cid")
streaming = result.get("streaming", False)
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.05)
@@ -159,13 +173,17 @@ class ChatRoute(Route):
break
elif (streaming and type == "complete") or not streaming:
# append bot message
new_his = {"type": "bot", "message": result_text}
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
conversation = self.db.get_conversation_by_user_id(
username, cid
)
try:
history = json.loads(conversation.history)
except BaseException as e:
logger.error(f"Failed to parse conversation history: {e}")
history = []
history.append({"type": "bot", "message": result_text})
self.db.update_conversation(
username, cid, history=json.dumps(history)
)
except BaseException as _:
@@ -173,11 +191,11 @@ class ChatRoute(Route):
return
# Put message to conversation-specific queue
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
chat_queue = webchat_queue_mgr.get_or_create_queue(conversation_id)
await chat_queue.put(
(
username,
webchat_conv_id,
conversation_id,
{
"message": message,
"image_url": image_url, # list
@@ -199,51 +217,25 @@ class ChatRoute(Route):
)
return response
async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
"""从对话 ID 中提取 WebChat 会话 ID
NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。
"""
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin="webchat", conversation_id=conversation_id
)
if not conversation:
raise ValueError(f"Conversation with ID {conversation_id} not found.")
conv_user_id = conversation.user_id
webchat_session_id = MessageSession.from_str(conv_user_id).session_id
if "!" not in webchat_session_id:
raise ValueError(f"Invalid conv user ID: {conv_user_id}")
return webchat_session_id.split("!")[-1]
async def delete_conversation(self):
username = g.get("username", "guest")
conversation_id = request.args.get("conversation_id")
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
username = g.get("username", "guest")
# Clean up queues when deleting conversation
webchat_queue_mgr.remove_queues(conversation_id)
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
await self.conv_mgr.delete_conversation(
unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
conversation_id=conversation_id,
)
await self.platform_history_mgr.delete(
platform_id="webchat", user_id=webchat_conv_id, offset_sec=99999999
)
self.db.delete_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get("username", "guest")
webchat_conv_id = str(uuid.uuid4())
conv_id = await self.conv_mgr.new_conversation(
unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
platform_id="webchat",
content=[],
)
return Response().ok(data={"conversation_id": conv_id}).__dict__
conversation_id = str(uuid.uuid4())
self.db.new_conversation(username, conversation_id)
return Response().ok(data={"conversation_id": conversation_id}).__dict__
async def rename_conversation(self):
username = g.get("username", "guest")
post_data = await request.json
if "conversation_id" not in post_data or "title" not in post_data:
return Response().error("Missing key: conversation_id or title").__dict__
@@ -251,42 +243,20 @@ class ChatRoute(Route):
conversation_id = post_data["conversation_id"]
title = post_data["title"]
await self.conv_mgr.update_conversation(
unified_msg_origin="webchat", # fake
conversation_id=conversation_id,
title=title,
)
self.db.update_conversation_title(username, conversation_id, title=title)
return Response().ok(message="重命名成功!").__dict__
async def get_conversations(self):
conversations = await self.conv_mgr.get_conversations(platform_id="webchat")
# remove content
conversations_ = []
for conv in conversations:
conv.history = None
conversations_.append(conv)
return Response().ok(data=conversations_).__dict__
username = g.get("username", "guest")
conversations = self.db.get_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
username = g.get("username", "guest")
conversation_id = request.args.get("conversation_id")
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
# Get platform message history
history_ls = await self.platform_history_mgr.get(
platform_id="webchat", user_id=webchat_conv_id, page=1, page_size=1000
)
history_res = [history.model_dump() for history in history_ls]
return (
Response()
.ok(
data={
"history": history_res,
}
)
.__dict__
)
return Response().ok(data=conversation).__dict__

View File

@@ -4,22 +4,16 @@ import os
from .route import Route, Response, RouteContext
from astrbot.core.provider.entities import ProviderType
from quart import request
from astrbot.core.config.default import (
CONFIG_METADATA_2,
DEFAULT_VALUE_MAP,
CONFIG_METADATA_3,
CONFIG_METADATA_3_SYSTEM,
)
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.utils.astrbot_path import get_astrbot_path
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger, html_renderer
from astrbot.core import logger
from astrbot.core.provider import Provider
import asyncio
from astrbot.core.utils.t2i.network_strategy import CUSTOM_T2I_TEMPLATE_PATH
def try_cast(value: str, type_: str):
@@ -165,166 +159,33 @@ class ConfigRoute(Route):
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.config: AstrBotConfig = core_lifecycle.astrbot_config
self.acm = core_lifecycle.astrbot_config_mgr
self.routes = {
"/config/abconf/new": ("POST", self.create_abconf),
"/config/abconf": ("GET", self.get_abconf),
"/config/abconfs": ("GET", self.get_abconf_list),
"/config/abconf/delete": ("POST", self.delete_abconf),
"/config/abconf/update": ("POST", self.update_abconf),
"/config/get": ("GET", self.get_configs),
"/config/astrbot/update": ("POST", self.post_astrbot_configs),
"/config/plugin/update": ("POST", self.post_plugin_configs),
"/config/platform/new": ("POST", self.post_new_platform),
"/config/platform/update": ("POST", self.post_update_platform),
"/config/platform/delete": ("POST", self.post_delete_platform),
"/config/platform/list": ("GET", self.get_platform_list),
"/config/provider/new": ("POST", self.post_new_provider),
"/config/provider/update": ("POST", self.post_update_provider),
"/config/provider/delete": ("POST", self.post_delete_provider),
"/config/llmtools": ("GET", self.get_llm_tools),
"/config/provider/check_one": ("GET", self.check_one_provider_status),
"/config/provider/list": ("GET", self.get_provider_config_list),
"/config/provider/model_list": ("GET", self.get_provider_model_list),
"/config/astrbot/t2i-template/get": ("GET", self.get_t2i_template),
"/config/astrbot/t2i-template/save": ("POST", self.post_t2i_template),
"/config/astrbot/t2i-template/delete": ("DELETE", self.delete_t2i_template),
"/config/provider/get_session_seperate": (
"GET",
lambda: Response()
.ok({"enable": self.config["provider_settings"]["separate_provider"]})
.__dict__,
),
"/config/provider/set_session_seperate": (
"POST",
self.post_session_seperate,
),
}
self.register_routes()
async def get_t2i_template(self):
"""获取 T2I 模板"""
try:
template = await html_renderer.network_strategy.get_template()
has_custom_template = os.path.exists(CUSTOM_T2I_TEMPLATE_PATH)
return (
Response()
.ok({"template": template, "has_custom_template": has_custom_template})
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取模板失败: {str(e)}").__dict__
async def post_t2i_template(self):
"""保存 T2I 模板"""
try:
post_data = await request.json
if not post_data or "template" not in post_data:
return Response().error("缺少模板内容").__dict__
template_content = post_data["template"]
# 保存自定义模板到文件
with open(CUSTOM_T2I_TEMPLATE_PATH, "w", encoding="utf-8") as f:
f.write(template_content)
return Response().ok(message="模板保存成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"保存模板失败: {str(e)}").__dict__
async def delete_t2i_template(self):
"""删除自定义 T2I 模板,恢复默认模板"""
try:
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
os.remove(CUSTOM_T2I_TEMPLATE_PATH)
return Response().ok(message="已恢复默认模板").__dict__
else:
return Response().ok(message="未找到自定义模板文件").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"删除模板失败: {str(e)}").__dict__
async def get_abconf_list(self):
"""获取所有 AstrBot 配置文件的列表"""
abconf_list = self.acm.get_conf_list()
return Response().ok({"info_list": abconf_list}).__dict__
async def create_abconf(self):
"""创建新的 AstrBot 配置文件"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
umo_parts = post_data["umo_parts"]
name = post_data.get("name", None)
try:
conf_id = self.acm.create_conf(umo_parts=umo_parts, name=name)
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
async def get_abconf(self):
"""获取指定 AstrBot 配置文件"""
abconf_id = request.args.get("id")
system_config = request.args.get("system_config", "0").lower() == "1"
if not abconf_id and not system_config:
return Response().error("缺少配置文件 ID").__dict__
try:
if system_config:
abconf = self.acm.confs["default"]
return (
Response()
.ok({"config": abconf, "metadata": CONFIG_METADATA_3_SYSTEM})
.__dict__
)
abconf = self.acm.confs[abconf_id]
return (
Response()
.ok({"config": abconf, "metadata": CONFIG_METADATA_3})
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
async def delete_abconf(self):
"""删除指定 AstrBot 配置文件"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
conf_id = post_data.get("id")
if not conf_id:
return Response().error("缺少配置文件 ID").__dict__
try:
success = self.acm.delete_conf(conf_id)
if success:
return Response().ok(message="删除成功").__dict__
else:
return Response().error("删除失败").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"删除配置文件失败: {str(e)}").__dict__
async def update_abconf(self):
"""更新指定 AstrBot 配置文件信息"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
conf_id = post_data.get("id")
if not conf_id:
return Response().error("缺少配置文件 ID").__dict__
name = post_data.get("name")
umo_parts = post_data.get("umo_parts")
try:
success = self.acm.update_conf_info(conf_id, name=name, umo_parts=umo_parts)
if success:
return Response().ok(message="更新成功").__dict__
else:
return Response().error("更新失败").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"更新配置文件失败: {str(e)}").__dict__
async def _test_single_provider(self, provider):
"""辅助函数:测试单个 provider 的可用性"""
meta = provider.meta()
@@ -349,16 +210,11 @@ class ConfigRoute(Route):
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
)
logger.debug(
f"Received response from {status_info['name']}: {response}"
)
logger.debug(f"Received response from {status_info['name']}: {response}")
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if (
hasattr(response, "completion_text")
and response.completion_text
):
if hasattr(response, "completion_text") and response.completion_text:
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
@@ -377,48 +233,29 @@ class ConfigRoute(Route):
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
)
else:
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
)
status_info["error"] = "Test call returned None, but expected an LLMResponse object."
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.")
except asyncio.TimeoutError:
status_info["error"] = (
"Connection timed out after 45 seconds during test call."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
)
status_info["error"] = "Connection timed out after 45 seconds during test call."
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.")
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
)
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}")
logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}")
elif provider_capability_type == ProviderType.EMBEDDING:
try:
# For embedding, we can call the get_embedding method with a short prompt.
embedding_result = await provider.get_embedding("health_check")
if isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
):
if isinstance(embedding_result, list) and (not embedding_result or isinstance(embedding_result[0], float)):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"Embedding test failed: unexpected result type {type(embedding_result)}"
)
status_info["error"] = f"Embedding test failed: unexpected result type {type(embedding_result)}"
except Exception as e:
logger.error(
f"Error testing embedding provider {provider_name}: {e}",
exc_info=True,
)
logger.error(f"Error testing embedding provider {provider_name}: {e}", exc_info=True)
status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: {str(e)}"
@@ -430,71 +267,41 @@ class ConfigRoute(Route):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"TTS test failed: unexpected result type {type(audio_result)}"
)
status_info["error"] = f"TTS test failed: unexpected result type {type(audio_result)}"
except Exception as e:
logger.error(
f"Error testing TTS provider {provider_name}: {e}", exc_info=True
)
logger.error(f"Error testing TTS provider {provider_name}: {e}", exc_info=True)
status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: {str(e)}"
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
try:
logger.debug(
f"Sending health check audio to provider: {status_info['name']}"
)
sample_audio_path = os.path.join(
get_astrbot_path(), "samples", "stt_health_check.wav"
)
logger.debug(f"Sending health check audio to provider: {status_info['name']}")
sample_audio_path = os.path.join(get_astrbot_path(), "samples", "stt_health_check.wav")
if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable"
status_info["error"] = (
"STT test failed: sample audio file not found."
)
logger.warning(
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}"
)
status_info["error"] = "STT test failed: sample audio file not found."
logger.warning(f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}")
else:
text_result = await provider.get_text(sample_audio_path)
if isinstance(text_result, str) and text_result:
status_info["status"] = "available"
snippet = (
text_result[:70] + "..."
if len(text_result) > 70
else text_result
)
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'"
)
snippet = text_result[:70] + "..." if len(text_result) > 70 else text_result
logger.info(f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'")
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"STT test failed: unexpected result type {type(text_result)}"
)
logger.warning(
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}"
)
status_info["error"] = f"STT test failed: unexpected result type {type(text_result)}"
logger.warning(f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}")
except Exception as e:
logger.error(
f"Error testing STT provider {provider_name}: {e}", exc_info=True
)
logger.error(f"Error testing STT provider {provider_name}: {e}", exc_info=True)
status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {str(e)}"
else:
logger.debug(
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}"
)
logger.debug(f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}")
status_info["status"] = "available"
status_info["error"] = (
"This provider type is not tested and is assumed to be available."
)
status_info["error"] = "This provider type is not tested and is assumed to be available."
return status_info
def _error_response(
self, message: str, status_code: int = 500, log_fn=logger.error
):
def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error):
log_fn(message)
# 记录更详细的traceback信息但只在是严重错误时
if status_code == 500:
@@ -505,9 +312,7 @@ class ConfigRoute(Route):
"""API: check a single LLM Provider's status by id"""
provider_id = request.args.get("id")
if not provider_id:
return self._error_response(
"Missing provider_id parameter", 400, logger.warning
)
return self._error_response("Missing provider_id parameter", 400, logger.warning)
logger.info(f"API call: /config/provider/check_one id={provider_id}")
try:
@@ -515,21 +320,16 @@ class ConfigRoute(Route):
target = prov_mgr.inst_map.get(provider_id)
if not target:
logger.warning(
f"Provider with id '{provider_id}' not found in provider_manager."
)
return (
Response()
.error(f"Provider with id '{provider_id}' not found")
.__dict__
)
logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.")
return Response().error(f"Provider with id '{provider_id}' not found").__dict__
result = await self._test_single_provider(target)
return Response().ok(result).__dict__
except Exception as e:
return self._error_response(
f"Critical error checking provider {provider_id}: {e}", 500
f"Critical error checking provider {provider_id}: {e}",
500
)
async def get_configs(self):
@@ -540,15 +340,29 @@ class ConfigRoute(Route):
return Response().ok(await self._get_astrbot_config()).__dict__
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
async def post_session_seperate(self):
"""设置提供商会话隔离"""
post_config = await request.json
enable = post_config.get("enable", None)
if enable is None:
return Response().error("缺少参数 enable").__dict__
astrbot_config = self.core_lifecycle.astrbot_config
astrbot_config["provider_settings"]["separate_provider"] = enable
try:
astrbot_config.save_config()
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "设置成功~").__dict__
async def get_provider_config_list(self):
provider_type = request.args.get("provider_type", None)
if not provider_type:
return Response().error("缺少参数 provider_type").__dict__
provider_type_ls = provider_type.split(",")
provider_list = []
astrbot_config = self.core_lifecycle.astrbot_config
for provider in astrbot_config["provider"]:
if provider.get("provider_type", None) in provider_type_ls:
if provider.get("provider_type", None) == provider_type:
provider_list.append(provider)
return Response().ok(provider_list).__dict__
@@ -574,21 +388,11 @@ class ConfigRoute(Route):
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def get_platform_list(self):
"""获取所有平台的列表"""
platform_list = []
for platform in self.config["platform"]:
platform_list.append(platform)
return Response().ok({"platforms": platform_list}).__dict__
async def post_astrbot_configs(self):
data = await request.json
config = data.get("config", None)
conf_id = data.get("conf_id", None)
post_configs = await request.json
try:
await self._save_astrbot_configs(config, conf_id)
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
return Response().ok(None, "保存成功~").__dict__
await self._save_astrbot_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
@@ -705,6 +509,12 @@ class ConfigRoute(Route):
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
async def get_llm_tools(self):
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools = tool_mgr.get_func_desc_openai_style()
return Response().ok(tools).__dict__
async def _get_astrbot_config(self):
config = self.config
@@ -747,12 +557,10 @@ class ConfigRoute(Route):
return ret
async def _save_astrbot_configs(self, post_configs: dict, conf_id: str = None):
async def _save_astrbot_configs(self, post_configs: dict):
try:
if conf_id not in self.acm.confs:
raise ValueError(f"配置文件 {conf_id} 不存在")
astrbot_config = self.acm.confs[conf_id]
save_config(post_configs, astrbot_config, is_core=True)
save_config(post_configs, self.config, is_core=True)
await self.core_lifecycle.restart()
except Exception as e:
raise e

View File

@@ -29,7 +29,6 @@ class ConversationRoute(Route):
),
}
self.db_helper = db_helper
self.conv_mgr = core_lifecycle.conversation_manager
self.core_lifecycle = core_lifecycle
self.register_routes()
@@ -55,6 +54,7 @@ class ConversationRoute(Route):
exclude_platforms.split(",") if exclude_platforms else []
)
# 限制页面大小,防止请求过大数据
if page < 1:
page = 1
if page_size < 1:
@@ -62,11 +62,9 @@ class ConversationRoute(Route):
if page_size > 100:
page_size = 100
# 使用数据库的分页方法获取会话列表和总数,传入筛选条件
try:
(
conversations,
total_count,
) = await self.conv_mgr.get_filtered_conversations(
conversations, total_count = self.db_helper.get_filtered_conversations(
page=page,
page_size=page_size,
platforms=platform_list,
@@ -110,9 +108,7 @@ class ConversationRoute(Route):
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id, conversation_id=cid
)
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
@@ -147,18 +143,14 @@ class ConversationRoute(Route):
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id, conversation_id=cid
)
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
if title is not None or persona_id is not None:
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
title=title,
persona_id=persona_id,
)
if title is not None:
self.db_helper.update_conversation_title(user_id, cid, title)
if persona_id is not None:
self.db_helper.update_conversation_persona_id(user_id, cid, persona_id)
return Response().ok({"message": "对话信息更新成功"}).__dict__
except Exception as e:
@@ -209,17 +201,11 @@ class ConversationRoute(Route):
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
)
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id, conversation_id=cid
)
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
history = json.loads(history) if isinstance(history, str) else history
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id, conversation_id=cid, history=history
)
self.db_helper.update_conversation(user_id, cid, history)
return Response().ok({"message": "对话历史更新成功"}).__dict__

View File

@@ -1,199 +0,0 @@
import traceback
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.db import BaseDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class PersonaRoute(Route):
def __init__(
self,
context: RouteContext,
db_helper: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.routes = {
"/persona/list": ("GET", self.list_personas),
"/persona/detail": ("POST", self.get_persona_detail),
"/persona/create": ("POST", self.create_persona),
"/persona/update": ("POST", self.update_persona),
"/persona/delete": ("POST", self.delete_persona),
}
self.db_helper = db_helper
self.persona_mgr = core_lifecycle.persona_mgr
self.register_routes()
async def list_personas(self):
"""获取所有人格列表"""
try:
personas = await self.persona_mgr.get_all_personas()
return (
Response()
.ok(
[
{
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools,
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
}
for persona in personas
]
)
.__dict__
)
except Exception as e:
logger.error(f"获取人格列表失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"获取人格列表失败: {str(e)}").__dict__
async def get_persona_detail(self):
"""获取指定人格的详细信息"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
persona = await self.persona_mgr.get_persona(persona_id)
if not persona:
return Response().error("人格不存在").__dict__
return (
Response()
.ok(
{
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools,
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
}
)
.__dict__
)
except Exception as e:
logger.error(f"获取人格详情失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"获取人格详情失败: {str(e)}").__dict__
async def create_persona(self):
"""创建新人格"""
try:
data = await request.get_json()
persona_id = data.get("persona_id", "").strip()
system_prompt = data.get("system_prompt", "").strip()
begin_dialogs = data.get("begin_dialogs", [])
tools = data.get("tools")
if not persona_id:
return Response().error("人格ID不能为空").__dict__
if not system_prompt:
return Response().error("系统提示词不能为空").__dict__
# 验证 begin_dialogs 格式
if begin_dialogs and len(begin_dialogs) % 2 != 0:
return (
Response()
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
.__dict__
)
persona = await self.persona_mgr.create_persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs if begin_dialogs else None,
tools=tools if tools else None,
)
return (
Response()
.ok(
{
"message": "人格创建成功",
"persona": {
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools or [],
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
},
}
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"创建人格失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"创建人格失败: {str(e)}").__dict__
async def update_persona(self):
"""更新人格信息"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
system_prompt = data.get("system_prompt")
begin_dialogs = data.get("begin_dialogs")
tools = data.get("tools")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
# 验证 begin_dialogs 格式
if begin_dialogs is not None and len(begin_dialogs) % 2 != 0:
return (
Response()
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
.__dict__
)
await self.persona_mgr.update_persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs,
tools=tools,
)
return Response().ok({"message": "人格更新成功"}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"更新人格失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"更新人格失败: {str(e)}").__dict__
async def delete_persona(self):
"""删除人格"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
await self.persona_mgr.delete_persona(persona_id)
return Response().ok({"message": "人格删除成功"}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"删除人格失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"删除人格失败: {str(e)}").__dict__

View File

@@ -40,6 +40,8 @@ class PluginRoute(Route):
"/plugin/on": ("POST", self.on_plugin),
"/plugin/reload": ("POST", self.reload_plugins),
"/plugin/readme": ("GET", self.get_plugin_readme),
"/plugin/platform_enable/get": ("GET", self.get_plugin_platform_enable),
"/plugin/platform_enable/set": ("POST", self.set_plugin_platform_enable),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -284,6 +286,14 @@ class PluginRoute(Route):
f"{filter.parent_command_names[0]} {filter.command_name}"
)
info["cmd"] = info["cmd"].strip()
if (
self.core_lifecycle.astrbot_config["wake_prefix"]
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
> 0
):
info["cmd"] = (
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
)
elif isinstance(filter, CommandGroupFilter):
info["type"] = "指令组"
info["cmd"] = filter.get_complete_command_names()[0]
@@ -291,6 +301,14 @@ class PluginRoute(Route):
info["sub_command"] = filter.print_cmd_tree(
filter.sub_command_filters
)
if (
self.core_lifecycle.astrbot_config["wake_prefix"]
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
> 0
):
info["cmd"] = (
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
)
elif isinstance(filter, RegexFilter):
info["type"] = "正则匹配"
info["cmd"] = filter.regex_str
@@ -480,3 +498,90 @@ class PluginRoute(Route):
except Exception as e:
logger.error(f"/api/plugin/readme: {traceback.format_exc()}")
return Response().error(f"读取README文件失败: {str(e)}").__dict__
async def get_plugin_platform_enable(self):
"""获取插件在各平台的可用性配置"""
try:
platform_enable = self.core_lifecycle.astrbot_config.get(
"platform_settings", {}
).get("plugin_enable", {})
# 获取所有可用平台
platforms = []
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
platform_type = platform.get("type", "")
platform_id = platform.get("id", "")
platforms.append(
{
"name": platform_id, # 使用type作为name这是系统内部使用的平台名称
"id": platform_id, # 保留id字段以便前端可以显示
"type": platform_type,
"display_name": f"{platform_type}({platform_id})",
}
)
adjusted_platform_enable = {}
for platform_id, plugins in platform_enable.items():
adjusted_platform_enable[platform_id] = plugins
# 获取所有插件,包括系统内部插件
plugins = []
for plugin in self.plugin_manager.context.get_all_stars():
plugins.append(
{
"name": plugin.name,
"desc": plugin.desc,
"reserved": plugin.reserved, # 添加reserved标志
}
)
logger.debug(
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
)
return (
Response()
.ok(
{
"platforms": platforms,
"plugins": plugins,
"platform_enable": adjusted_platform_enable,
}
)
.__dict__
)
except Exception as e:
logger.error(f"/api/plugin/platform_enable/get: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def set_plugin_platform_enable(self):
"""设置插件在各平台的可用性配置"""
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
data = await request.json
platform_enable = data.get("platform_enable", {})
# 更新配置
config = self.core_lifecycle.astrbot_config
platform_settings = config.get("platform_settings", {})
platform_settings["plugin_enable"] = platform_enable
config["platform_settings"] = platform_settings
config.save_config()
# 更新插件的平台兼容性缓存
await self.plugin_manager.update_all_platform_compatibility()
logger.info(f"插件平台可用性配置已更新: {platform_enable}")
return Response().ok(None, "插件平台可用性配置已更新").__dict__
except Exception as e:
logger.error(f"/api/plugin/platform_enable/set: {traceback.format_exc()}")
return Response().error(str(e)).__dict__

View File

@@ -21,19 +21,17 @@ class Route:
@dataclass
class Response:
status: str | None = None
message: str | None = None
data: dict | list | None = None
status: str = None
message: str = None
data: dict = None
def error(self, message: str):
self.status = "error"
self.message = message
return self
def ok(self, data: dict | list | None = None, message: str | None = None):
def ok(self, data: dict = {}, message: str = None):
self.status = "ok"
if data is None:
data = {}
self.data = data
self.message = message
return self

View File

@@ -24,6 +24,7 @@ class SessionManagementRoute(Route):
"/session/list": ("GET", self.list_sessions),
"/session/update_persona": ("POST", self.update_session_persona),
"/session/update_provider": ("POST", self.update_session_provider),
"/session/get_session_info": ("POST", self.get_session_info),
"/session/plugins": ("GET", self.get_session_plugins),
"/session/update_plugin": ("POST", self.update_session_plugin),
"/session/update_llm": ("POST", self.update_session_llm),
@@ -31,20 +32,24 @@ class SessionManagementRoute(Route):
"/session/update_name": ("POST", self.update_session_name),
"/session/update_status": ("POST", self.update_session_status),
}
self.conv_mgr = core_lifecycle.conversation_manager
self.db_helper = db_helper
self.core_lifecycle = core_lifecycle
self.register_routes()
async def list_sessions(self):
"""获取所有会话的列表,包括 persona 和 provider 信息"""
try:
preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[])
session_conversations = {}
for pref in preferences:
session_conversations[pref.scope_id] = pref.value["val"]
provider_manager = self.core_lifecycle.provider_manager
persona_mgr = self.core_lifecycle.persona_mgr
personas = persona_mgr.personas_v3
# 获取会话对话映射
session_conversations = sp.get("session_conversation", {}) or {}
# 获取会话提供商偏好设置
session_provider_perf = sp.get("session_provider_perf", {}) or {}
# 获取可用的 personas
personas = self.core_lifecycle.star_context.provider_manager.personas
# 获取可用的 providers
provider_manager = self.core_lifecycle.star_context.provider_manager
sessions = []
@@ -54,9 +59,13 @@ class SessionManagementRoute(Route):
"session_id": session_id,
"conversation_id": conversation_id,
"persona_id": None,
"persona_name": None,
"chat_provider_id": None,
"chat_provider_name": None,
"stt_provider_id": None,
"stt_provider_name": None,
"tts_provider_id": None,
"tts_provider_name": None,
"session_enabled": SessionServiceManager.is_session_enabled(
session_id
),
@@ -81,46 +90,74 @@ class SessionManagementRoute(Route):
}
# 获取对话信息
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=session_id, conversation_id=conversation_id
conversation = self.db_helper.get_conversation_by_user_id(
session_id, conversation_id
)
if conversation:
session_info["persona_id"] = conversation.persona_id
# 查找 persona 名称
if conversation.persona_id and conversation.persona_id != "[%None]":
for persona in personas:
if persona["name"] == conversation.persona_id:
session_info["persona_id"] = persona["name"]
session_info["persona_name"] = persona["name"]
break
elif conversation.persona_id == "[%None]":
session_info["persona_id"] = "无人格"
session_info["persona_name"] = "无人格"
else:
# 使用默认人格
default_persona = persona_mgr.selected_default_persona_v3
default_persona = provider_manager.selected_default_persona
if default_persona:
session_info["persona_id"] = default_persona["name"]
session_info["persona_name"] = default_persona["name"]
# 获取 provider 信息
provider_manager = self.core_lifecycle.provider_manager
chat_provider = provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION, umo=session_id
)
tts_provider = provider_manager.get_using_provider(
provider_type=ProviderType.TEXT_TO_SPEECH, umo=session_id
)
stt_provider = provider_manager.get_using_provider(
provider_type=ProviderType.SPEECH_TO_TEXT, umo=session_id
)
if chat_provider:
meta = chat_provider.meta()
session_info["chat_provider_id"] = meta.id
if tts_provider:
meta = tts_provider.meta()
session_info["tts_provider_id"] = meta.id
if stt_provider:
meta = stt_provider.meta()
session_info["stt_provider_id"] = meta.id
# 获取会话的 provider 偏好设置
session_perf = session_provider_perf.get(session_id, {})
# Chat completion provider
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
if chat_provider_id:
chat_provider = provider_manager.inst_map.get(chat_provider_id)
if chat_provider:
session_info["chat_provider_id"] = chat_provider_id
session_info["chat_provider_name"] = chat_provider.meta().id
else:
# 使用默认 provider
default_provider = provider_manager.curr_provider_inst
if default_provider:
session_info["chat_provider_id"] = default_provider.meta().id
session_info["chat_provider_name"] = default_provider.meta().id
# STT provider
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
if stt_provider_id:
stt_provider = provider_manager.inst_map.get(stt_provider_id)
if stt_provider:
session_info["stt_provider_id"] = stt_provider_id
session_info["stt_provider_name"] = stt_provider.meta().id
else:
# 使用默认 STT provider
default_stt_provider = provider_manager.curr_stt_provider_inst
if default_stt_provider:
session_info["stt_provider_id"] = default_stt_provider.meta().id
session_info["stt_provider_name"] = (
default_stt_provider.meta().id
)
# TTS provider
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
if tts_provider_id:
tts_provider = provider_manager.inst_map.get(tts_provider_id)
if tts_provider:
session_info["tts_provider_id"] = tts_provider_id
session_info["tts_provider_name"] = tts_provider.meta().id
else:
# 使用默认 TTS provider
default_tts_provider = provider_manager.curr_tts_provider_inst
if default_tts_provider:
session_info["tts_provider_id"] = default_tts_provider.meta().id
session_info["tts_provider_name"] = (
default_tts_provider.meta().id
)
sessions.append(session_info)
@@ -274,6 +311,133 @@ class SessionManagementRoute(Route):
logger.error(error_msg)
return Response().error(f"更新会话提供商失败: {str(e)}").__dict__
async def get_session_info(self):
"""获取指定会话的详细信息"""
try:
data = await request.get_json()
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
# 获取会话对话信息
session_conversations = sp.get("session_conversation", {}) or {}
conversation_id = session_conversations.get(session_id)
if not conversation_id:
return Response().error(f"会话 {session_id} 未找到对话").__dict__
session_info = {
"session_id": session_id,
"conversation_id": conversation_id,
"persona_id": None,
"persona_name": None,
"chat_provider_id": None,
"chat_provider_name": None,
"stt_provider_id": None,
"stt_provider_name": None,
"tts_provider_id": None,
"tts_provider_name": None,
"llm_enabled": SessionServiceManager.is_llm_enabled_for_session(
session_id
),
"tts_enabled": None, # 将在下面设置
"platform": session_id.split(":")[0]
if ":" in session_id
else "unknown",
"message_type": session_id.split(":")[1]
if session_id.count(":") >= 1
else "unknown",
"session_name": session_id.split(":")[2]
if session_id.count(":") >= 2
else session_id,
}
# 获取TTS状态
session_info["tts_enabled"] = (
SessionServiceManager.is_tts_enabled_for_session(session_id)
)
# 获取对话信息
conversation = self.db_helper.get_conversation_by_user_id(
session_id, conversation_id
)
if conversation:
session_info["persona_id"] = conversation.persona_id
# 查找 persona 名称
provider_manager = self.core_lifecycle.star_context.provider_manager
personas = provider_manager.personas
if conversation.persona_id and conversation.persona_id != "[%None]":
for persona in personas:
if persona["name"] == conversation.persona_id:
session_info["persona_name"] = persona["name"]
break
elif conversation.persona_id == "[%None]":
session_info["persona_name"] = "无人格"
else:
# 使用默认人格
default_persona = provider_manager.selected_default_persona
if default_persona:
session_info["persona_id"] = default_persona["name"]
session_info["persona_name"] = default_persona["name"]
# 获取会话的 provider 偏好设置
session_provider_perf = sp.get("session_provider_perf", {}) or {}
session_perf = session_provider_perf.get(session_id, {})
# 获取 provider 信息
provider_manager = self.core_lifecycle.star_context.provider_manager
# Chat completion provider
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
if chat_provider_id:
chat_provider = provider_manager.inst_map.get(chat_provider_id)
if chat_provider:
session_info["chat_provider_id"] = chat_provider_id
session_info["chat_provider_name"] = chat_provider.meta().id
else:
# 使用默认 provider
default_provider = provider_manager.curr_provider_inst
if default_provider:
session_info["chat_provider_id"] = default_provider.meta().id
session_info["chat_provider_name"] = default_provider.meta().id
# STT provider
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
if stt_provider_id:
stt_provider = provider_manager.inst_map.get(stt_provider_id)
if stt_provider:
session_info["stt_provider_id"] = stt_provider_id
session_info["stt_provider_name"] = stt_provider.meta().id
else:
# 使用默认 STT provider
default_stt_provider = provider_manager.curr_stt_provider_inst
if default_stt_provider:
session_info["stt_provider_id"] = default_stt_provider.meta().id
session_info["stt_provider_name"] = default_stt_provider.meta().id
# TTS provider
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
if tts_provider_id:
tts_provider = provider_manager.inst_map.get(tts_provider_id)
if tts_provider:
session_info["tts_provider_id"] = tts_provider_id
session_info["tts_provider_name"] = tts_provider.meta().id
else:
# 使用默认 TTS provider
default_tts_provider = provider_manager.curr_tts_provider_inst
if default_tts_provider:
session_info["tts_provider_id"] = default_tts_provider.meta().id
session_info["tts_provider_name"] = default_tts_provider.meta().id
return Response().ok(session_info).__dict__
except Exception as e:
error_msg = f"获取会话信息失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"获取会话信息失败: {str(e)}").__dict__
async def get_session_plugins(self):
"""获取指定会话的插件配置信息"""
try:

View File

@@ -11,7 +11,6 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.config import VERSION
from astrbot.core.utils.io import get_dashboard_version
from astrbot.core import DEMO_MODE
from astrbot.core.db.migration.helper import check_migration_needed_v4
class StatRoute(Route):
@@ -60,8 +59,6 @@ class StatRoute(Route):
)
async def get_version(self):
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
return (
Response()
.ok(
@@ -69,7 +66,6 @@ class StatRoute(Route):
"version": VERSION,
"dashboard_version": await get_dashboard_version(),
"change_pwd_hint": self.is_default_cred(),
"need_migration": need_migration,
}
)
.__dict__
@@ -88,7 +84,7 @@ class StatRoute(Route):
message_time_based_stats = []
idx = 0
for bucket_end in range(start_time, now, 3600):
for bucket_end in range(start_time, now, 1800):
cnt = 0
while (
idx < len(stat.platform)

View File

@@ -1,3 +1,5 @@
import json
import os
import traceback
import aiohttp
@@ -5,7 +7,7 @@ from quart import request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star import star_map
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .route import Response, Route, RouteContext
@@ -23,17 +25,44 @@ class ToolsRoute(Route):
"/tools/mcp/add": ("POST", self.add_mcp_server),
"/tools/mcp/update": ("POST", self.update_mcp_server),
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
"/tools/mcp/market": ("GET", self.get_mcp_markets),
"/tools/mcp/test": ("POST", self.test_mcp_connection),
"/tools/list": ("GET", self.get_tool_list),
"/tools/toggle-tool": ("POST", self.toggle_tool),
"/tools/mcp/sync-provider": ("POST", self.sync_provider),
}
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
@property
def mcp_config_path(self):
data_dir = get_astrbot_data_path()
return os.path.join(data_dir, "mcp_server.json")
def load_mcp_config(self):
if not os.path.exists(self.mcp_config_path):
# 配置文件不存在,创建默认配置
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
return DEFAULT_MCP_CONFIG
try:
with open(self.mcp_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"加载 MCP 配置失败: {e}")
return DEFAULT_MCP_CONFIG
def save_mcp_config(self, config):
try:
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
return True
except Exception as e:
logger.error(f"保存 MCP 配置失败: {e}")
return False
async def get_mcp_servers(self):
try:
config = self.tool_mgr.load_mcp_config()
config = self.load_mcp_config()
servers = []
# 获取所有服务器并添加它们的工具列表
@@ -96,14 +125,14 @@ class ToolsRoute(Route):
if not has_valid_config:
return Response().error("必须提供有效的服务器配置").__dict__
config = self.tool_mgr.load_mcp_config()
config = self.load_mcp_config()
if name in config["mcpServers"]:
return Response().error(f"服务器 {name} 已存在").__dict__
config["mcpServers"][name] = server_config
if self.tool_mgr.save_mcp_config(config):
if self.save_mcp_config(config):
try:
await self.tool_mgr.enable_mcp_server(
name, server_config, timeout=30
@@ -133,7 +162,7 @@ class ToolsRoute(Route):
if not name:
return Response().error("服务器名称不能为空").__dict__
config = self.tool_mgr.load_mcp_config()
config = self.load_mcp_config()
if name not in config["mcpServers"]:
return Response().error(f"服务器 {name} 不存在").__dict__
@@ -169,7 +198,7 @@ class ToolsRoute(Route):
config["mcpServers"][name] = server_config
if self.tool_mgr.save_mcp_config(config):
if self.save_mcp_config(config):
# 处理MCP客户端状态变化
if active:
if name in self.tool_mgr.mcp_client_dict or not only_update_active:
@@ -237,14 +266,14 @@ class ToolsRoute(Route):
if not name:
return Response().error("服务器名称不能为空").__dict__
config = self.tool_mgr.load_mcp_config()
config = self.load_mcp_config()
if name not in config["mcpServers"]:
return Response().error(f"服务器 {name} 不存在").__dict__
del config["mcpServers"][name]
if self.tool_mgr.save_mcp_config(config):
if self.save_mcp_config(config):
if name in self.tool_mgr.mcp_client_dict:
try:
await self.tool_mgr.disable_mcp_server(name, timeout=10)
@@ -266,6 +295,31 @@ class ToolsRoute(Route):
logger.error(traceback.format_exc())
return Response().error(f"删除 MCP 服务器失败: {str(e)}").__dict__
async def get_mcp_markets(self):
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 10, type=int)
BASE_URL = (
"https://api.soulter.top/astrbot/mcpservers?page={}&page_size={}".format(
page,
page_size,
)
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{BASE_URL}") as response:
if response.status == 200:
data = await response.json()
return Response().ok(data["data"]).__dict__
else:
return (
Response()
.error(f"获取市场数据失败: HTTP {response.status}")
.__dict__
)
except Exception as _:
logger.error(traceback.format_exc())
return Response().error("获取市场数据失败").__dict__
async def test_mcp_connection(self):
"""
测试 MCP 服务器连接
@@ -282,57 +336,3 @@ class ToolsRoute(Route):
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
async def get_tool_list(self):
"""获取所有注册的工具列表"""
try:
tools = self.tool_mgr.func_list
tools_dict = [tool.__dict__() for tool in tools]
return Response().ok(data=tools_dict).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取工具列表失败: {str(e)}").__dict__
async def toggle_tool(self):
"""启用或停用指定的工具"""
try:
data = await request.json
tool_name = data.get("name")
action = data.get("activate") # True or False
if not tool_name or action is None:
return Response().error("缺少必要参数: name 或 action").__dict__
if action:
try:
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
except ValueError as e:
return Response().error(f"启用工具失败: {str(e)}").__dict__
else:
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
if ok:
return Response().ok(None, "操作成功。").__dict__
else:
return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"操作工具失败: {str(e)}").__dict__
async def sync_provider(self):
"""同步 MCP 提供者配置"""
try:
data = await request.json
provider_name = data.get("name") # modelscope, or others
match provider_name:
case "modelscope":
access_token = data.get("access_token", "")
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
case _:
return Response().error(f"未知: {provider_name}").__dict__
return Response().ok(message="同步成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"同步失败: {str(e)}").__dict__

View File

@@ -7,7 +7,6 @@ from astrbot.core import logger, pip_installer
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.config.default import VERSION
from astrbot.core import DEMO_MODE
from astrbot.core.db.migration.helper import do_migration_v4, check_migration_needed_v4
class UpdateRoute(Route):
@@ -24,27 +23,11 @@ class UpdateRoute(Route):
"/update/do": ("POST", self.update_project),
"/update/dashboard": ("POST", self.update_dashboard),
"/update/pip-install": ("POST", self.install_pip_package),
"/update/migration": ("POST", self.do_migration),
}
self.astrbot_updator = astrbot_updator
self.core_lifecycle = core_lifecycle
self.register_routes()
async def do_migration(self):
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
if not need_migration:
return Response().ok(None, "不需要进行迁移。").__dict__
try:
data = await request.json
pim = data.get("platform_id_map", {})
await do_migration_v4(
self.core_lifecycle.db, pim, self.core_lifecycle.astrbot_config
)
return Response().ok(None, "迁移成功。").__dict__
except Exception as e:
logger.error(f"迁移失败: {traceback.format_exc()}")
return Response().error(f"迁移失败: {str(e)}").__dict__
async def check_update(self):
type_ = request.args.get("type", None)
@@ -65,7 +48,7 @@ class UpdateRoute(Route):
"version": f"v{VERSION}",
"has_new_version": ret is not None,
"dashboard_version": dv,
"dashboard_has_new_version": dv and dv != f"v{VERSION}",
"dashboard_has_new_version": dv != f"v{VERSION}",
},
).__dict__
except Exception as e:
@@ -99,12 +82,11 @@ class UpdateRoute(Route):
latest=latest, version=version, proxy=proxy
)
try:
await download_dashboard(
latest=latest, version=version, proxy=proxy
)
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
if latest:
try:
await download_dashboard()
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
# pip 更新依赖
logger.info("更新依赖中...")

View File

@@ -60,9 +60,6 @@ class AstrBotDashboard:
self.session_management_route = SessionManagementRoute(
self.context, db, core_lifecycle
)
self.persona_route = PersonaRoute(
self.context, db, core_lifecycle
)
self.app.add_url_rule(
"/api/plug/<path:subpath>",

View File

@@ -1,10 +0,0 @@
# What's Changed
> 新版本预告: v4.0.0 即将发布。
1. 新增: 添加对 ModelScope、Compshare优云智算的模版支持。
2. 优化: 增加插件数据缓存,优化插件市场数据获取时的稳定性。
其他更新:
1. 现已支持在 1Panel 平台通过应用商城快捷部署 AstrBot。详见[在 1Panel 部署 AstrBot](https://docs.astrbot.app/deploy/astrbot/1panel.html)

View File

@@ -1,13 +0,0 @@
# What's Changed
1. 修复: 修复插件可能存在的无法正常禁用的问题 ([#2352](https://github.com/Soulter/AstrBot/issues/2352))
2. ❗修复:当返回文本为空并且存在函数调用时错误地被终止事件,导致函数调用结果未被正常返回 ([#2491](https://github.com/Soulter/AstrBot/issues/2491))
3. 修复:修复无法清空 AstrBot 配置下的 http_proxy 代理的问题 ([#2434](https://github.com/Soulter/AstrBot/issues/2434))
4. ❗修复Gemini 下开启流式输出时,持久化的消息结果不完整 ([#2424](https://github.com/Soulter/AstrBot/issues/2424))
5. 修复:注册文件时由于 file:/// 前缀,导致文件被误判为不存在的问题 ([#2325](https://github.com/Soulter/AstrBot/issues/2325))
6. 优化: 为部分类型供应商添加默认的温度选项 ([#2321](https://github.com/Soulter/AstrBot/issues/2321))
7. 优化: 适配 Qwen3 模型非流式输出下需要传入 enable_think 参数(否则报错) ([#2424](https://github.com/Soulter/AstrBot/issues/2424))
8. 优化:支持配置工具调用轮数上限,默认 30
9. 新增: 添加 WebUI 语义化预发布版本提醒和检测功能
> 新版本预告: v4.0.0 即将发布。

Some files were not shown because too many files have changed in this diff Show More