Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
1c090299b1 feat: tauri app 2025-11-10 15:11:59 +08:00
134 changed files with 15841 additions and 1396 deletions

79
.github/workflows/build-app.yml vendored Normal file
View File

@@ -0,0 +1,79 @@
name: Build Desktop App
on:
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
build:
strategy:
fail-fast: false
matrix:
platform: [macos-latest, ubuntu-latest, windows-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: 20
- name: Install Rust
uses: dtolnay/rust-toolchain@stable
- name: Install dependencies (Ubuntu)
if: matrix.platform == 'ubuntu-latest'
run: |
sudo apt-get update
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev libappindicator3-dev librsvg2-dev patchelf
- name: Install Python dependencies
run: |
pip install uv
uv sync
- name: Build Python backend with Nuitka
run: |
pip install nuitka
python build_nuitka.py
- name: Install Node dependencies
working-directory: ./dashboard
run: npm install
- name: Build Tauri app
working-directory: ./dashboard
run: npm run tauri:build
- name: Upload artifacts (macOS)
if: matrix.platform == 'macos-latest'
uses: actions/upload-artifact@v4
with:
name: astrbot-macos
path: dashboard/src-tauri/target/release/bundle/dmg/*.dmg
- name: Upload artifacts (Windows)
if: matrix.platform == 'windows-latest'
uses: actions/upload-artifact@v4
with:
name: astrbot-windows
path: dashboard/src-tauri/target/release/bundle/msi/*.msi
- name: Upload artifacts (Linux)
if: matrix.platform == 'ubuntu-latest'
uses: actions/upload-artifact@v4
with:
name: astrbot-linux
path: |
dashboard/src-tauri/target/release/bundle/deb/*.deb
dashboard/src-tauri/target/release/bundle/appimage/*.AppImage

2
.gitignore vendored
View File

@@ -32,6 +32,7 @@ tests/astrbot_plugin_openai
# Dashboard
dashboard/node_modules/
dashboard/dist/
dashboard/src-tauri/target
package-lock.json
package.json
@@ -47,3 +48,4 @@ astrbot.lock
chroma
venv/*
pytest.ini
build/

287
BUILD_INSTRUCTIONS.md Normal file
View File

@@ -0,0 +1,287 @@
# AstrBot 桌面应用构建指南
本指南介绍如何使用 Nuitka 将 Python 后端打包并集成到 Tauri 桌面应用中。
## 前置要求
### 系统要求
- Python 3.10+
- Node.js 20+
- Rust (通过 rustup 安装)
- UV 包管理器
### macOS 额外要求
- Xcode Command Line Tools: `xcode-select --install`
### Linux 额外要求
```bash
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev \
libappindicator3-dev librsvg2-dev patchelf
```
### Windows 额外要求
- Visual Studio 2019+ with C++ build tools
- Windows 10 SDK
## 构建步骤
### 1. 安装 Python 依赖
```bash
pip install uv
uv sync
```
### 2. 安装 Nuitka
```bash
pip install nuitka
```
### 3. 构建 Python 后端
```bash
python build_nuitka.py
```
这会使用 Nuitka 将 `main.py` 编译为独立可执行文件,输出到 `build/nuitka/` 目录。
**注意**: Nuitka 编译过程可能需要 10-30 分钟,取决于您的系统性能。
### 4. 安装前端依赖
```bash
cd dashboard
npm install
```
### 5. 构建 Tauri 应用
```bash
npm run tauri:build
```
构建脚本会自动:
1. 运行 `build_nuitka.py` 编译 Python 后端
2. 将编译好的可执行文件复制到 `src-tauri/resources/` 目录
3. 构建 Tauri 应用并打包所有资源
### 6. 查找构建产物
构建完成后,您可以在以下位置找到安装包:
- **macOS**: `dashboard/src-tauri/target/release/bundle/dmg/AstrBot_*.dmg`
- **Windows**: `dashboard/src-tauri/target/release/bundle/msi/AstrBot_*.msi`
- **Linux**:
- `dashboard/src-tauri/target/release/bundle/deb/astrbot_*.deb`
- `dashboard/src-tauri/target/release/bundle/appimage/astrbot_*.AppImage`
## 开发模式
在开发时,您可能不想每次都完整编译 Python 后端。
### 仅开发 Tauri + Vue
```bash
cd dashboard
npm run tauri:dev
```
这会启动开发服务器,但不会自动启动 Python 后端。您需要手动运行:
```bash
uv run main.py
```
### 测试完整集成
如果您想测试 Tauri 自动启动 Python 后端的功能:
1. 先编译一次 Python 后端:
```bash
python build_nuitka.py
```
2. 手动复制到资源目录:
```bash
# macOS
cp -r build/nuitka/main.app dashboard/src-tauri/resources/astrbot-backend.app
# Windows
copy build\nuitka\main.exe dashboard\src-tauri\resources\astrbot-backend.exe
# Linux
cp build/nuitka/main.bin dashboard/src-tauri/resources/astrbot-backend
```
3. 运行开发模式:
```bash
cd dashboard
npm run tauri:dev
```
## Nuitka 构建选项说明
`build_nuitka.py` 脚本使用以下关键选项:
- `--standalone`: 创建包含所有依赖的独立目录
- `--onefile`: 将所有内容打包到单个可执行文件
- `--follow-imports`: 自动跟踪所有 Python 导入
- `--include-package`: 明确包含特定包
- `--include-data-dir`: 包含数据目录(插件、配置等)
### 自定义构建
如果您需要修改构建选项,编辑 `build_nuitka.py`:
```python
# 添加更多要包含的包
include_packages = [
"astrbot",
"your_custom_package",
# ...
]
# 添加更多数据目录
data_includes = [
"data/config",
"your_custom_data",
# ...
]
```
## 常见问题
### 1. Nuitka 编译失败
**问题**: 编译时出现 "module not found" 错误
**解决方案**: 在 `build_nuitka.py` 中添加缺失的包到 `include_packages` 列表
### 2. 运行时找不到资源文件
**问题**: 应用启动后提示找不到配置文件或插件
**解决方案**: 确保在 `build_nuitka.py` 中使用 `--include-data-dir` 包含了所有必要的数据目录
### 3. macOS 安全警告
**问题**: macOS 提示"应用来自未知开发者"
**解决方案**:
```bash
# 临时解除限制
sudo spctl --master-disable
# 或者为特定应用授权
xattr -cr /Applications/AstrBot.app
```
对于生产发布,您需要:
1. 注册 Apple Developer 账号
2. 对应用进行代码签名
3. 提交公证 (Notarization)
### 4. Windows Defender 报毒
**问题**: Windows Defender 或其他杀毒软件报毒
**解决方案**:
- 这是 Nuitka 打包程序的常见问题
- 可以使用 `--windows-company-name``--windows-product-name` 添加元数据
- 对于生产发布,需要购买代码签名证书
### 5. Linux 依赖问题
**问题**: 在某些 Linux 发行版上缺少共享库
**解决方案**: 使用 AppImage 格式,它包含所有依赖:
```bash
# 构建时会自动生成 AppImage
npm run tauri:build
```
## 优化构建大小
默认的 `--onefile` 模式会生成较大的可执行文件。如果需要减小体积:
1. 移除不需要的包
2. 使用 `--standalone` 而不是 `--onefile`
3. 排除不必要的数据文件
修改 `build_nuitka.py`:
```python
# 移除 --onefile使用 --standalone
nuitka_cmd = [
sys.executable,
"-m", "nuitka",
"--standalone", # 只使用 standalone
# "--onefile", # 注释掉 onefile
# ...
]
```
## CI/CD 集成
项目已配置 GitHub Actions 工作流 (`.github/workflows/build-app.yml`),可以自动为所有平台构建应用。
推送标签时自动触发:
```bash
git tag v4.5.7
git push origin v4.5.7
```
或手动触发:
在 GitHub Actions 页面选择 "Build Desktop App" 工作流并点击 "Run workflow"
## 发布清单
在发布新版本前:
- [ ] 更新版本号
- `pyproject.toml` - Python 项目版本
- `dashboard/package.json` - Node 项目版本
- `dashboard/src-tauri/Cargo.toml` - Rust 项目版本
- `dashboard/src-tauri/tauri.conf.json` - Tauri 配置版本
- [ ] 运行代码检查
```bash
uv run ruff check .
uv run ruff format .
```
- [ ] 本地测试构建
```bash
python build_nuitka.py
cd dashboard && npm run tauri:build
```
- [ ] 测试安装包
- 安装生成的安装包
- 验证应用启动
- 验证 Python 后端自动启动
- 测试核心功能
- [ ] 创建发布标签
```bash
git tag -a v4.5.7 -m "Release v4.5.7"
git push origin v4.5.7
```
## 技术架构
```
┌─────────────────────────────────────┐
│ Tauri Desktop App │
│ (Rust + WebView) │
│ │
│ ┌─────────────────────────────┐ │
│ │ Vue.js Dashboard │ │
│ │ (Frontend UI) │ │
│ └─────────────────────────────┘ │
│ │
│ ┌─────────────────────────────┐ │
│ │ Python Backend │ │
│ │ (Nuitka Compiled) │ │
│ │ - AstrBot Core │ │
│ │ - Plugins │ │
│ │ - API Server │ │
│ └─────────────────────────────┘ │
│ │
│ HTTP/WebSocket │
│ localhost:6185 │
└─────────────────────────────────────┘
```
## 参考资源
- [Nuitka 文档](https://nuitka.net/doc/user-manual.html)
- [Tauri 文档](https://tauri.app/v1/guides/)
- [AstrBot 文档](https://astrbot.fun)

View File

@@ -36,8 +36,7 @@ from astrbot.core.star.config import *
# provider
from astrbot.core.provider import Provider, ProviderMetaData
from astrbot.core.db.po import Personality
from astrbot.core.provider import Provider, Personality, ProviderMetaData
# platform
from astrbot.core.platform import (

View File

@@ -1,5 +1,4 @@
from astrbot.core.db.po import Personality
from astrbot.core.provider import Provider, STTProvider
from astrbot.core.provider import Personality, Provider, STTProvider
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMetaData,

View File

@@ -76,7 +76,7 @@ class ImageURLPart(ContentPart):
"""The ID of the image, to allow LLMs to distinguish different images."""
type: str = "image_url"
image_url: ImageURL
image_url: str
class AudioURLPart(ContentPart):

View File

@@ -1,21 +1,16 @@
from dataclasses import dataclass
from typing import Any, Generic
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import TypeVar
from .message import Message
TContext = TypeVar("TContext", default=Any)
@dataclass(config={"arbitrary_types_allowed": True})
@dataclass
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 60 # Default tool call timeout in seconds

View File

@@ -40,13 +40,6 @@ class BaseAgentRunner(T.Generic[TContext]):
"""Process a single step of the agent."""
...
@abc.abstractmethod
async def step_until_done(
self, max_step: int
) -> T.AsyncGenerator[AgentResponse, None]:
"""Process steps until the agent is done."""
...
@abc.abstractmethod
def done(self) -> bool:
"""Check if the agent has completed its task.

View File

@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..message import AssistantMessageSegment, ToolCallMessageSegment
from ..response import AgentResponseData
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
@@ -55,20 +55,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.agent_hooks = agent_hooks
self.run_context = run_context
messages = []
# append existing messages in the run context
for msg in request.contexts:
messages.append(Message.model_validate(msg))
if request.prompt is not None:
m = await request.assemble_context()
messages.append(Message.model_validate(m))
if request.system_prompt:
messages.insert(
0,
Message(role="system", content=request.system_prompt),
)
self.run_context.messages = messages
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
@@ -110,22 +96,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
elif llm_response.completion_text:
else:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text),
),
)
elif llm_response.reasoning_content:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain(type="reasoning").message(
llm_response.reasoning_content,
),
),
)
continue
llm_resp_result = llm_response
break # got final response
@@ -153,13 +130,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
# record the final assistant message
self.run_context.messages.append(
Message(
role="assistant",
content=llm_resp.completion_text or "",
),
)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
@@ -186,16 +156,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain(type="tool_call").message(
f"🔨 调用工具: {tool_call_name}"
),
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
result.type = "tool_call_result"
yield AgentResponse(
type="tool_call_result",
data=AgentResponseData(chain=result),
@@ -208,23 +175,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
),
tool_calls_result=tool_call_result_blocks,
)
# record the assistant message with tool calls
self.run_context.messages.extend(
tool_calls_result.to_openai_messages_model()
)
self.req.append_tool_calls_result(tool_calls_result)
async def step_until_done(
self, max_step: int
) -> T.AsyncGenerator[AgentResponse, None]:
"""Process steps until the agent is done."""
step_count = 0
while not self.done() and step_count < max_step:
step_count += 1
async for resp in self.step():
yield resp
async def _handle_function_tools(
self,
req: ProviderRequest,

View File

@@ -4,13 +4,12 @@ from typing import Any, Generic
import jsonschema
import mcp
from deprecated import deprecated
from pydantic import Field, model_validator
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
ToolExecResult = str | mcp.types.CallToolResult
@dataclass
@@ -56,14 +55,15 @@ class FunctionTool(ToolSchema, Generic[TContext]):
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> str | mcp.types.CallToolResult:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
@dataclass
class ToolSet:
"""A set of function tools that can be used in function calling.
@@ -71,7 +71,8 @@ class ToolSet:
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
"""
tools: list[FunctionTool] = Field(default_factory=list)
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
"""Check if the tool set is empty."""

View File

@@ -1,19 +1,14 @@
from pydantic import Field
from pydantic.dataclasses import dataclass
from dataclasses import dataclass
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.context import Context
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@dataclass(config={"arbitrary_types_allowed": True})
@dataclass
class AstrAgentContext:
context: Context
"""The star context instance"""
provider: Provider
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
event: AstrMessageEvent
"""The message event associated with the agent context."""
extra: dict[str, str] = Field(default_factory=dict)
"""Customized extra data."""
AgentContextWrapper = ContextWrapper[AstrAgentContext]

View File

@@ -1,36 +0,0 @@
from typing import Any
from mcp.types import CallToolResult
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.pipeline.context_utils import call_event_hook
from astrbot.core.star.star_handler import EventType
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
)
async def on_tool_end(
self,
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
run_context.context.event.clear_result()
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
pass
MAIN_AGENT_HOOKS = MainAgentHooks()

View File

@@ -1,80 +0,0 @@
import traceback
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
stream_to_general: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use:
await astr_event.send(resp.data["chain"])
continue
if stream_to_general and resp.type == "streaming_delta":
continue
if stream_to_general or not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
),
)
yield
astr_event.clear_result()
elif resp.type == "streaming_delta":
chain = resp.data["chain"]
if chain.type == "reasoning" and not show_reasoning:
# display the reasoning content only when configured
continue
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return

View File

@@ -1,246 +0,0 @@
import asyncio
import inspect
import traceback
import typing as T
import mcp
from astrbot import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import (
CommandResult,
MessageChain,
MessageEventResult,
)
from astrbot.core.provider.register import llm_tools
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
elif isinstance(tool, MCPTool):
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
else:
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input")
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
ctx = run_context.context.context
event = run_context.context.event
umo = event.unified_msg_origin
prov_id = await ctx.get_current_chat_provider_id(umo)
llm_resp = await ctx.tool_loop_agent(
event=event,
chat_provider_id=prov_id,
prompt=input_,
system_prompt=tool.agent.instructions,
tools=toolset,
max_steps=30,
run_hooks=tool.agent.run_hooks,
)
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
)
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
is_override_call = True
break
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
method_name=method_name,
**tool_args,
)
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
res = await tool.call(run_context, **tool_args)
if not res:
return
yield res
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
method_name: str,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
elif method_name == "call":
ready_to_call = handler(context, *args, **kwargs)
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.5.8"
VERSION = "4.5.6"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -68,7 +68,7 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"unsupported_streaming_strategy": "realtime_segmenting",
"streaming_segmented": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
},
@@ -880,23 +880,6 @@ CONFIG_METADATA_2 = {
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
"Groq": {
"id": "groq_default",
"provider": "groq",
"type": "groq_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.groq.com/openai/v1",
"timeout": 120,
"model_config": {
"model": "openai/gpt-oss-20b",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
"302.AI": {
"id": "302ai",
"provider": "302ai",
@@ -2010,8 +1993,8 @@ CONFIG_METADATA_2 = {
"show_tool_use_status": {
"type": "bool",
},
"unsupported_streaming_strategy": {
"type": "string",
"streaming_segmented": {
"type": "bool",
},
"max_agent_step": {
"description": "工具调用轮数上限",
@@ -2316,15 +2299,9 @@ CONFIG_METADATA_3 = {
"description": "流式回复",
"type": "bool",
},
"provider_settings.unsupported_streaming_strategy": {
"description": "不支持流式回复的平台",
"type": "string",
"options": ["realtime_segmenting", "turn_off"],
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
"labels": ["实时分段回复", "关闭流式回复"],
"condition": {
"provider_settings.streaming_response": True,
},
"provider_settings.streaming_segmented": {
"description": "不支持流式回复的平台采取分段输出",
"type": "bool",
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",

View File

@@ -1,9 +0,0 @@
from __future__ import annotations
class AstrBotError(Exception):
"""Base exception for all AstrBot errors."""
class ProviderNotFoundError(AstrBotError):
"""Raised when a specified provider is not found."""

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
@dataclass
@@ -15,3 +15,4 @@ class PipelineContext:
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
call_local_llm_tool = call_local_llm_tool

View File

@@ -3,6 +3,8 @@ import traceback
import typing as T
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
@@ -105,3 +107,66 @@ async def call_event_hook(
return True
return event.is_stopped()
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
method_name: str,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
elif method_name == "call":
ready_to_call = handler(context, *args, **kwargs)
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -3,10 +3,20 @@
import asyncio
import copy
import json
import traceback
from collections.abc import AsyncGenerator
from typing import Any
from mcp.types import CallToolResult
from astrbot.core import logger
from astrbot.core.agent.tool import ToolSet
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import Image
@@ -21,19 +31,324 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
from ....astr_agent_context import AgentContextWrapper
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....astr_agent_run_util import AgentRunner, run_agent
from ....astr_agent_tool_exec import FunctionToolExecutor
from ...context import PipelineContext, call_event_hook
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
from ..stage import Stage
from ..utils import inject_kb_context
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
elif isinstance(tool, MCPTool):
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
else:
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input", "agent")
agent_runner = AgentRunner()
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description or "",
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
)
astr_agent_ctx = AstrAgentContext(
provider=run_context.context.provider,
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
event=run_context.context.event,
)
event = run_context.context.event
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=run_context.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
streaming=run_context.context.streaming,
)
async for _ in run_agent(agent_runner, 15, True):
pass
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
)
result = (
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
)
text_content = mcp.types.TextContent(
type="text",
text=result,
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
logger.debug(f"Found call in: {ty}")
is_override_call = True
break
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
method_name=method_name,
**tool_args,
)
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
res = await tool.call(run_context, **tool_args)
if not res:
return
yield res
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
)
async def on_tool_end(
self,
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
run_context.context.event.clear_result()
MAIN_AGENT_HOOKS = MainAgentHooks()
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue
if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
),
)
yield
astr_event.clear_result()
elif resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
@@ -48,15 +363,11 @@ class LLMRequestSubStage(Stage):
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_reasoning = settings.get("display_reasoning_text", False)
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -95,12 +406,67 @@ class LLMRequestSubStage(Stage):
raise RuntimeError("无法创建新的对话。")
return conversation
async def _apply_kb_context(
async def process(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""应用知识库上下文到请求中"""
_nested: bool = False,
) -> None | AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest(prompt="", image_urls=[])
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# 应用知识库
try:
await inject_kb_context(
umo=event.unified_msg_origin,
@@ -110,40 +476,43 @@ class LLMRequestSubStage(Stage):
except Exception as e:
logger.error(f"调用知识库时遇到问题: {e}")
def _truncate_contexts(
self,
contexts: list[dict],
) -> list[dict]:
"""截断上下文列表,确保不超过最大长度"""
if self.max_context_length == -1:
return contexts
# 执行请求 LLM 前事件钩子。
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if len(contexts) // 2 <= self.max_context_length:
return contexts
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
truncated_contexts = contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(truncated_contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
# max context length
if (
self.max_context_length != -1 # -1 为不限制
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(req.contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
req.contexts = req.contexts[index:]
return truncated_contexts
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
def _modalities_fix(
self,
provider: Provider,
req: ProviderRequest,
):
"""检查提供商的模态能力,清理请求中的不支持内容"""
# fix messages
req.contexts = self.fix_messages(req.contexts)
# check provider modalities
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
@@ -157,13 +526,7 @@ class LLMRequestSubStage(Stage):
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
)
req.func_tool = None
def _plugin_tool_fix(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""根据事件中的插件设置,过滤请求中的工具列表"""
# 插件可用性设置
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
@@ -177,6 +540,80 @@ class LLMRequestSubStage(Stage):
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
astr_agent_ctx = AstrAgentContext(
provider=provider,
first_provider_request=req,
curr_provider_request=req,
streaming=streaming_response,
event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
)
if streaming_response:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use),
),
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)
async def _handle_webchat(
self,
event: AstrMessageEvent,
@@ -224,6 +661,9 @@ class LLMRequestSubStage(Stage):
),
)
if llm_resp and llm_resp.completion_text:
logger.debug(
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
)
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
return
@@ -251,9 +691,6 @@ class LLMRequestSubStage(Stage):
logger.debug("LLM 响应为空,不保存记录。")
return
if req.contexts is None:
req.contexts = []
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
@@ -273,7 +710,7 @@ class LLMRequestSubStage(Stage):
history=messages,
)
def _fix_messages(self, messages: list[dict]) -> list[dict]:
def fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
@@ -288,184 +725,3 @@ class LLMRequestSubStage(Stage):
else:
fixed_messages.append(message)
return fixed_messages
async def process(
self,
event: AstrMessageEvent,
_nested: bool = False,
) -> None | AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
logger.debug("ready to request llm provider")
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
logger.debug("acquired session lock for llm request")
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest()
req.prompt = ""
req.image_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix and not event.message_str.startswith(
self.provider_wake_prefix
):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# apply knowledge base context
await self._apply_kb_context(event, req)
# call event hook
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
self._fix_messages(req.contexts)
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
# check provider modalities, if provider does not support image/tool_use, clear them in request.
self._modalities_fix(provider, req)
# filter tools, only keep tools from this pipeline's selected plugins
self._plugin_tool_fix(event, req)
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
astr_agent_ctx = AstrAgentContext(
context=self.ctx.plugin_manager.context,
event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
)
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
show_reasoning=self.show_reasoning,
),
),
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain()
.message(final_llm_resp.completion_text)
.chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
async for _ in run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
stream_to_general,
show_reasoning=self.show_reasoning,
):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)

View File

@@ -10,6 +10,7 @@ from astrbot.core.message.message_event_result import MessageChain, ResultConten
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
from ..context import PipelineContext, call_event_hook
from ..stage import Stage, register_stage
@@ -168,15 +169,12 @@ class RespondStage(Stage):
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
realtime_segmenting = (
self.config.get("provider_settings", {}).get(
"unsupported_streaming_strategy",
"realtime_segmenting",
)
== "realtime_segmenting"
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented",
False,
)
logger.info(f"应用流式输出({event.get_platform_id()})")
await event.send_streaming(result.async_stream, realtime_segmenting)
await event.send_streaming(result.async_stream, use_fallback)
return
if len(result.chain) > 0:
# 检查路径映射
@@ -220,20 +218,21 @@ class RespondStage(Stage):
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
)
return
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
else:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}

View File

@@ -16,6 +16,3 @@ class PlatformMetadata:
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str | None = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""
support_streaming_message: bool = True
"""平台是否支持真实流式传输"""

View File

@@ -14,7 +14,6 @@ def register_platform_adapter(
default_config_tmpl: dict | None = None,
adapter_display_name: str | None = None,
logo_path: str | None = None,
support_streaming_message: bool = True,
):
"""用于注册平台适配器的带参装饰器。
@@ -43,7 +42,6 @@ def register_platform_adapter(
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
support_streaming_message=support_streaming_message,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls

View File

@@ -29,7 +29,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
@register_platform_adapter(
"aiocqhttp",
"适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
support_streaming_message=False,
)
class AiocqhttpAdapter(Platform):
def __init__(
@@ -50,7 +49,6 @@ class AiocqhttpAdapter(Platform):
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
support_streaming_message=False,
)
self.bot = CQHttp(

View File

@@ -37,9 +37,7 @@ class MyEventHandler(dingtalk_stream.EventHandler):
return AckMessage.STATUS_OK, "OK"
@register_platform_adapter(
"dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
)
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
class DingtalkPlatformAdapter(Platform):
def __init__(
self,
@@ -76,14 +74,6 @@ class DingtalkPlatformAdapter(Platform):
)
self.client_ = client # 用于 websockets 的 client
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
if not dingtalk_id:
return dingtalk_id
prefix = "$:LWCP_v1:$"
if dingtalk_id.startswith(prefix):
return dingtalk_id[len(prefix) :]
return dingtalk_id
async def send_by_session(
self,
session: MessageSesion,
@@ -96,7 +86,6 @@ class DingtalkPlatformAdapter(Platform):
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
support_streaming_message=False,
)
async def convert_msg(
@@ -113,10 +102,10 @@ class DingtalkPlatformAdapter(Platform):
else MessageType.FRIEND_MESSAGE
)
abm.sender = MessageMember(
user_id=self._id_to_sid(message.sender_id),
user_id=message.sender_id,
nickname=message.sender_nick,
)
abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.self_id = message.chatbot_user_id
abm.message_id = message.message_id
abm.raw_message = message
@@ -124,8 +113,8 @@ class DingtalkPlatformAdapter(Platform):
# 处理所有被 @ 的用户(包括机器人自己,因 at_users 已包含)
if message.at_users:
for user in message.at_users:
if id := self._id_to_sid(user.dingtalk_id):
abm.message.append(At(qq=id))
if user.dingtalk_id:
abm.message.append(At(qq=user.dingtalk_id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id

View File

@@ -34,9 +34,7 @@ else:
# 注册平台适配器
@register_platform_adapter(
"discord", "Discord 适配器 (基于 Pycord)", support_streaming_message=False
)
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
class DiscordPlatformAdapter(Platform):
def __init__(
self,
@@ -113,7 +111,6 @@ class DiscordPlatformAdapter(Platform):
"Discord 适配器",
id=self.config.get("id"),
default_config_tmpl=self.config,
support_streaming_message=False,
)
@override

View File

@@ -23,9 +23,7 @@ from ...register import register_platform_adapter
from .lark_event import LarkMessageEvent
@register_platform_adapter(
"lark", "飞书机器人官方 API 适配器", support_streaming_message=False
)
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
class LarkPlatformAdapter(Platform):
def __init__(
self,
@@ -117,7 +115,6 @@ class LarkPlatformAdapter(Platform):
name="lark",
description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):

View File

@@ -45,9 +45,7 @@ MAX_FILE_UPLOAD_COUNT = 16
DEFAULT_UPLOAD_CONCURRENCY = 3
@register_platform_adapter(
"misskey", "Misskey 平台适配器", support_streaming_message=False
)
@register_platform_adapter("misskey", "Misskey 平台适配器")
class MisskeyPlatformAdapter(Platform):
def __init__(
self,
@@ -122,7 +120,6 @@ class MisskeyPlatformAdapter(Platform):
description="Misskey 平台适配器",
id=self.config.get("id", "misskey"),
default_config_tmpl=default_config,
support_streaming_message=False,
)
async def run(self):

View File

@@ -29,7 +29,8 @@ from astrbot.core.platform.astr_message_event import MessageSession
@register_platform_adapter(
"satori", "Satori 协议适配器", support_streaming_message=False
"satori",
"Satori 协议适配器",
)
class SatoriPlatformAdapter(Platform):
def __init__(
@@ -59,7 +60,6 @@ class SatoriPlatformAdapter(Platform):
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
support_streaming_message=False,
)
self.ws: ClientConnection | None = None

View File

@@ -30,7 +30,6 @@ from .slack_event import SlackMessageEvent
@register_platform_adapter(
"slack",
"适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
support_streaming_message=False,
)
class SlackAdapter(Platform):
def __init__(
@@ -69,7 +68,6 @@ class SlackAdapter(Platform):
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"),
support_streaming_message=False,
)
# 初始化 Slack Web Client

View File

@@ -109,7 +109,6 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
reasoning_content = ""
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator:
@@ -125,22 +124,16 @@ class WebChatMessageEvent(AstrMessageEvent):
)
final_data = ""
continue
r = await WebChatMessageEvent._send(
final_data += await WebChatMessageEvent._send(
chain,
session_id=self.session_id,
streaming=True,
)
if chain.type == "reasoning":
reasoning_content += chain.get_plain_text()
else:
final_data += r
await web_chat_back_queue.put(
{
"type": "complete", # complete means we return the final result
"data": final_data,
"reasoning": reasoning_content,
"streaming": True,
"cid": cid,
},

View File

@@ -32,9 +32,7 @@ except ImportError as e:
)
@register_platform_adapter(
"wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
)
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
class WeChatPadProAdapter(Platform):
def __init__(
self,
@@ -53,7 +51,6 @@ class WeChatPadProAdapter(Platform):
name="wechatpadpro",
description="WeChatPadPro 消息平台适配器",
id=self.config.get("id", "wechatpadpro"),
support_streaming_message=False,
)
# 保存配置信息

View File

@@ -110,7 +110,7 @@ class WecomServer:
await self.shutdown_event.wait()
@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
@register_platform_adapter("wecom", "wecom 适配器")
class WecomPlatformAdapter(Platform):
def __init__(
self,
@@ -196,7 +196,6 @@ class WecomPlatformAdapter(Platform):
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
support_streaming_message=False,
)
@override

View File

@@ -113,9 +113,7 @@ class WecomServer:
await self.shutdown_event.wait()
@register_platform_adapter(
"weixin_official_account", "微信公众平台 适配器", support_streaming_message=False
)
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
class WeixinOfficialAccountPlatformAdapter(Platform):
def __init__(
self,
@@ -197,7 +195,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
support_streaming_message=False,
)
@override

View File

@@ -1,4 +1,4 @@
from .entities import ProviderMetaData
from .provider import Provider, STTProvider
from .provider import Personality, Provider, STTProvider
__all__ = ["Provider", "ProviderMetaData", "STTProvider"]
__all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"]

View File

@@ -30,31 +30,18 @@ class ProviderType(enum.Enum):
@dataclass
class ProviderMeta:
"""The basic metadata of a provider instance."""
id: str
"""the unique id of the provider instance that user configured"""
model: str | None
"""the model name of the provider instance currently used"""
class ProviderMetaData:
type: str
"""the name of the provider adapter, such as openai, ollama"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
"""the capability type of the provider adapter"""
@dataclass
class ProviderMetaData(ProviderMeta):
"""The metadata of a provider adapter for registration."""
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
"""the short description of the provider adapter"""
"""提供商适配器描述"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Any = None
"""the class type of the provider adapter"""
default_config_tmpl: dict | None = None
"""the default configuration template of the provider adapter"""
"""平台的默认配置模板"""
provider_display_name: str | None = None
"""the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
@@ -73,20 +60,12 @@ class ToolCallsResult:
]
return ret
def to_openai_messages_model(
self,
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
return [
self.tool_calls_info,
*self.tool_calls_result,
]
@dataclass
class ProviderRequest:
prompt: str | None = None
prompt: str
"""提示词"""
session_id: str | None = ""
session_id: str = ""
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
@@ -202,28 +181,25 @@ class ProviderRequest:
@dataclass
class LLMResponse:
role: str
"""The role of the message, e.g., assistant, tool, err"""
"""角色, assistant, tool, err"""
result_chain: MessageChain | None = None
"""A chain of message components representing the text completion from LLM."""
"""返回的消息链"""
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
"""Tool call arguments."""
"""工具调用参数"""
tools_call_name: list[str] = field(default_factory=list)
"""Tool call names."""
"""工具调用名称"""
tools_call_ids: list[str] = field(default_factory=list)
"""Tool call IDs."""
reasoning_content: str = ""
"""The reasoning content extracted from the LLM, if any."""
"""工具调用 ID"""
raw_completion: (
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
) = None
"""The raw completion response from the LLM provider."""
_new_record: dict[str, Any] | None = None
_completion_text: str = ""
"""The plain text of the completion."""
is_chunk: bool = False
"""Indicates if the response is a chunked response."""
"""是否是流式输出的单个 Chunk"""
def __init__(
self,
@@ -237,6 +213,7 @@ class LLMResponse:
| GenerateContentResponse
| AnthropicMessage
| None = None,
_new_record: dict[str, Any] | None = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
@@ -264,6 +241,7 @@ class LLMResponse:
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
self.is_chunk = is_chunk
@property

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import copy
import json
import os
from collections.abc import Awaitable, Callable
@@ -25,16 +24,7 @@ SUPPORTED_TYPES = [
"boolean",
] # json schema 支持的数据类型
PY_TO_JSON_TYPE = {
"int": "number",
"float": "number",
"bool": "boolean",
"str": "string",
"dict": "object",
"list": "array",
"tuple": "array",
"set": "array",
}
# alias
FuncTool = FunctionTool
@@ -116,7 +106,7 @@ class FunctionToolManager:
def spec_to_func(
self,
name: str,
func_args: list[dict],
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
@@ -125,9 +115,10 @@ class FunctionToolManager:
"properties": {},
}
for param in func_args:
p = copy.deepcopy(param)
p.pop("name", None)
params["properties"][param["name"]] = p
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
return FuncTool(
name=name,
parameters=params,

View File

@@ -241,8 +241,6 @@ class ProviderManager:
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "groq_chat_completion":
from .sources.groq_source import ProviderGroq as ProviderGroq
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
@@ -356,8 +354,6 @@ class ProviderManager:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = cls_type(provider_config, self.provider_settings)
@@ -398,6 +394,7 @@ class ProviderManager:
inst = cls_type(
provider_config,
self.provider_settings,
self.selected_default_persona,
)
if getattr(inst, "initialize", None):

View File

@@ -1,18 +1,28 @@
import abc
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Personality
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMeta,
ProviderType,
RerankResult,
ToolCallsResult,
)
from astrbot.core.provider.register import provider_cls_map
@dataclass
class ProviderMeta:
id: str
model: str
type: str
provider_type: ProviderType
class AbstractProvider(abc.ABC):
"""Provider Abstract Class"""
@@ -33,15 +43,15 @@ class AbstractProvider(abc.ABC):
"""Get the provider metadata"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
if not meta_data:
raise ValueError(f"Provider type {provider_type_name} not registered")
meta = ProviderMeta(
id=self.provider_config.get("id", "default"),
provider_type = meta_data.provider_type if meta_data else None
if provider_type is None:
raise ValueError(f"Cannot find provider type: {provider_type_name}")
return ProviderMeta(
id=self.provider_config["id"],
model=self.get_model(),
type=provider_type_name,
provider_type=meta_data.provider_type,
provider_type=provider_type,
)
return meta
class Provider(AbstractProvider):
@@ -51,10 +61,15 @@ class Provider(AbstractProvider):
self,
provider_config: dict,
provider_settings: dict,
default_persona: Personality | None = None,
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
self.curr_personality = default_persona
"""维护了当前的使用的 persona即人格。可能为 None"""
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError

View File

@@ -36,8 +36,6 @@ def register_provider_adapter(
default_config_tmpl["id"] = provider_type_name
pm = ProviderMetaData(
id="default", # will be replaced when instantiated
model=None,
type=provider_type_name,
desc=desc,
provider_type=provider_type,

View File

@@ -25,10 +25,12 @@ class ProviderAnthropic(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.chosen_api_key: str = ""

View File

@@ -20,10 +20,12 @@ class ProviderCoze(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:

View File

@@ -8,7 +8,7 @@ from dashscope.app.application_response import ApplicationResponse
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from .. import Provider
from .. import Personality, Provider
from ..entities import LLMResponse
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
@@ -20,11 +20,13 @@ class ProviderDashscope(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
default_persona: Personality | None = None,
) -> None:
Provider.__init__(
self,
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:

View File

@@ -18,10 +18,12 @@ class ProviderDify(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:

View File

@@ -53,10 +53,12 @@ class ProviderGoogleGenAI(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_keys: list = super().get_keys()
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
@@ -324,18 +326,8 @@ class ProviderGoogleGenAI(Provider):
return gemini_contents
def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
"""Extract reasoning content from candidate parts"""
if not candidate.content or not candidate.content.parts:
return ""
thought_buf: list[str] = [
(p.text or "") for p in candidate.content.parts if p.thought
]
return "".join(thought_buf).strip()
@staticmethod
def _process_content_parts(
self,
candidate: types.Candidate,
llm_response: LLMResponse,
) -> MessageChain:
@@ -366,11 +358,6 @@ class ProviderGoogleGenAI(Provider):
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
raise Exception("API 返回的 candidate.content.parts 为空。")
# 提取 reasoning content
reasoning = self._extract_reasoning_content(candidate)
if reasoning:
llm_response.reasoning_content = reasoning
chain = []
part: types.Part
@@ -528,7 +515,6 @@ class ProviderGoogleGenAI(Provider):
# Accumulate the complete response text for the final response
accumulated_text = ""
accumulated_reasoning = ""
final_response = None
async for chunk in result:
@@ -553,19 +539,9 @@ class ProviderGoogleGenAI(Provider):
yield llm_response
return
_f = False
# 提取 reasoning content
reasoning = self._extract_reasoning_content(chunk.candidates[0])
if reasoning:
_f = True
accumulated_reasoning += reasoning
llm_response.reasoning_content = reasoning
if chunk.text:
_f = True
accumulated_text += chunk.text
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
if _f:
yield llm_response
if chunk.candidates[0].finish_reason:
@@ -583,10 +559,6 @@ class ProviderGoogleGenAI(Provider):
if not final_response:
final_response = LLMResponse("assistant", is_chunk=False)
# Set the complete accumulated reasoning in the final response
if accumulated_reasoning:
final_response.reasoning_content = accumulated_reasoning
# Set the complete accumulated text in the final response
if accumulated_text:
final_response.result_chain = MessageChain(

View File

@@ -1,15 +0,0 @@
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter(
"groq_chat_completion", "Groq Chat Completion Provider Adapter"
)
class ProviderGroq(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.reasoning_key = "reasoning"

View File

@@ -4,14 +4,12 @@ import inspect
import json
import os
import random
import re
from collections.abc import AsyncGenerator
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
import astrbot.core.message.components as Comp
from astrbot import logger
@@ -30,8 +28,17 @@ from ..register import register_provider_adapter
"OpenAI API Chat Completion 提供商适配器",
)
class ProviderOpenAIOfficial(Provider):
def __init__(self, provider_config, provider_settings) -> None:
super().__init__(provider_config, provider_settings)
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.chosen_api_key = None
self.api_keys: list = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
@@ -46,8 +53,9 @@ class ProviderOpenAIOfficial(Provider):
for key in self.custom_headers:
self.custom_headers[key] = str(self.custom_headers[key])
# 适配 azure openai #332
if "api_version" in provider_config:
# Using Azure OpenAI API
# 使用 azure api
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
@@ -56,7 +64,7 @@ class ProviderOpenAIOfficial(Provider):
timeout=self.timeout,
)
else:
# Using OpenAI Official API
# 使用 openai api
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
@@ -72,8 +80,6 @@ class ProviderOpenAIOfficial(Provider):
model = model_config.get("model", "unknown")
self.set_model(model)
self.reasoning_key = "reasoning_content"
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
@@ -151,7 +157,7 @@ class ProviderOpenAIOfficial(Provider):
logger.debug(f"completion: {completion}")
llm_response = await self._parse_openai_completion(completion, tools)
llm_response = await self.parse_openai_completion(completion, tools)
return llm_response
@@ -204,78 +210,36 @@ class ProviderOpenAIOfficial(Provider):
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
# logger.debug(f"chunk delta: {delta}")
# handle the content delta
reasoning = self._extract_reasoning_content(chunk)
_y = False
if reasoning:
llm_response.reasoning_content = reasoning
_y = True
# 处理文本内容
if delta.content:
completion_text = delta.content
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)],
)
_y = True
if _y:
yield llm_response
final_completion = state.get_final_completion()
llm_response = await self._parse_openai_completion(final_completion, tools)
llm_response = await self.parse_openai_completion(final_completion, tools)
yield llm_response
def _extract_reasoning_content(
self,
completion: ChatCompletion | ChatCompletionChunk,
) -> str:
"""Extract reasoning content from OpenAI ChatCompletion if available."""
reasoning_text = ""
if len(completion.choices) == 0:
return reasoning_text
if isinstance(completion, ChatCompletion):
choice = completion.choices[0]
reasoning_attr = getattr(choice.message, self.reasoning_key, None)
if reasoning_attr:
reasoning_text = str(reasoning_attr)
elif isinstance(completion, ChatCompletionChunk):
delta = completion.choices[0].delta
reasoning_attr = getattr(delta, self.reasoning_key, None)
if reasoning_attr:
reasoning_text = str(reasoning_attr)
return reasoning_text
async def _parse_openai_completion(
async def parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse:
"""Parse OpenAI ChatCompletion into LLMResponse"""
"""解析 OpenAI ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
# parse the text completion
if choice.message.content is not None:
# text completion
completion_text = str(choice.message.content).strip()
# specially, some providers may set <think> tags around reasoning content in the completion text,
# we use regex to remove them, and store then in reasoning_content field
reasoning_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
matches = reasoning_pattern.findall(completion_text)
if matches:
llm_response.reasoning_content = "\n".join(
[match.strip() for match in matches],
)
completion_text = reasoning_pattern.sub("", completion_text).strip()
llm_response.result_chain = MessageChain().message(completion_text)
# parse the reasoning content if any
# the priority is higher than the <think> tag extraction
llm_response.reasoning_content = self._extract_reasoning_content(completion)
# parse tool calls if any
if choice.message.tool_calls and tools is not None:
# tools call (function calling)
args_ls = []
func_name_ls = []
tool_call_ids = []
@@ -301,11 +265,11 @@ class ProviderOpenAIOfficial(Provider):
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_call_ids
# specially handle finish reason
if choice.finish_reason == "content_filter":
raise Exception(
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
)
if llm_response.completion_text is None and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")

View File

@@ -12,5 +12,10 @@ class ProviderZhipu(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
default_persona=None,
) -> None:
super().__init__(provider_config, provider_settings)
super().__init__(
provider_config,
provider_settings,
default_persona,
)

View File

@@ -5,10 +5,6 @@ from typing import Any
from deprecated import deprecated
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import Message
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
@@ -17,10 +13,10 @@ from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import (
@@ -35,7 +31,6 @@ from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
)
from ..exceptions import ProviderNotFoundError
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from .star import StarMetadata, star_map, star_registry
@@ -80,153 +75,6 @@ class Context:
self.astrbot_config_mgr = astrbot_config_mgr
self.kb_manager = knowledge_base_manager
async def llm_generate(
self,
*,
chat_provider_id: str,
prompt: str | None = None,
image_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
.. versionadded:: 4.5.7 (sdk)
Args:
chat_provider_id: The chat provider ID to use.
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
tools: ToolSet of tools available to the LLM
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
contexts: context messages for the LLM
**kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
Raises:
ChatProviderNotFoundError: If the specified chat provider ID is not found
Exception: For other errors during LLM generation
"""
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
if not prov or not isinstance(prov, Provider):
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
llm_resp = await prov.text_chat(
prompt=prompt,
image_urls=image_urls,
func_tool=tools,
contexts=contexts,
system_prompt=system_prompt,
**kwargs,
)
return llm_resp
async def tool_loop_agent(
self,
*,
event: AstrMessageEvent,
chat_provider_id: str,
prompt: str | None = None,
image_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | None = None,
max_steps: int = 30,
tool_call_timeout: int = 60,
**kwargs: Any,
) -> LLMResponse:
"""Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
If you do not pass the agent_context parameter, the method will recreate a new agent context.
.. versionadded:: 4.5.7 (sdk)
Args:
chat_provider_id: The chat provider ID to use.
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
tools: ToolSet of tools available to the LLM
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
Returns:
The final LLMResponse after tool calls are completed.
Raises:
ChatProviderNotFoundError: If the specified chat provider ID is not found
Exception: For other errors during LLM generation
"""
# Import here to avoid circular imports
from astrbot.core.astr_agent_context import (
AgentContextWrapper,
AstrAgentContext,
)
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
if not prov or not isinstance(prov, Provider):
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
agent_context = kwargs.get("agent_context")
context_ = []
for msg in contexts or []:
if isinstance(msg, Message):
context_.append(msg.model_dump())
else:
context_.append(msg)
request = ProviderRequest(
prompt=prompt,
image_urls=image_urls or [],
func_tool=tools,
contexts=context_,
system_prompt=system_prompt or "",
)
if agent_context is None:
agent_context = AstrAgentContext(
context=self,
event=event,
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
await agent_runner.reset(
provider=prov,
request=request,
run_context=AgentContextWrapper(
context=agent_context,
tool_call_timeout=tool_call_timeout,
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
)
async for _ in agent_runner.step_until_done(max_steps):
pass
llm_resp = agent_runner.get_final_llm_resp()
if not llm_resp:
raise Exception("Agent did not produce a final LLM response")
return llm_resp
async def get_current_chat_provider_id(self, umo: str) -> str:
"""Get the ID of the currently used chat provider.
Args:
umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
Raises:
ProviderNotFoundError: If the specified chat provider is not found
"""
prov = self.get_using_provider(umo)
if not prov:
raise ProviderNotFoundError("Provider not found")
return prov.meta().id
def get_registered_star(self, star_name: str) -> StarMetadata | None:
"""根据插件名获取插件的 Metadata"""
for star in star_registry:
@@ -259,6 +107,10 @@ class Context:
"""
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
def register_provider(self, provider: Provider):
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(
self,
provider_id: str,
@@ -337,6 +189,45 @@ class Context:
return self._config
return self.astrbot_config_mgr.get_conf(umo)
def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
return self._db
def get_event_queue(self) -> Queue:
"""获取事件队列。"""
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
if isinstance(platform_type, str):
if name == platform_type:
return platform
elif (
name in ADAPTER_NAME_2_TYPE
and ADAPTER_NAME_2_TYPE[name] & platform_type
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
async def send_message(
self,
session: str | MessageSesion,
@@ -409,49 +300,6 @@ class Context:
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""
def get_event_queue(self) -> Queue:
"""获取事件队列。"""
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
if isinstance(platform_type, str):
if name == platform_type:
return platform
elif (
name in ADAPTER_NAME_2_TYPE
and ADAPTER_NAME_2_TYPE[name] & platform_type
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
return self._db
def register_provider(self, provider: Provider):
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
self.provider_manager.provider_insts.append(provider)
def register_llm_tool(
self,
name: str,

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import re
from collections.abc import Awaitable, Callable
from typing import Any
@@ -12,7 +11,7 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from ..filter.command import CommandFilter
@@ -418,37 +417,18 @@ def register_llm_tool(name: str | None = None, **kwargs):
docstring = docstring_parser.parse(func_doc)
args = []
for arg in docstring.params:
sub_type_name = None
type_name = arg.type_name
if not type_name:
raise ValueError(
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 的参数 {arg.arg_name} 缺少类型注释。",
)
# parse type_name to handle cases like "list[string]"
match = re.match(r"(\w+)\[(\w+)\]", type_name)
if match:
type_name = match.group(1)
sub_type_name = match.group(2)
type_name = PY_TO_JSON_TYPE.get(type_name, type_name)
if sub_type_name:
sub_type_name = PY_TO_JSON_TYPE.get(sub_type_name, sub_type_name)
if type_name not in SUPPORTED_TYPES or (
sub_type_name and sub_type_name not in SUPPORTED_TYPES
):
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}",
)
arg_json_schema = {
"type": type_name,
"name": arg.arg_name,
"description": arg.description,
}
if sub_type_name:
if type_name == "array":
arg_json_schema["items"] = {"type": sub_type_name}
args.append(arg_json_schema)
args.append(
{
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description,
},
)
# print(llm_tool_name, registering_agent)
if not registering_agent:
doc_desc = docstring.description.strip() if docstring.description else ""
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)

View File

@@ -204,8 +204,6 @@ class ChatRoute(Route):
):
# 追加机器人消息
new_his = {"type": "bot", "message": result_text}
if "reasoning" in result:
new_his["reasoning"] = result["reasoning"]
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,

134
build_nuitka.py Normal file
View File

@@ -0,0 +1,134 @@
#!/usr/bin/env python3
"""
Use Nuitka to build the AstrBot project into standalone executables
"""
import os
import platform
import subprocess
import sys
from pathlib import Path
def get_platform_info():
"""fetch the current platform information"""
system = platform.system()
machine = platform.machine()
return system, machine
def build_with_nuitka():
"""use Nuitka to build the project"""
system, machine = get_platform_info()
print(f"🚀 Starting build for {system} ({machine}) platform...")
# Output directory
output_dir = Path("build/nuitka")
output_dir.mkdir(parents=True, exist_ok=True)
# Base Nuitka command
nuitka_cmd = [
sys.executable,
"-m",
"nuitka",
"--standalone", # Create standalone directory
"--onefile", # Single file mode
"--follow-imports", # Follow all imports
"--enable-plugin=multiprocessing", # Enable multiprocessing support
"--output-dir=build/nuitka", # Output directory
"--quiet", # Reduce output verbosity
"--assume-yes-for-downloads", # Automatically download dependencies
"--jobs=4", # Use multiple CPU cores
]
# include specific packages
include_packages = [
"astrbot",
]
for pkg in include_packages:
nuitka_cmd.extend([f"--include-package={pkg}"])
# include data directories
# data_includes = [
# "data/config",
# "data/plugins",
# "data/temp",
# ]
# for data_dir in data_includes:
# if os.path.exists(data_dir):
# nuitka_cmd.extend([f"--include-data-dir={data_dir}={data_dir}"])
# include packages directory (built-in plugins)
# if os.path.exists("packages"):
# nuitka_cmd.extend(["--include-data-dir=packages=packages"])
# Platform specific settings
if system == "Darwin": # macOS
nuitka_cmd.extend(
[
"--macos-create-app-bundle", # Create .app bundle
"--macos-app-name=AstrBot",
]
)
# macOS icon (if exists)
icon_path = "dashboard/src-tauri/icons/icon.icns"
if os.path.exists(icon_path):
nuitka_cmd.extend([f"--macos-app-icon={icon_path}"])
elif system == "Windows":
nuitka_cmd.extend(
[
"--windows-console-mode=disable", # 无控制台窗口
]
)
# Windows icon (if exists)
icon_path = "dashboard/src-tauri/icons/icon.ico"
if os.path.exists(icon_path):
nuitka_cmd.extend([f"--windows-icon-from-ico={icon_path}"])
# Main file to compile
nuitka_cmd.append("main.py")
print(f"📦 Executing command: {' '.join(nuitka_cmd)}")
try:
subprocess.run(nuitka_cmd, check=True)
print("✅ Nuitka build successful!")
# Find the generated executable
if system == "Darwin":
built_file = list(output_dir.glob("*.app"))
if built_file:
print(f"Generated macOS app: {built_file[0]}")
elif system == "Windows":
built_file = list(output_dir.glob("*.exe"))
if built_file:
print(f"Generated Windows executable: {built_file[0]}")
else: # Linux
built_file = list(output_dir.glob("main.bin"))
if built_file:
print(f"Generated Linux executable: {built_file[0]}")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Nuitka build failed: {e}")
return False
if __name__ == "__main__":
print("=" * 60)
print("AstrBot Nuitka Builder")
print("=" * 60)
# 构建
if build_with_nuitka():
print("\n" + "=" * 60)
print("🎉 Build Complete!")
print("=" * 60)
else:
print("\n" + "=" * 60)
print("❌ Build Failed")
print("=" * 60)
sys.exit(1)

134
build_pyinstaller.py Normal file
View File

@@ -0,0 +1,134 @@
#!/usr/bin/env python3
"""
Use PyInstaller to build the AstrBot project into standalone executables
"""
import platform
import subprocess
import sys
from pathlib import Path
def get_platform_info():
"""fetch the current platform information"""
system = platform.system()
machine = platform.machine()
return system, machine
def build_with_pyinstaller():
"""use PyInstaller to build the project"""
system, machine = get_platform_info()
print(f"🚀 Starting build for {system} ({machine}) platform...")
# Output directory
output_dir = Path("build/pyinstaller")
output_dir.mkdir(parents=True, exist_ok=True)
# Base PyInstaller command
pyinstaller_cmd = [
sys.executable,
"-m",
"PyInstaller",
"--clean", # Clean cache before build
"--noconfirm", # Replace output directory without asking
"--onefile", # Single file mode
"--distpath=build/pyinstaller/dist", # Distribution directory
"--workpath=build/pyinstaller/build", # Work directory
"--specpath=build/pyinstaller", # Spec file directory
"--name=AstrBot", # Output executable name
]
# Platform specific settings
# if system == "Darwin": # macOS
# # macOS icon (if exists)
# icon_path = "dashboard/src-tauri/icons/icon.icns"
# if os.path.exists(icon_path):
# pyinstaller_cmd.extend([f"--icon={icon_path}"])
# # Create .app bundle
# pyinstaller_cmd.extend(["--windowed"])
# elif system == "Windows":
# # Windows icon (if exists)
# icon_path = "dashboard/src-tauri/icons/icon.ico"
# if os.path.exists(icon_path):
# pyinstaller_cmd.extend([f"--icon={icon_path}"])
# # No console window
# pyinstaller_cmd.extend(["--windowed"])
# else: # Linux
# pyinstaller_cmd.extend(["--console"])
# Main file to compile
pyinstaller_cmd.append("main.py")
print(f"📦 Executing command: {' '.join(pyinstaller_cmd)}")
try:
subprocess.run(pyinstaller_cmd, check=True)
print("✅ PyInstaller build successful!")
# Find the generated executable
dist_dir = output_dir / "dist"
if system == "Darwin":
built_file = list(dist_dir.glob("AstrBot.app"))
if not built_file:
built_file = list(dist_dir.glob("AstrBot"))
if built_file:
print(f"📱 Generated macOS app: {built_file[0]}")
elif system == "Windows":
built_file = list(dist_dir.glob("AstrBot.exe"))
if built_file:
print(f"💻 Generated Windows executable: {built_file[0]}")
else: # Linux
built_file = list(dist_dir.glob("AstrBot"))
if built_file:
print(f"🐧 Generated Linux executable: {built_file[0]}")
print(f"\n📁 Output directory: {dist_dir.absolute()}")
return True
except subprocess.CalledProcessError as e:
print(f"❌ PyInstaller build failed: {e}")
return False
except Exception as e:
print(f"❌ Unexpected error: {e}")
return False
def install_pyinstaller():
"""Install PyInstaller if not already installed"""
try:
import PyInstaller
print(f"✅ PyInstaller already installed (version {PyInstaller.__version__})")
return True
except ImportError:
print("📥 PyInstaller not found, installing...")
try:
subprocess.run(
[sys.executable, "-m", "pip", "install", "pyinstaller"], check=True
)
print("✅ PyInstaller installed successfully!")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Failed to install PyInstaller: {e}")
return False
if __name__ == "__main__":
print("=" * 60)
print("AstrBot PyInstaller Builder")
print("=" * 60)
# Check and install PyInstaller
if not install_pyinstaller():
sys.exit(1)
# Build
if build_with_pyinstaller():
print("\n" + "=" * 60)
print("🎉 Build Complete!")
print("=" * 60)
else:
print("\n" + "=" * 60)
print("❌ Build Failed")
print("=" * 60)
sys.exit(1)

View File

@@ -1,12 +0,0 @@
## What's Changed
1. 新增:支持为 OpenAI API 提供商自定义请求头 ([#3581](https://github.com/AstrBotDevs/AstrBot/issues/3581))
2. 新增:为 WebChat 为 Thinking 模型添加思考过程展示功能;支持快捷切换流式输出 / 非流式输出。([#3632](https://github.com/AstrBotDevs/AstrBot/issues/3632))
3. 新增:优化插件调用 LLM 和 Agent 的路径,为 Context 类引入多个调用 LLM 和 Agent 的便捷方法 ([#3636](https://github.com/AstrBotDevs/AstrBot/issues/3636))
4. 优化:改善不支持流式输出的消息平台的回退策略 ([#3547](https://github.com/AstrBotDevs/AstrBot/issues/3547))
5. 优化当同一个会话umo下同时有多个请求时执行排队处理避免并发请求导致的上下文混乱问题 ([#3607](https://github.com/AstrBotDevs/AstrBot/issues/3607))
6. 优化:优化 WebUI 的登录界面和 Changelog 页面的显示效果
7. 修复:修复在知识库名字过长的情况下,“选择知识库”按钮显示异常的问题 ([#3582](https://github.com/AstrBotDevs/AstrBot/issues/3582))
8. 修复:修复部分情况下,分段消息发送时导致的死锁问题(由 PR #3607 引入)
9. 修复:钉钉适配器使用部分指令无法生效的问题 ([#3634](https://github.com/AstrBotDevs/AstrBot/issues/3634))
10. 其他:为部分适配器添加缺失的 send_streaming 方法 ([#3545](https://github.com/AstrBotDevs/AstrBot/issues/3545))

View File

@@ -1,5 +0,0 @@
## What's Changed
hot fix of 4.5.7
fix: 无法正常发送图片,报错 `pydantic_core._pydantic_core.ValidationError`

225
dashboard/TAURI_README.md Normal file
View File

@@ -0,0 +1,225 @@
# AstrBot Dashboard - Tauri 桌面应用
本项目现已支持通过 Tauri 构建为桌面应用,同时保持与 Web 版本的兼容性。
## 环境要求
### 系统依赖
**macOS:**
```bash
# 安装 Xcode Command Line Tools
xcode-select --install
```
**Windows:**
- 安装 [Microsoft Visual Studio C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/)
- 安装 [WebView2](https://developer.microsoft.com/en-us/microsoft-edge/webview2/)
**Linux (Ubuntu/Debian):**
```bash
sudo apt update
sudo apt install libwebkit2gtk-4.0-dev \
build-essential \
curl \
wget \
file \
libssl-dev \
libgtk-3-dev \
libayatana-appindicator3-dev \
librsvg2-dev
```
### Rust 环境
```bash
# 安装 Rust
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
# 验证安装
rustc --version
cargo --version
```
## 安装依赖
```bash
cd dashboard
npm install
```
## 开发模式
### Web 端开发(不变)
```bash
npm run dev
```
访问 http://localhost:3000
### 桌面端开发
```bash
npm run tauri:dev
```
这会同时启动:
1. Vite 开发服务器(端口 3000
2. Tauri 桌面应用窗口
热重载功能正常工作,修改代码后会自动刷新。
## 构建
### Web 端构建(不变)
```bash
npm run build
```
输出目录:`dist/`
### 桌面端构建
```bash
npm run tauri:build
```
构建产物位置:
- **macOS**: `src-tauri/target/release/bundle/dmg/`
- **Windows**: `src-tauri/target/release/bundle/msi/`
- **Linux**: `src-tauri/target/release/bundle/deb/``appimage/`
## 图标设置
### 自动生成图标
准备一个至少 512x512 像素的 PNG 图标,然后运行:
```bash
npm run tauri icon path/to/your/icon.png
```
### 手动设置图标
将以下图标放入 `src-tauri/icons/` 目录:
- `32x32.png`
- `128x128.png`
- `128x128@2x.png`
- `icon.icns` (macOS)
- `icon.ico` (Windows)
## 代码兼容性
项目已配置为同时支持 Web 和桌面端,使用相同的代码库。
### 环境检测工具
`src/utils/tauri.ts` 中提供了环境检测工具:
```typescript
import { isTauri, isWeb, PlatformAPI } from '@/utils/tauri';
// 检测运行环境
if (isTauri()) {
console.log('运行在桌面应用中');
} else {
console.log('运行在浏览器中');
}
// 获取正确的 API 端点
const baseURL = PlatformAPI.getBaseURL();
```
### API 调用注意事项
- **Web 端**: 使用 Vite 代理API 路径为 `/api/*`
- **桌面端**: 直接连接到 `http://127.0.0.1:6185`
已在 `PlatformAPI.getBaseURL()` 中处理,使用 axios 时:
```typescript
import axios from 'axios';
import { PlatformAPI } from '@/utils/tauri';
const api = axios.create({
baseURL: PlatformAPI.getBaseURL()
});
```
## 配置说明
### tauri.conf.json
主要配置项:
- `build.devPath`: 开发服务器地址http://localhost:3000
- `build.distDir`: 构建输出目录(../dist
- `tauri.allowlist`: API 权限配置
- `tauri.windows`: 窗口配置(大小、标题等)
### 安全性
默认配置已启用必要的权限:
- 文件系统访问(限定在 APPDATA 目录)
- HTTP 请求(限定到本地后端)
- 窗口控制
- 对话框(打开/保存文件)
可在 `tauri.conf.json``allowlist` 部分调整权限。
## 后端连接
桌面应用需要后端服务运行在 `http://127.0.0.1:6185`
### 启动流程
1. 启动 AstrBot 后端:
```bash
cd /path/to/AstrBot
uv run main.py
```
2. 启动桌面应用:
```bash
cd dashboard
npm run tauri:dev
```
或直接运行打包后的应用(后端需要已启动)。
## 常见问题
### Q: 桌面应用无法连接到后端?
确保:
1. AstrBot 后端正在运行(`uv run main.py`
2. 后端监听在 `127.0.0.1:6185`
3. 防火墙未阻止连接
### Q: 图标未显示?
检查 `src-tauri/icons/` 目录中是否有所需的图标文件,或使用 `npm run tauri icon` 命令生成。
### Q: 构建失败?
- 确保已安装 Rust 和系统依赖
- 运行 `cargo clean` 清理缓存后重试
- 检查 Rust 版本(需要 1.60+
### Q: Web 端功能是否受影响?
不受影响。`npm run dev` 和 `npm run build` 的行为完全不变。
## 开发建议
1. **优先使用 Web 端开发**: 更快的热重载,更好的调试体验
2. **定期测试桌面端**: 确保跨平台兼容性
3. **使用环境检测**: 针对不同平台提供最佳体验
4. **注意 API 差异**: Web 和桌面端的某些 API 可能有差异
## 更多资源
- [Tauri 官方文档](https://tauri.app/)
- [Tauri API 参考](https://tauri.app/v1/api/js/)
- [Tauri Discord 社区](https://discord.com/invite/tauri)

View File

@@ -10,10 +10,14 @@
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
"preview": "vite preview --port 5050",
"typecheck": "vue-tsc --noEmit",
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore",
"tauri": "tauri",
"tauri:dev": "tauri dev",
"tauri:build": "tauri build"
},
"dependencies": {
"@guolao/vue-monaco-editor": "^1.5.4",
"@tauri-apps/api": "^2.9.0",
"@tiptap/starter-kit": "2.1.7",
"@tiptap/vue-3": "2.1.7",
"apexcharts": "3.42.0",
@@ -43,6 +47,7 @@
"devDependencies": {
"@mdi/font": "7.2.96",
"@rushstack/eslint-patch": "1.3.3",
"@tauri-apps/cli": "^2.9.4",
"@types/chance": "1.1.3",
"@types/markdown-it": "^14.1.2",
"@types/node": "^20.5.7",

4509
dashboard/pnpm-lock.yaml generated Normal file

File diff suppressed because it is too large Load Diff

3
dashboard/src-tauri/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
# Tauri specific
src-tauri/target/
src-tauri/WixTools/

4692
dashboard/src-tauri/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
[package]
name = "astrbot-dashboard"
version = "4.5.6"
description = "AstrBot"
authors = ["AstrBot Team"]
license = "AGPL-3.0"
repository = "https://github.com/AstrBotDevs/AstrBot"
default-run = "astrbot-dashboard"
edition = "2021"
rust-version = "1.91.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
tauri-build = { version = "2", features = [] }
[dependencies]
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
tauri = { version = "2.9.2", features = ["macos-private-api", "protocol-asset"] }
tauri-plugin-opener = "2"
[features]
# this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled.
# If you use cargo directly instead of tauri's cli you can use this feature flag to switch between tauri's `dev` and `build` modes.
# DO NOT REMOVE!!
custom-protocol = [ "tauri/custom-protocol" ]

View File

@@ -0,0 +1,3 @@
fn main() {
tauri_build::build()
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

View File

@@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
<background android:drawable="@color/ic_launcher_background"/>
</adaptive-icon>

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.6 KiB

View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="ic_launcher_background">#fff</color>
</resources>

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 602 B

Some files were not shown because too many files have changed in this diff Show More