Compare commits
5 Commits
v4.5.8
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
381f7f4405 | ||
|
|
8f38e748cd | ||
|
|
f83484a8c0 | ||
|
|
56e3ddd62a | ||
|
|
80948be41d |
@@ -1,9 +1,9 @@
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
# github actions
|
||||
.git
|
||||
# github acions
|
||||
.github/
|
||||
.*ignore
|
||||
.git/
|
||||
# User-specific stuff
|
||||
.idea/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
@@ -15,10 +15,10 @@ env/
|
||||
venv*/
|
||||
ENV/
|
||||
.conda/
|
||||
README*.md
|
||||
dashboard/
|
||||
data/
|
||||
changelogs/
|
||||
tests/
|
||||
.ruff_cache/
|
||||
.astrbot
|
||||
astrbot.lock
|
||||
.astrbot
|
||||
2
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
2
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
不熟悉 JSON ?可以从 [此站](https://plugins.astrbot.app) 右下角提交。
|
||||
不熟悉 JSON ?可以从 [此处](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
|
||||
|
||||
- type: textarea
|
||||
id: plugin-info
|
||||
|
||||
18
Dockerfile
18
Dockerfile
@@ -12,21 +12,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
bash \
|
||||
ffmpeg \
|
||||
curl \
|
||||
gnupg \
|
||||
git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN apt-get update && apt-get install -y curl gnupg \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
|
||||
&& apt-get install -y nodejs
|
||||
RUN apt-get update && apt-get install -y curl gnupg && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m pip install uv \
|
||||
&& echo "3.11" > .python-version
|
||||
RUN python -m pip install uv
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
CMD [ "python", "main.py" ]
|
||||
|
||||
35
Dockerfile_with_node
Normal file
35
Dockerfile_with_node
Normal file
@@ -0,0 +1,35 @@
|
||||
FROM python:3.10-slim
|
||||
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
curl \
|
||||
unzip \
|
||||
ca-certificates \
|
||||
bash \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Installation of Node.js
|
||||
ENV NVM_DIR="/root/.nvm"
|
||||
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
|
||||
. "$NVM_DIR/nvm.sh" && \
|
||||
nvm install 22 && \
|
||||
nvm use 22
|
||||
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
|
||||
|
||||
RUN python -m pip install uv
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
116
README.md
116
README.md
@@ -8,7 +8,7 @@
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=1" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
@@ -119,73 +119,83 @@ uv run main.py
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## 支持的消息平台
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
**官方维护**
|
||||
|
||||
- QQ (官方平台 & OneBot)
|
||||
- Telegram
|
||||
- 企微应用 & 企微智能机器人
|
||||
- 微信客服 & 微信公众号
|
||||
- 飞书
|
||||
- 钉钉
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- Whatsapp (将支持)
|
||||
- LINE (将支持)
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方平台) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企微应用 | ✔ |
|
||||
| 企微智能机器人 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
| 微信公众号 | ✔ |
|
||||
| 飞书 | ✔ |
|
||||
| 钉钉 | ✔ |
|
||||
| Slack | ✔ |
|
||||
| Discord | ✔ |
|
||||
| Satori | ✔ |
|
||||
| Misskey | ✔ |
|
||||
| Whatsapp | 将支持 |
|
||||
| LINE | 将支持 |
|
||||
|
||||
**社区维护**
|
||||
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
|
||||
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | ✔ |
|
||||
|
||||
## 支持的模型服务
|
||||
## ⚡ 提供商支持情况
|
||||
|
||||
**大模型服务**
|
||||
|
||||
- OpenAI 及兼容服务
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- 智谱 AI
|
||||
- DeepSeek
|
||||
- Ollama (本地部署)
|
||||
- LM Studio (本地部署)
|
||||
- [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [小马算力](https://www.tokenpony.cn/3YPyf)
|
||||
- [硅基流动](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**LLMOps 平台**
|
||||
|
||||
- Dify
|
||||
- 阿里云百炼应用
|
||||
- Coze
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
|
||||
| Anthropic | ✔ | |
|
||||
| Google Gemini | ✔ | |
|
||||
| Moonshot AI | ✔ | |
|
||||
| 智谱 AI | ✔ | |
|
||||
| DeepSeek | ✔ | |
|
||||
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
|
||||
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
|
||||
| 硅基流动 | ✔ | |
|
||||
| PPIO 派欧云 | ✔ | |
|
||||
| ModelScope | ✔ | |
|
||||
| OneAPI | ✔ | |
|
||||
| Dify | ✔ | |
|
||||
| 阿里云百炼应用 | ✔ | |
|
||||
| Coze | ✔ | |
|
||||
|
||||
**语音转文本服务**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| Whisper | ✔ | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 本地部署 |
|
||||
|
||||
**文本转语音服务**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- 阿里云百炼 TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- 火山引擎 TTS
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI TTS | ✔ | |
|
||||
| Gemini TTS | ✔ | |
|
||||
| GSVI | ✔ | GPT-Sovits-Inference |
|
||||
| GPT-SoVITs | ✔ | GPT-Sovits |
|
||||
| FishAudio | ✔ | |
|
||||
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | |
|
||||
| Azure TTS | ✔ | |
|
||||
| Minimax TTS | ✔ | |
|
||||
| 火山引擎 TTS | ✔ | |
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -219,7 +229,7 @@ pre-commit install
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我们维护这个开源项目的动力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
@@ -36,8 +36,7 @@ from astrbot.core.star.config import *
|
||||
|
||||
|
||||
# provider
|
||||
from astrbot.core.provider import Provider, ProviderMetaData
|
||||
from astrbot.core.db.po import Personality
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
|
||||
# platform
|
||||
from astrbot.core.platform import (
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from astrbot.core.db.po import Personality
|
||||
from astrbot.core.provider import Provider, STTProvider
|
||||
from astrbot.core.provider import Personality, Provider, STTProvider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderMetaData,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
@@ -48,7 +48,7 @@ def init() -> None:
|
||||
|
||||
try:
|
||||
with lock.acquire():
|
||||
asyncio.run(initialize_astrbot(astrbot_root))
|
||||
anyio.run(initialize_astrbot, astrbot_root)
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path):
|
||||
async def run_astrbot(astrbot_root: Path) -> None:
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
@@ -53,7 +53,7 @@ def run(reload: bool, port: str) -> None:
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
with lock.acquire():
|
||||
asyncio.run(run_astrbot(astrbot_root))
|
||||
anyio.run(run_astrbot, astrbot_root)
|
||||
except KeyboardInterrupt:
|
||||
click.echo("AstrBot 已关闭...")
|
||||
except Timeout:
|
||||
|
||||
@@ -76,7 +76,7 @@ class ImageURLPart(ContentPart):
|
||||
"""The ID of the image, to allow LLMs to distinguish different images."""
|
||||
|
||||
type: str = "image_url"
|
||||
image_url: ImageURL
|
||||
image_url: str
|
||||
|
||||
|
||||
class AudioURLPart(ContentPart):
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from .message import Message
|
||||
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
context: TContext
|
||||
messages: list[Message] = Field(default_factory=list)
|
||||
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||
|
||||
|
||||
|
||||
@@ -40,13 +40,6 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
"""Process a single step of the agent."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step_until_done(
|
||||
self, max_step: int
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""Process steps until the agent is done."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def done(self) -> bool:
|
||||
"""Check if the agent has completed its task.
|
||||
|
||||
@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..message import AssistantMessageSegment, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
@@ -55,20 +55,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
messages = []
|
||||
# append existing messages in the run context
|
||||
for msg in request.contexts:
|
||||
messages.append(Message.model_validate(msg))
|
||||
if request.prompt is not None:
|
||||
m = await request.assemble_context()
|
||||
messages.append(Message.model_validate(m))
|
||||
if request.system_prompt:
|
||||
messages.insert(
|
||||
0,
|
||||
Message(role="system", content=request.system_prompt),
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
@@ -110,22 +96,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
elif llm_response.completion_text:
|
||||
else:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_response.completion_text),
|
||||
),
|
||||
)
|
||||
elif llm_response.reasoning_content:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="reasoning").message(
|
||||
llm_response.reasoning_content,
|
||||
),
|
||||
),
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
break # got final response
|
||||
@@ -153,13 +130,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
),
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
@@ -186,16 +156,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
result.type = "tool_call_result"
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
data=AgentResponseData(chain=result),
|
||||
@@ -208,23 +175,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
# record the assistant message with tool calls
|
||||
self.run_context.messages.extend(
|
||||
tool_calls_result.to_openai_messages_model()
|
||||
)
|
||||
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
async def step_until_done(
|
||||
self, max_step: int
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""Process steps until the agent is done."""
|
||||
step_count = 0
|
||||
while not self.done() and step_count < max_step:
|
||||
step_count += 1
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
|
||||
@@ -4,13 +4,12 @@ from typing import Any, Generic
|
||||
import jsonschema
|
||||
import mcp
|
||||
from deprecated import deprecated
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
ToolExecResult = str | mcp.types.CallToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -56,14 +55,23 @@ class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
||||
|
||||
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
|
||||
def __dict__(self) -> dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"parameters": self.parameters,
|
||||
"description": self.description,
|
||||
"active": self.active,
|
||||
}
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[TContext], **kwargs
|
||||
) -> str | mcp.types.CallToolResult:
|
||||
"""Run the tool with the given arguments. The handler field has priority."""
|
||||
raise NotImplementedError(
|
||||
"FunctionTool.call() must be implemented by subclasses or set a handler."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSet:
|
||||
"""A set of function tools that can be used in function calling.
|
||||
|
||||
@@ -71,7 +79,8 @@ class ToolSet:
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
|
||||
"""
|
||||
|
||||
tools: list[FunctionTool] = Field(default_factory=list)
|
||||
def __init__(self, tools: list[FunctionTool] | None = None):
|
||||
self.tools: list[FunctionTool] = tools or []
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class AstrAgentContext:
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
provider: Provider
|
||||
first_provider_request: ProviderRequest
|
||||
curr_provider_request: ProviderRequest
|
||||
streaming: bool
|
||||
event: AstrMessageEvent
|
||||
"""The message event associated with the agent context."""
|
||||
extra: dict[str, str] = Field(default_factory=dict)
|
||||
"""Customized extra data."""
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.pipeline.context_utils import call_event_hook
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
llm_response,
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
tool: FunctionTool[Any],
|
||||
tool_args: dict | None,
|
||||
tool_result: CallToolResult | None,
|
||||
):
|
||||
run_context.context.event.clear_result()
|
||||
|
||||
|
||||
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
pass
|
||||
|
||||
|
||||
MAIN_AGENT_HOOKS = MainAgentHooks()
|
||||
@@ -1,80 +0,0 @@
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner,
|
||||
max_step: int = 30,
|
||||
show_tool_use: bool = True,
|
||||
stream_to_general: bool = False,
|
||||
show_reasoning: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use:
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if stream_to_general and resp.type == "streaming_delta":
|
||||
continue
|
||||
|
||||
if stream_to_general or not agent_runner.streaming:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
astr_event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
),
|
||||
)
|
||||
yield
|
||||
astr_event.clear_result()
|
||||
elif resp.type == "streaming_delta":
|
||||
chain = resp.data["chain"]
|
||||
if chain.type == "reasoning" and not show_reasoning:
|
||||
# display the reasoning content only when configured
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
@@ -1,246 +0,0 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
|
||||
import mcp
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import (
|
||||
CommandResult,
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
)
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@classmethod
|
||||
async def execute(cls, tool, run_context, **tool_args):
|
||||
"""执行函数调用。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||||
**kwargs: 函数调用的参数。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||
|
||||
"""
|
||||
if isinstance(tool, HandoffTool):
|
||||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
elif isinstance(tool, MCPTool):
|
||||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
else:
|
||||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_handoff(
|
||||
cls,
|
||||
tool: HandoffTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
input_ = tool_args.get("input")
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
if tools:
|
||||
toolset = ToolSet()
|
||||
for t in tools:
|
||||
if isinstance(t, str):
|
||||
_t = llm_tools.get_func(t)
|
||||
if _t:
|
||||
toolset.add_tool(_t)
|
||||
elif isinstance(t, FunctionTool):
|
||||
toolset.add_tool(t)
|
||||
else:
|
||||
toolset = None
|
||||
|
||||
ctx = run_context.context.context
|
||||
event = run_context.context.event
|
||||
umo = event.unified_msg_origin
|
||||
prov_id = await ctx.get_current_chat_provider_id(umo)
|
||||
llm_resp = await ctx.tool_loop_agent(
|
||||
event=event,
|
||||
chat_provider_id=prov_id,
|
||||
prompt=input_,
|
||||
system_prompt=tool.agent.instructions,
|
||||
tools=toolset,
|
||||
max_steps=30,
|
||||
run_hooks=tool.agent.run_hooks,
|
||||
)
|
||||
yield mcp.types.CallToolResult(
|
||||
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
event = run_context.context.event
|
||||
if not event:
|
||||
raise ValueError("Event must be provided for local function tools.")
|
||||
|
||||
is_override_call = False
|
||||
for ty in type(tool).mro():
|
||||
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
||||
is_override_call = True
|
||||
break
|
||||
|
||||
# 检查 tool 下有没有 run 方法
|
||||
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
awaitable = None
|
||||
method_name = ""
|
||||
if tool.handler:
|
||||
awaitable = tool.handler
|
||||
method_name = "decorator_handler"
|
||||
elif is_override_call:
|
||||
awaitable = tool.call
|
||||
method_name = "call"
|
||||
elif hasattr(tool, "run"):
|
||||
awaitable = getattr(tool, "run")
|
||||
method_name = "run"
|
||||
if awaitable is None:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
wrapper = call_local_llm_tool(
|
||||
context=run_context,
|
||||
handler=awaitable,
|
||||
method_name=method_name,
|
||||
**tool_args,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
anext(wrapper),
|
||||
timeout=run_context.tool_call_timeout,
|
||||
)
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
if res := run_context.context.event.get_result():
|
||||
if res.chain:
|
||||
try:
|
||||
await event.send(
|
||||
MessageChain(
|
||||
chain=res.chain,
|
||||
type="tool_direct_result",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Tool 直接发送消息失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield None
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(
|
||||
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
res = await tool.call(run_context, **tool_args)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
event = context.context.event
|
||||
|
||||
try:
|
||||
if method_name == "run" or method_name == "decorator_handler":
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
elif method_name == "call":
|
||||
ready_to_call = handler(context, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.5.8"
|
||||
VERSION = "4.5.1"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -68,7 +68,7 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"unsupported_streaming_strategy": "realtime_segmenting",
|
||||
"streaming_segmented": False,
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
},
|
||||
@@ -740,7 +740,6 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||
@@ -756,7 +755,6 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -770,7 +768,6 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"xai_native_search": False,
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
@@ -802,7 +799,6 @@ CONFIG_METADATA_2 = {
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -817,7 +813,6 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "llama-3.1-8b",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -834,7 +829,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -876,24 +870,6 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"Groq": {
|
||||
"id": "groq_default",
|
||||
"provider": "groq",
|
||||
"type": "groq_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "openai/gpt-oss-20b",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
@@ -907,7 +883,6 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -924,7 +899,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -941,7 +915,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
@@ -957,7 +930,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
@@ -972,7 +944,6 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -986,7 +957,6 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -1002,8 +972,6 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Dify": {
|
||||
@@ -1060,7 +1028,6 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -1073,7 +1040,6 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
"timeout": 60,
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"Whisper(API)": {
|
||||
@@ -1355,12 +1321,6 @@ CONFIG_METADATA_2 = {
|
||||
"render_type": "checkbox",
|
||||
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
|
||||
},
|
||||
"custom_headers": {
|
||||
"description": "自定义添加请求头",
|
||||
"type": "dict",
|
||||
"items": {},
|
||||
"hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。",
|
||||
},
|
||||
"custom_extra_body": {
|
||||
"description": "自定义请求体参数",
|
||||
"type": "dict",
|
||||
@@ -2010,8 +1970,8 @@ CONFIG_METADATA_2 = {
|
||||
"show_tool_use_status": {
|
||||
"type": "bool",
|
||||
},
|
||||
"unsupported_streaming_strategy": {
|
||||
"type": "string",
|
||||
"streaming_segmented": {
|
||||
"type": "bool",
|
||||
},
|
||||
"max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
@@ -2316,15 +2276,9 @@ CONFIG_METADATA_3 = {
|
||||
"description": "流式回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.unsupported_streaming_strategy": {
|
||||
"description": "不支持流式回复的平台",
|
||||
"type": "string",
|
||||
"options": ["realtime_segmenting", "turn_off"],
|
||||
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
|
||||
"labels": ["实时分段回复", "关闭流式回复"],
|
||||
"condition": {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
"provider_settings.streaming_segmented": {
|
||||
"description": "不支持流式回复的平台采取分段输出",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
|
||||
@@ -14,7 +14,8 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot.core import LogBroker, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -104,7 +105,9 @@ class AstrBotCoreLifecycle:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
self._event_queue_send, self.event_queue = anyio.create_memory_object_stream[
|
||||
object
|
||||
](0)
|
||||
|
||||
# 初始化人格管理器
|
||||
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
|
||||
@@ -118,7 +121,9 @@ class AstrBotCoreLifecycle:
|
||||
)
|
||||
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
self.platform_manager = PlatformManager(
|
||||
self.astrbot_config, self._event_queue_send
|
||||
)
|
||||
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
@@ -131,7 +136,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self._event_queue_send,
|
||||
self.astrbot_config,
|
||||
self.db,
|
||||
self.provider_manager,
|
||||
|
||||
@@ -271,7 +271,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.user_id) == user_id
|
||||
col(ConversationV2.user_id) == user_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""事件总线, 用于处理事件的分发和处理
|
||||
"""事件总线, 用于处理事件的分发和处理.
|
||||
|
||||
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
||||
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||
|
||||
@@ -10,8 +11,8 @@ class:
|
||||
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from asyncio import Queue
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -25,28 +26,29 @@ class EventBus:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: Queue,
|
||||
event_queue: MemoryObjectReceiveStream[AstrMessageEvent],
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
):
|
||||
astrbot_config_mgr: AstrBotConfigManager | None = None,
|
||||
) -> None:
|
||||
self.event_queue = event_queue # 事件队列
|
||||
# abconf uuid -> scheduler
|
||||
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
|
||||
async def dispatch(self):
|
||||
async def dispatch(self) -> None:
|
||||
while True:
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
event: AstrMessageEvent = await self.event_queue.receive()
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
anyio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None:
|
||||
"""用于记录事件信息
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
event: 事件对象
|
||||
conf_name: 配置名称
|
||||
|
||||
"""
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AstrBotError(Exception):
|
||||
"""Base exception for all AstrBot errors."""
|
||||
|
||||
|
||||
class ProviderNotFoundError(AstrBotError):
|
||||
"""Raised when a specified provider is not found."""
|
||||
@@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import anyio
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
|
||||
|
||||
def __init__(self, default_timeout: float = 300):
|
||||
self.lock = asyncio.Lock()
|
||||
self.staged_files = {} # token: (file_path, expire_time)
|
||||
def __init__(self, default_timeout: float = 300) -> None:
|
||||
self.lock = anyio.Lock()
|
||||
self.staged_files: dict = {} # token: (file_path, expire_time)
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
async def _cleanup_expired_tokens(self):
|
||||
|
||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.star import PluginManager
|
||||
|
||||
from .context_utils import call_event_hook, call_handler
|
||||
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -15,3 +15,4 @@ class PipelineContext:
|
||||
astrbot_config_id: str
|
||||
call_handler = call_handler
|
||||
call_event_hook = call_event_hook
|
||||
call_local_llm_tool = call_local_llm_tool
|
||||
|
||||
@@ -3,6 +3,8 @@ import traceback
|
||||
import typing as T
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star import star_map
|
||||
@@ -105,3 +107,66 @@ async def call_event_hook(
|
||||
return True
|
||||
|
||||
return event.is_stopped()
|
||||
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
event = context.context.event
|
||||
|
||||
try:
|
||||
if method_name == "run" or method_name == "decorator_handler":
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
elif method_name == "call":
|
||||
ready_to_call = handler(context, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
|
||||
@@ -3,10 +3,20 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import Image
|
||||
@@ -21,19 +31,324 @@ from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from ....astr_agent_context import AgentContextWrapper
|
||||
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....astr_agent_run_util import AgentRunner, run_agent
|
||||
from ....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from ...context import PipelineContext, call_event_hook
|
||||
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
|
||||
from ..stage import Stage
|
||||
from ..utils import inject_kb_context
|
||||
|
||||
try:
|
||||
import mcp
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@classmethod
|
||||
async def execute(cls, tool, run_context, **tool_args):
|
||||
"""执行函数调用。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||||
**kwargs: 函数调用的参数。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||
|
||||
"""
|
||||
if isinstance(tool, HandoffTool):
|
||||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
elif isinstance(tool, MCPTool):
|
||||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
else:
|
||||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_handoff(
|
||||
cls,
|
||||
tool: HandoffTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
input_ = tool_args.get("input", "agent")
|
||||
agent_runner = AgentRunner()
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
if tools:
|
||||
toolset = ToolSet()
|
||||
for t in tools:
|
||||
if isinstance(t, str):
|
||||
_t = llm_tools.get_func(t)
|
||||
if _t:
|
||||
toolset.add_tool(_t)
|
||||
elif isinstance(t, FunctionTool):
|
||||
toolset.add_tool(t)
|
||||
else:
|
||||
toolset = None
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=input_,
|
||||
system_prompt=tool.description or "",
|
||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||
func_tool=toolset,
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=run_context.context.provider,
|
||||
first_provider_request=run_context.context.first_provider_request,
|
||||
curr_provider_request=request,
|
||||
streaming=run_context.context.streaming,
|
||||
event=run_context.context.event,
|
||||
)
|
||||
|
||||
event = run_context.context.event
|
||||
|
||||
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
||||
await event.send(
|
||||
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
|
||||
)
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=run_context.context.provider,
|
||||
request=request,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=run_context.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
streaming=run_context.context.streaming,
|
||||
)
|
||||
|
||||
async for _ in run_agent(agent_runner, 15, True):
|
||||
pass
|
||||
|
||||
if agent_runner.done():
|
||||
llm_response = agent_runner.get_final_llm_resp()
|
||||
|
||||
if not llm_response:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
|
||||
)
|
||||
|
||||
result = (
|
||||
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
|
||||
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
|
||||
)
|
||||
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=result,
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
event = run_context.context.event
|
||||
if not event:
|
||||
raise ValueError("Event must be provided for local function tools.")
|
||||
|
||||
is_override_call = False
|
||||
for ty in type(tool).mro():
|
||||
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
||||
logger.debug(f"Found call in: {ty}")
|
||||
is_override_call = True
|
||||
break
|
||||
|
||||
# 检查 tool 下有没有 run 方法
|
||||
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
awaitable = None
|
||||
method_name = ""
|
||||
if tool.handler:
|
||||
awaitable = tool.handler
|
||||
method_name = "decorator_handler"
|
||||
elif is_override_call:
|
||||
awaitable = tool.call
|
||||
method_name = "call"
|
||||
elif hasattr(tool, "run"):
|
||||
awaitable = getattr(tool, "run")
|
||||
method_name = "run"
|
||||
if awaitable is None:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
wrapper = call_local_llm_tool(
|
||||
context=run_context,
|
||||
handler=awaitable,
|
||||
method_name=method_name,
|
||||
**tool_args,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
anext(wrapper),
|
||||
timeout=run_context.tool_call_timeout,
|
||||
)
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
if res := run_context.context.event.get_result():
|
||||
if res.chain:
|
||||
try:
|
||||
await event.send(
|
||||
MessageChain(
|
||||
chain=res.chain,
|
||||
type="tool_direct_result",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Tool 直接发送消息失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield None
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(
|
||||
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
res = await tool.call(run_context, **tool_args)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
llm_response,
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
tool: FunctionTool[Any],
|
||||
tool_args: dict | None,
|
||||
tool_result: CallToolResult | None,
|
||||
):
|
||||
run_context.context.event.clear_result()
|
||||
|
||||
|
||||
MAIN_AGENT_HOOKS = MainAgentHooks()
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner,
|
||||
max_step: int = 30,
|
||||
show_tool_use: bool = True,
|
||||
) -> AsyncGenerator[MessageChain, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
resp.data["chain"].type = "tool_call_result"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use or astr_event.get_platform_name() == "webchat":
|
||||
resp.data["chain"].type = "tool_call"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if not agent_runner.streaming:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
astr_event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
),
|
||||
)
|
||||
yield
|
||||
astr_event.clear_result()
|
||||
elif resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
@@ -48,15 +363,11 @@ class LLMRequestSubStage(Stage):
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
]
|
||||
self.max_step: int = settings.get("max_agent_step", 30)
|
||||
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
||||
if isinstance(self.max_step, bool): # workaround: #2622
|
||||
self.max_step = 30
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
@@ -95,12 +406,63 @@ class LLMRequestSubStage(Stage):
|
||||
raise RuntimeError("无法创建新的对话。")
|
||||
return conversation
|
||||
|
||||
async def _apply_kb_context(
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""应用知识库上下文到请求中"""
|
||||
_nested: bool = False,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if self.provider_wake_prefix:
|
||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# 应用知识库
|
||||
try:
|
||||
await inject_kb_context(
|
||||
umo=event.unified_msg_origin,
|
||||
@@ -110,40 +472,43 @@ class LLMRequestSubStage(Stage):
|
||||
except Exception as e:
|
||||
logger.error(f"调用知识库时遇到问题: {e}")
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
# max context length
|
||||
if (
|
||||
self.max_context_length != -1 # -1 为不限制
|
||||
and len(req.contexts) // 2 > self.max_context_length
|
||||
):
|
||||
logger.debug("上下文长度超过限制,将截断。")
|
||||
req.contexts = req.contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(req.contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
req.contexts = req.contexts[index:]
|
||||
|
||||
return truncated_contexts
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""检查提供商的模态能力,清理请求中的不支持内容"""
|
||||
# fix messages
|
||||
req.contexts = self.fix_messages(req.contexts)
|
||||
|
||||
# check provider modalities
|
||||
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
@@ -157,13 +522,7 @@ class LLMRequestSubStage(Stage):
|
||||
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
|
||||
)
|
||||
req.func_tool = None
|
||||
|
||||
def _plugin_tool_fix(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""根据事件中的插件设置,过滤请求中的工具列表"""
|
||||
# 插件可用性设置
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
@@ -177,6 +536,80 @@ class LLMRequestSubStage(Stage):
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=provider,
|
||||
first_provider_request=req,
|
||||
curr_provider_request=req,
|
||||
streaming=self.streaming_response,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=self.streaming_response,
|
||||
)
|
||||
|
||||
if self.streaming_response:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(agent_runner, self.max_step, self.show_tool_use),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_webchat(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -224,6 +657,9 @@ class LLMRequestSubStage(Stage):
|
||||
),
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
logger.debug(
|
||||
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
|
||||
)
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
@@ -251,9 +687,6 @@ class LLMRequestSubStage(Stage):
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
if req.contexts is None:
|
||||
req.contexts = []
|
||||
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
@@ -273,7 +706,7 @@ class LLMRequestSubStage(Stage):
|
||||
history=messages,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
def fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
@@ -288,184 +721,3 @@ class LLMRequestSubStage(Stage):
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
_nested: bool = False,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if self.provider_wake_prefix and not event.message_str.startswith(
|
||||
self.provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# apply knowledge base context
|
||||
await self._apply_kb_context(event, req)
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -19,11 +20,11 @@ class RateLimitStage(Stage):
|
||||
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# 存储每个会话的请求时间队列
|
||||
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
|
||||
# 为每个会话设置一个锁,避免并发冲突
|
||||
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self.locks: defaultdict[str, anyio.Lock] = defaultdict(anyio.Lock)
|
||||
# 限流参数
|
||||
self.rate_limit_count: int = 0
|
||||
self.rate_limit_time: timedelta = timedelta(0)
|
||||
@@ -74,7 +75,7 @@ class RateLimitStage(Stage):
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。",
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
await anyio.sleep(stall_duration)
|
||||
now = datetime.now()
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
logger.info(
|
||||
|
||||
@@ -10,6 +10,7 @@ from astrbot.core.message.message_event_result import MessageChain, ResultConten
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from ..context import PipelineContext, call_event_hook
|
||||
from ..stage import Stage, register_stage
|
||||
@@ -168,15 +169,12 @@ class RespondStage(Stage):
|
||||
logger.warning("async_stream 为空,跳过发送。")
|
||||
return
|
||||
# 流式结果直接交付平台适配器处理
|
||||
realtime_segmenting = (
|
||||
self.config.get("provider_settings", {}).get(
|
||||
"unsupported_streaming_strategy",
|
||||
"realtime_segmenting",
|
||||
)
|
||||
== "realtime_segmenting"
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented",
|
||||
False,
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_id()})")
|
||||
await event.send_streaming(result.async_stream, realtime_segmenting)
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
return
|
||||
if len(result.chain) > 0:
|
||||
# 检查路径映射
|
||||
@@ -220,20 +218,21 @@ class RespondStage(Stage):
|
||||
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
|
||||
)
|
||||
return
|
||||
for comp in result.chain:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
if comp.type in need_separately:
|
||||
await event.send(MessageChain([comp]))
|
||||
else:
|
||||
await event.send(MessageChain([*header_comps, comp]))
|
||||
header_comps.clear()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for comp in result.chain:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
if comp.type in need_separately:
|
||||
await event.send(MessageChain([comp]))
|
||||
else:
|
||||
await event.send(MessageChain([*header_comps, comp]))
|
||||
header_comps.clear()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
if all(
|
||||
comp.type in {ComponentType.Reply, ComponentType.At}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -12,7 +13,7 @@ from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
|
||||
class PlatformManager:
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
||||
def __init__(self, config: AstrBotConfig, event_queue: MemoryObjectSendStream):
|
||||
self.platform_insts: list[Platform] = []
|
||||
"""加载的 Platform 的实例"""
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import abc
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
@@ -13,7 +14,7 @@ from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
class Platform(abc.ABC):
|
||||
def __init__(self, event_queue: Queue):
|
||||
def __init__(self, event_queue: MemoryObjectSendStream):
|
||||
super().__init__()
|
||||
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
||||
self._event_queue = event_queue
|
||||
@@ -45,7 +46,7 @@ class Platform(abc.ABC):
|
||||
|
||||
def commit_event(self, event: AstrMessageEvent):
|
||||
"""提交一个事件到事件队列。"""
|
||||
self._event_queue.put_nowait(event)
|
||||
self._event_queue.send_nowait(event)
|
||||
|
||||
def get_client(self):
|
||||
"""获取平台的客户端对象。"""
|
||||
|
||||
@@ -16,6 +16,3 @@ class PlatformMetadata:
|
||||
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
||||
logo_path: str | None = None
|
||||
"""平台适配器的 logo 文件路径(相对于插件目录)"""
|
||||
|
||||
support_streaming_message: bool = True
|
||||
"""平台是否支持真实流式传输"""
|
||||
|
||||
@@ -14,7 +14,6 @@ def register_platform_adapter(
|
||||
default_config_tmpl: dict | None = None,
|
||||
adapter_display_name: str | None = None,
|
||||
logo_path: str | None = None,
|
||||
support_streaming_message: bool = True,
|
||||
):
|
||||
"""用于注册平台适配器的带参装饰器。
|
||||
|
||||
@@ -43,7 +42,6 @@ def register_platform_adapter(
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
support_streaming_message=support_streaming_message,
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -29,7 +29,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
@register_platform_adapter(
|
||||
"aiocqhttp",
|
||||
"适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
support_streaming_message=False,
|
||||
)
|
||||
class AiocqhttpAdapter(Platform):
|
||||
def __init__(
|
||||
@@ -50,7 +49,6 @@ class AiocqhttpAdapter(Platform):
|
||||
name="aiocqhttp",
|
||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
self.bot = CQHttp(
|
||||
@@ -109,7 +107,7 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, event: Event) -> AstrBotMessage | None:
|
||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
logger.debug(f"[aiocqhttp] RawMessage {event}")
|
||||
|
||||
if event["post_type"] == "message":
|
||||
@@ -224,7 +222,7 @@ class AiocqhttpAdapter(Platform):
|
||||
err = f"aiocqhttp: 无法识别的消息类型: {event.message!s},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。"
|
||||
logger.critical(err)
|
||||
try:
|
||||
await self.bot.send(event, err)
|
||||
self.bot.send(event, err)
|
||||
except BaseException as e:
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -37,9 +37,7 @@ class MyEventHandler(dingtalk_stream.EventHandler):
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
|
||||
)
|
||||
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
|
||||
class DingtalkPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -76,14 +74,6 @@ class DingtalkPlatformAdapter(Platform):
|
||||
)
|
||||
self.client_ = client # 用于 websockets 的 client
|
||||
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
|
||||
if not dingtalk_id:
|
||||
return dingtalk_id
|
||||
prefix = "$:LWCP_v1:$"
|
||||
if dingtalk_id.startswith(prefix):
|
||||
return dingtalk_id[len(prefix) :]
|
||||
return dingtalk_id
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
session: MessageSesion,
|
||||
@@ -96,7 +86,6 @@ class DingtalkPlatformAdapter(Platform):
|
||||
name="dingtalk",
|
||||
description="钉钉机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(
|
||||
@@ -113,10 +102,10 @@ class DingtalkPlatformAdapter(Platform):
|
||||
else MessageType.FRIEND_MESSAGE
|
||||
)
|
||||
abm.sender = MessageMember(
|
||||
user_id=self._id_to_sid(message.sender_id),
|
||||
user_id=message.sender_id,
|
||||
nickname=message.sender_nick,
|
||||
)
|
||||
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
||||
abm.self_id = message.chatbot_user_id
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
|
||||
@@ -124,8 +113,8 @@ class DingtalkPlatformAdapter(Platform):
|
||||
# 处理所有被 @ 的用户(包括机器人自己,因 at_users 已包含)
|
||||
if message.at_users:
|
||||
for user in message.at_users:
|
||||
if id := self._id_to_sid(user.dingtalk_id):
|
||||
abm.message.append(At(qq=id))
|
||||
if user.dingtalk_id:
|
||||
abm.message.append(At(qq=user.dingtalk_id))
|
||||
abm.group_id = message.conversation_id
|
||||
if self.unique_session:
|
||||
abm.session_id = abm.sender.user_id
|
||||
@@ -227,7 +216,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
self._event_queue.send_nowait(event)
|
||||
|
||||
async def run(self):
|
||||
# await self.client_.start()
|
||||
|
||||
@@ -34,9 +34,7 @@ else:
|
||||
|
||||
|
||||
# 注册平台适配器
|
||||
@register_platform_adapter(
|
||||
"discord", "Discord 适配器 (基于 Pycord)", support_streaming_message=False
|
||||
)
|
||||
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
|
||||
class DiscordPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -92,7 +90,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
message_obj.self_id = self.client_self_id
|
||||
message_obj.session_id = session.session_id
|
||||
message_obj.message = message_chain.chain
|
||||
message_obj.message = message_chain
|
||||
|
||||
# 创建临时事件对象来发送消息
|
||||
temp_event = DiscordPlatformEvent(
|
||||
@@ -113,7 +111,6 @@ class DiscordPlatformAdapter(Platform):
|
||||
"Discord 适配器",
|
||||
id=self.config.get("id"),
|
||||
default_config_tmpl=self.config,
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
from collections.abc import AsyncGenerator
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
@@ -21,6 +20,11 @@ from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata
|
||||
from .client import DiscordBotClient
|
||||
from .components import DiscordEmbed, DiscordView
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# 自定义Discord视图组件(兼容旧版本)
|
||||
class DiscordViewComponent(BaseMessageComponent):
|
||||
@@ -44,6 +48,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.client = client
|
||||
self.interaction_followup_webhook = interaction_followup_webhook
|
||||
|
||||
@override
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息到Discord平台"""
|
||||
# 解析消息链为 Discord 所需的对象
|
||||
@@ -92,21 +97,6 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return None
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _get_channel(self) -> discord.abc.Messageable | None:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
@@ -193,7 +183,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
BytesIO(img_bytes),
|
||||
filename=filename or "image.png",
|
||||
)
|
||||
except (ValueError, TypeError, binascii.Error):
|
||||
except (ValueError, TypeError, base64.binascii.Error):
|
||||
logger.debug(
|
||||
f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}",
|
||||
)
|
||||
|
||||
@@ -23,9 +23,7 @@ from ...register import register_platform_adapter
|
||||
from .lark_event import LarkMessageEvent
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"lark", "飞书机器人官方 API 适配器", support_streaming_message=False
|
||||
)
|
||||
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
|
||||
class LarkPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -117,7 +115,6 @@ class LarkPlatformAdapter(Platform):
|
||||
name="lark",
|
||||
description="飞书机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
@@ -227,7 +224,7 @@ class LarkPlatformAdapter(Platform):
|
||||
bot=self.lark_api,
|
||||
)
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
self._event_queue.send_nowait(event)
|
||||
|
||||
async def run(self):
|
||||
# self.client.start()
|
||||
|
||||
@@ -45,9 +45,7 @@ MAX_FILE_UPLOAD_COUNT = 16
|
||||
DEFAULT_UPLOAD_CONCURRENCY = 3
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"misskey", "Misskey 平台适配器", support_streaming_message=False
|
||||
)
|
||||
@register_platform_adapter("misskey", "Misskey 平台适配器")
|
||||
class MisskeyPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -122,7 +120,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
description="Misskey 平台适配器",
|
||||
id=self.config.get("id", "misskey"),
|
||||
default_config_tmpl=default_config,
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
||||
@@ -29,7 +29,8 @@ from astrbot.core.platform.astr_message_event import MessageSession
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"satori", "Satori 协议适配器", support_streaming_message=False
|
||||
"satori",
|
||||
"Satori 协议适配器",
|
||||
)
|
||||
class SatoriPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
@@ -59,7 +60,6 @@ class SatoriPlatformAdapter(Platform):
|
||||
name="satori",
|
||||
description="Satori 通用协议适配器",
|
||||
id=self.config["id"],
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
self.ws: ClientConnection | None = None
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import anyio
|
||||
from quart import Quart, Response, request
|
||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
@@ -40,7 +40,7 @@ class SlackWebhookClient:
|
||||
logging.getLogger("quart.app").setLevel(logging.WARNING)
|
||||
logging.getLogger("quart.serving").setLevel(logging.WARNING)
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.shutdown_event = anyio.Event()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置路由"""
|
||||
|
||||
@@ -30,7 +30,6 @@ from .slack_event import SlackMessageEvent
|
||||
@register_platform_adapter(
|
||||
"slack",
|
||||
"适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
support_streaming_message=False,
|
||||
)
|
||||
class SlackAdapter(Platform):
|
||||
def __init__(
|
||||
@@ -69,7 +68,6 @@ class SlackAdapter(Platform):
|
||||
name="slack",
|
||||
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
# 初始化 Slack Web Client
|
||||
@@ -84,7 +82,7 @@ class SlackAdapter(Platform):
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
blocks, text = await SlackMessageEvent._parse_slack_blocks(
|
||||
blocks, text = SlackMessageEvent._parse_slack_blocks(
|
||||
message_chain=message_chain,
|
||||
web_client=self.web_client,
|
||||
)
|
||||
|
||||
@@ -163,9 +163,6 @@ class WebChatAdapter(Platform):
|
||||
_, _, payload = message.raw_message # type: ignore
|
||||
message_event.set_extra("selected_provider", payload.get("selected_provider"))
|
||||
message_event.set_extra("selected_model", payload.get("selected_model"))
|
||||
message_event.set_extra(
|
||||
"enable_streaming", payload.get("enable_streaming", True)
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
|
||||
@@ -109,7 +109,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
reasoning_content = ""
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
async for chain in generator:
|
||||
@@ -125,22 +124,16 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
|
||||
r = await WebChatMessageEvent._send(
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
chain,
|
||||
session_id=self.session_id,
|
||||
streaming=True,
|
||||
)
|
||||
if chain.type == "reasoning":
|
||||
reasoning_content += chain.get_plain_text()
|
||||
else:
|
||||
final_data += r
|
||||
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "complete", # complete means we return the final result
|
||||
"data": final_data,
|
||||
"reasoning": reasoning_content,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
|
||||
@@ -32,9 +32,7 @@ except ImportError as e:
|
||||
)
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
|
||||
)
|
||||
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
|
||||
class WeChatPadProAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -53,7 +51,6 @@ class WeChatPadProAdapter(Platform):
|
||||
name="wechatpadpro",
|
||||
description="WeChatPadPro 消息平台适配器",
|
||||
id=self.config.get("id", "wechatpadpro"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
# 保存配置信息
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
@@ -51,21 +50,6 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||
await self._send_voice(session, comp)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return None
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
|
||||
b64 = await comp.convert_to_base64()
|
||||
raw = self._validate_base64(b64)
|
||||
|
||||
@@ -110,7 +110,7 @@ class WecomServer:
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
|
||||
@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
|
||||
@register_platform_adapter("wecom", "wecom 适配器")
|
||||
class WecomPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -196,7 +196,6 @@ class WecomPlatformAdapter(Platform):
|
||||
"wecom",
|
||||
"wecom 适配器",
|
||||
id=self.config.get("id", "wecom"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -10,7 +10,7 @@ import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import random
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
@@ -139,12 +139,6 @@ class PKCS7Encoder:
|
||||
class Prpcrypt:
|
||||
"""提供接收和推送给企业微信消息的加解密接口"""
|
||||
|
||||
# 16位随机字符串的范围常量
|
||||
# randbelow(RANDOM_RANGE) 返回 [0, 8999999999999999](两端都包含,即包含0和8999999999999999)
|
||||
# 加上 MIN_RANDOM_VALUE 后得到 [1000000000000000, 9999999999999999](两端都包含)即16位数字
|
||||
MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位)
|
||||
RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位)
|
||||
|
||||
def __init__(self, key):
|
||||
# self.key = base64.b64decode(key+"=")
|
||||
self.key = key
|
||||
@@ -213,9 +207,7 @@ class Prpcrypt:
|
||||
"""随机生成16位字符串
|
||||
@return: 16位字符串
|
||||
"""
|
||||
return str(
|
||||
secrets.randbelow(self.RANDOM_RANGE) + self.MIN_RANDOM_VALUE
|
||||
).encode()
|
||||
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
||||
|
||||
|
||||
class WXBizJsonMsgCrypt:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""企业微信智能机器人 API 客户端
|
||||
"""企业微信智能机器人 API 客户端.
|
||||
|
||||
处理消息加密解密、API 调用等
|
||||
"""
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
处理企业微信智能机器人的 HTTP 回调请求
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import quart
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -41,7 +41,7 @@ class WecomAIBotServer:
|
||||
self.app = quart.Quart(__name__)
|
||||
self._setup_routes()
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.shutdown_event = anyio.Event()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置 Quart 路由"""
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import random
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
@@ -53,7 +53,7 @@ def generate_random_string(length: int = 10) -> str:
|
||||
|
||||
"""
|
||||
letters = string.ascii_letters + string.digits
|
||||
return "".join(secrets.choice(letters) for _ in range(length))
|
||||
return "".join(random.choice(letters) for _ in range(length))
|
||||
|
||||
|
||||
def calculate_image_md5(image_data: bytes) -> str:
|
||||
|
||||
@@ -113,9 +113,7 @@ class WecomServer:
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"weixin_official_account", "微信公众平台 适配器", support_streaming_message=False
|
||||
)
|
||||
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
|
||||
class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -197,7 +195,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
"weixin_official_account",
|
||||
"微信公众平台 适配器",
|
||||
id=self.config.get("id", "weixin_official_account"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .entities import ProviderMetaData
|
||||
from .provider import Provider, STTProvider
|
||||
from .provider import Personality, Provider, STTProvider
|
||||
|
||||
__all__ = ["Provider", "ProviderMetaData", "STTProvider"]
|
||||
__all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"]
|
||||
|
||||
@@ -30,31 +30,18 @@ class ProviderType(enum.Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta:
|
||||
"""The basic metadata of a provider instance."""
|
||||
|
||||
id: str
|
||||
"""the unique id of the provider instance that user configured"""
|
||||
model: str | None
|
||||
"""the model name of the provider instance currently used"""
|
||||
class ProviderMetaData:
|
||||
type: str
|
||||
"""the name of the provider adapter, such as openai, ollama"""
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
"""the capability type of the provider adapter"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData(ProviderMeta):
|
||||
"""The metadata of a provider adapter for registration."""
|
||||
|
||||
"""提供商适配器名称,如 openai, ollama"""
|
||||
desc: str = ""
|
||||
"""the short description of the provider adapter"""
|
||||
"""提供商适配器描述"""
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
cls_type: Any = None
|
||||
"""the class type of the provider adapter"""
|
||||
|
||||
default_config_tmpl: dict | None = None
|
||||
"""the default configuration template of the provider adapter"""
|
||||
"""平台的默认配置模板"""
|
||||
provider_display_name: str | None = None
|
||||
"""the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
|
||||
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -73,20 +60,12 @@ class ToolCallsResult:
|
||||
]
|
||||
return ret
|
||||
|
||||
def to_openai_messages_model(
|
||||
self,
|
||||
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
|
||||
return [
|
||||
self.tool_calls_info,
|
||||
*self.tool_calls_result,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest:
|
||||
prompt: str | None = None
|
||||
prompt: str
|
||||
"""提示词"""
|
||||
session_id: str | None = ""
|
||||
session_id: str = ""
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
@@ -202,28 +181,25 @@ class ProviderRequest:
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
"""The role of the message, e.g., assistant, tool, err"""
|
||||
"""角色, assistant, tool, err"""
|
||||
result_chain: MessageChain | None = None
|
||||
"""A chain of message components representing the text completion from LLM."""
|
||||
"""返回的消息链"""
|
||||
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
|
||||
"""Tool call arguments."""
|
||||
"""工具调用参数"""
|
||||
tools_call_name: list[str] = field(default_factory=list)
|
||||
"""Tool call names."""
|
||||
"""工具调用名称"""
|
||||
tools_call_ids: list[str] = field(default_factory=list)
|
||||
"""Tool call IDs."""
|
||||
reasoning_content: str = ""
|
||||
"""The reasoning content extracted from the LLM, if any."""
|
||||
"""工具调用 ID"""
|
||||
|
||||
raw_completion: (
|
||||
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
|
||||
) = None
|
||||
"""The raw completion response from the LLM provider."""
|
||||
_new_record: dict[str, Any] | None = None
|
||||
|
||||
_completion_text: str = ""
|
||||
"""The plain text of the completion."""
|
||||
|
||||
is_chunk: bool = False
|
||||
"""Indicates if the response is a chunked response."""
|
||||
"""是否是流式输出的单个 Chunk"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -237,6 +213,7 @@ class LLMResponse:
|
||||
| GenerateContentResponse
|
||||
| AnthropicMessage
|
||||
| None = None,
|
||||
_new_record: dict[str, Any] | None = None,
|
||||
is_chunk: bool = False,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
@@ -264,6 +241,7 @@ class LLMResponse:
|
||||
self.tools_call_name = tools_call_name
|
||||
self.tools_call_ids = tools_call_ids
|
||||
self.raw_completion = raw_completion
|
||||
self._new_record = _new_record
|
||||
self.is_chunk = is_chunk
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
@@ -25,16 +25,7 @@ SUPPORTED_TYPES = [
|
||||
"boolean",
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
PY_TO_JSON_TYPE = {
|
||||
"int": "number",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"str": "string",
|
||||
"dict": "object",
|
||||
"list": "array",
|
||||
"tuple": "array",
|
||||
"set": "array",
|
||||
}
|
||||
|
||||
# alias
|
||||
FuncTool = FunctionTool
|
||||
|
||||
@@ -108,7 +99,7 @@ class FunctionToolManager:
|
||||
self.func_list: list[FuncTool] = []
|
||||
self.mcp_client_dict: dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_client_event: dict[str, asyncio.Event] = {}
|
||||
self.mcp_client_event: dict[str, anyio.Event] = {}
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
@@ -116,7 +107,7 @@ class FunctionToolManager:
|
||||
def spec_to_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list[dict],
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
) -> FuncTool:
|
||||
@@ -125,9 +116,10 @@ class FunctionToolManager:
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
p = copy.deepcopy(param)
|
||||
p.pop("name", None)
|
||||
params["properties"][param["name"]] = p
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
return FuncTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
@@ -215,7 +207,7 @@ class FunctionToolManager:
|
||||
for name in mcp_server_json_obj:
|
||||
cfg = mcp_server_json_obj[name]
|
||||
if cfg.get("active", True):
|
||||
event = asyncio.Event()
|
||||
event = anyio.Event()
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, cfg, event),
|
||||
)
|
||||
@@ -225,7 +217,7 @@ class FunctionToolManager:
|
||||
self,
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
event: anyio.Event,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
@@ -316,7 +308,7 @@ class FunctionToolManager:
|
||||
self,
|
||||
name: str,
|
||||
config: dict,
|
||||
event: asyncio.Event | None = None,
|
||||
event: anyio.Event | None = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
@@ -325,7 +317,7 @@ class FunctionToolManager:
|
||||
Args:
|
||||
name (str): The name of the MCP server.
|
||||
config (dict): Configuration for the MCP server.
|
||||
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||
event (anyio.Event): Event to signal when the MCP client is ready.
|
||||
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||
timeout (int): Timeout for the initialization.
|
||||
|
||||
@@ -335,7 +327,7 @@ class FunctionToolManager:
|
||||
|
||||
"""
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
event = anyio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
|
||||
@@ -241,8 +241,6 @@ class ProviderManager:
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
@@ -356,8 +354,6 @@ class ProviderManager:
|
||||
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
||||
return
|
||||
|
||||
provider_metadata.id = provider_config["id"]
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
@@ -398,6 +394,7 @@ class ProviderManager:
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.selected_default_persona,
|
||||
)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
|
||||
@@ -1,18 +1,28 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.db.po import Personality
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderMeta,
|
||||
ProviderType,
|
||||
RerankResult,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta:
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
provider_type: ProviderType
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
"""Provider Abstract Class"""
|
||||
|
||||
@@ -33,15 +43,15 @@ class AbstractProvider(abc.ABC):
|
||||
"""Get the provider metadata"""
|
||||
provider_type_name = self.provider_config["type"]
|
||||
meta_data = provider_cls_map.get(provider_type_name)
|
||||
if not meta_data:
|
||||
raise ValueError(f"Provider type {provider_type_name} not registered")
|
||||
meta = ProviderMeta(
|
||||
id=self.provider_config.get("id", "default"),
|
||||
provider_type = meta_data.provider_type if meta_data else None
|
||||
if provider_type is None:
|
||||
raise ValueError(f"Cannot find provider type: {provider_type_name}")
|
||||
return ProviderMeta(
|
||||
id=self.provider_config["id"],
|
||||
model=self.get_model(),
|
||||
type=provider_type_name,
|
||||
provider_type=meta_data.provider_type,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
return meta
|
||||
|
||||
|
||||
class Provider(AbstractProvider):
|
||||
@@ -51,10 +61,15 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality = default_persona
|
||||
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -36,8 +36,6 @@ def register_provider_adapter(
|
||||
default_config_tmpl["id"] = provider_type_name
|
||||
|
||||
pm = ProviderMetaData(
|
||||
id="default", # will be replaced when instantiated
|
||||
model=None,
|
||||
type=provider_type_name,
|
||||
desc=desc,
|
||||
provider_type=provider_type,
|
||||
|
||||
@@ -25,10 +25,12 @@ class ProviderAnthropic(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
self.chosen_api_key: str = ""
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -54,9 +54,7 @@ class OTTSProvider:
|
||||
async def _generate_signature(self) -> str:
|
||||
await self._sync_time()
|
||||
timestamp = int(time.time()) + self.time_offset
|
||||
nonce = "".join(
|
||||
secrets.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(10)
|
||||
)
|
||||
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
|
||||
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
||||
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
||||
|
||||
|
||||
@@ -20,10 +20,12 @@ class ProviderCoze(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -8,7 +8,7 @@ from dashscope.app.application_response import ApplicationResponse
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from .. import Provider
|
||||
from .. import Personality, Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
@@ -20,11 +20,13 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
Provider.__init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -18,10 +18,12 @@ class ProviderDify(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -53,10 +53,12 @@ class ProviderGoogleGenAI(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_keys: list = super().get_keys()
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
@@ -324,18 +326,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
return gemini_contents
|
||||
|
||||
def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
|
||||
"""Extract reasoning content from candidate parts"""
|
||||
if not candidate.content or not candidate.content.parts:
|
||||
return ""
|
||||
|
||||
thought_buf: list[str] = [
|
||||
(p.text or "") for p in candidate.content.parts if p.thought
|
||||
]
|
||||
return "".join(thought_buf).strip()
|
||||
|
||||
@staticmethod
|
||||
def _process_content_parts(
|
||||
self,
|
||||
candidate: types.Candidate,
|
||||
llm_response: LLMResponse,
|
||||
) -> MessageChain:
|
||||
@@ -366,11 +358,6 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
|
||||
raise Exception("API 返回的 candidate.content.parts 为空。")
|
||||
|
||||
# 提取 reasoning content
|
||||
reasoning = self._extract_reasoning_content(candidate)
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
|
||||
chain = []
|
||||
part: types.Part
|
||||
|
||||
@@ -528,7 +515,6 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
# Accumulate the complete response text for the final response
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
final_response = None
|
||||
|
||||
async for chunk in result:
|
||||
@@ -553,19 +539,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
yield llm_response
|
||||
return
|
||||
|
||||
_f = False
|
||||
|
||||
# 提取 reasoning content
|
||||
reasoning = self._extract_reasoning_content(chunk.candidates[0])
|
||||
if reasoning:
|
||||
_f = True
|
||||
accumulated_reasoning += reasoning
|
||||
llm_response.reasoning_content = reasoning
|
||||
if chunk.text:
|
||||
_f = True
|
||||
accumulated_text += chunk.text
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
if _f:
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
@@ -583,10 +559,6 @@ class ProviderGoogleGenAI(Provider):
|
||||
if not final_response:
|
||||
final_response = LLMResponse("assistant", is_chunk=False)
|
||||
|
||||
# Set the complete accumulated reasoning in the final response
|
||||
if accumulated_reasoning:
|
||||
final_response.reasoning_content = accumulated_reasoning
|
||||
|
||||
# Set the complete accumulated text in the final response
|
||||
if accumulated_text:
|
||||
final_response.result_chain = MessageChain(
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"groq_chat_completion", "Groq Chat Completion Provider Adapter"
|
||||
)
|
||||
class ProviderGroq(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.reasoning_key = "reasoning"
|
||||
@@ -4,14 +4,12 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
@@ -30,37 +28,37 @@ from ..register import register_provider_adapter
|
||||
"OpenAI API Chat Completion 提供商适配器",
|
||||
)
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(self, provider_config, provider_settings) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: list = super().get_keys()
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
self.custom_headers = provider_config.get("custom_headers", {})
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
if not isinstance(self.custom_headers, dict) or not self.custom_headers:
|
||||
self.custom_headers = None
|
||||
else:
|
||||
for key in self.custom_headers:
|
||||
self.custom_headers[key] = str(self.custom_headers[key])
|
||||
|
||||
# 适配 azure openai #332
|
||||
if "api_version" in provider_config:
|
||||
# Using Azure OpenAI API
|
||||
# 使用 azure api
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
api_version=provider_config.get("api_version", None),
|
||||
default_headers=self.custom_headers,
|
||||
base_url=provider_config.get("api_base", ""),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
else:
|
||||
# Using OpenAI Official API
|
||||
# 使用 openai api
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
default_headers=self.custom_headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
@@ -72,8 +70,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
model = model_config.get("model", "unknown")
|
||||
self.set_model(model)
|
||||
|
||||
self.reasoning_key = "reasoning_content"
|
||||
|
||||
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
|
||||
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
|
||||
|
||||
@@ -151,7 +147,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
llm_response = await self._parse_openai_completion(completion, tools)
|
||||
llm_response = await self.parse_openai_completion(completion, tools)
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -204,78 +200,36 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
# logger.debug(f"chunk delta: {delta}")
|
||||
# handle the content delta
|
||||
reasoning = self._extract_reasoning_content(chunk)
|
||||
_y = False
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
_y = True
|
||||
# 处理文本内容
|
||||
if delta.content:
|
||||
completion_text = delta.content
|
||||
llm_response.result_chain = MessageChain(
|
||||
chain=[Comp.Plain(completion_text)],
|
||||
)
|
||||
_y = True
|
||||
if _y:
|
||||
yield llm_response
|
||||
|
||||
final_completion = state.get_final_completion()
|
||||
llm_response = await self._parse_openai_completion(final_completion, tools)
|
||||
llm_response = await self.parse_openai_completion(final_completion, tools)
|
||||
|
||||
yield llm_response
|
||||
|
||||
def _extract_reasoning_content(
|
||||
self,
|
||||
completion: ChatCompletion | ChatCompletionChunk,
|
||||
) -> str:
|
||||
"""Extract reasoning content from OpenAI ChatCompletion if available."""
|
||||
reasoning_text = ""
|
||||
if len(completion.choices) == 0:
|
||||
return reasoning_text
|
||||
if isinstance(completion, ChatCompletion):
|
||||
choice = completion.choices[0]
|
||||
reasoning_attr = getattr(choice.message, self.reasoning_key, None)
|
||||
if reasoning_attr:
|
||||
reasoning_text = str(reasoning_attr)
|
||||
elif isinstance(completion, ChatCompletionChunk):
|
||||
delta = completion.choices[0].delta
|
||||
reasoning_attr = getattr(delta, self.reasoning_key, None)
|
||||
if reasoning_attr:
|
||||
reasoning_text = str(reasoning_attr)
|
||||
return reasoning_text
|
||||
|
||||
async def _parse_openai_completion(
|
||||
async def parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: ToolSet | None
|
||||
) -> LLMResponse:
|
||||
"""Parse OpenAI ChatCompletion into LLMResponse"""
|
||||
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
# parse the text completion
|
||||
if choice.message.content is not None:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
# specially, some providers may set <think> tags around reasoning content in the completion text,
|
||||
# we use regex to remove them, and store then in reasoning_content field
|
||||
reasoning_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
matches = reasoning_pattern.findall(completion_text)
|
||||
if matches:
|
||||
llm_response.reasoning_content = "\n".join(
|
||||
[match.strip() for match in matches],
|
||||
)
|
||||
completion_text = reasoning_pattern.sub("", completion_text).strip()
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
# parse the reasoning content if any
|
||||
# the priority is higher than the <think> tag extraction
|
||||
llm_response.reasoning_content = self._extract_reasoning_content(completion)
|
||||
|
||||
# parse tool calls if any
|
||||
if choice.message.tool_calls and tools is not None:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
tool_call_ids = []
|
||||
@@ -301,11 +255,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
llm_response.tools_call_ids = tool_call_ids
|
||||
|
||||
# specially handle finish reason
|
||||
if choice.finish_reason == "content_filter":
|
||||
raise Exception(
|
||||
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
|
||||
)
|
||||
|
||||
if llm_response.completion_text is None and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
@@ -12,5 +12,10 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import logging
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
from deprecated import deprecated
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
@@ -17,10 +13,10 @@ from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.provider.provider import (
|
||||
@@ -35,7 +31,6 @@ from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
)
|
||||
|
||||
from ..exceptions import ProviderNotFoundError
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
@@ -55,7 +50,7 @@ class Context:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: Queue,
|
||||
event_queue: MemoryObjectSendStream,
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager,
|
||||
@@ -80,153 +75,6 @@ class Context:
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
self.kb_manager = knowledge_base_manager
|
||||
|
||||
async def llm_generate(
|
||||
self,
|
||||
*,
|
||||
chat_provider_id: str,
|
||||
prompt: str | None = None,
|
||||
image_urls: list[str] | None = None,
|
||||
tools: ToolSet | None = None,
|
||||
system_prompt: str | None = None,
|
||||
contexts: list[Message] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
|
||||
|
||||
.. versionadded:: 4.5.7 (sdk)
|
||||
|
||||
Args:
|
||||
chat_provider_id: The chat provider ID to use.
|
||||
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
|
||||
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
|
||||
tools: ToolSet of tools available to the LLM
|
||||
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
|
||||
contexts: context messages for the LLM
|
||||
**kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
|
||||
|
||||
Raises:
|
||||
ChatProviderNotFoundError: If the specified chat provider ID is not found
|
||||
Exception: For other errors during LLM generation
|
||||
"""
|
||||
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
|
||||
if not prov or not isinstance(prov, Provider):
|
||||
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt=prompt,
|
||||
image_urls=image_urls,
|
||||
func_tool=tools,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
return llm_resp
|
||||
|
||||
async def tool_loop_agent(
|
||||
self,
|
||||
*,
|
||||
event: AstrMessageEvent,
|
||||
chat_provider_id: str,
|
||||
prompt: str | None = None,
|
||||
image_urls: list[str] | None = None,
|
||||
tools: ToolSet | None = None,
|
||||
system_prompt: str | None = None,
|
||||
contexts: list[Message] | None = None,
|
||||
max_steps: int = 30,
|
||||
tool_call_timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
|
||||
If you do not pass the agent_context parameter, the method will recreate a new agent context.
|
||||
|
||||
.. versionadded:: 4.5.7 (sdk)
|
||||
|
||||
Args:
|
||||
chat_provider_id: The chat provider ID to use.
|
||||
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
|
||||
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
|
||||
tools: ToolSet of tools available to the LLM
|
||||
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
|
||||
contexts: context messages for the LLM
|
||||
max_steps: Maximum number of tool calls before stopping the loop
|
||||
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
||||
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
||||
agent_context: AstrAgentContext - context to use for the agent
|
||||
|
||||
Returns:
|
||||
The final LLMResponse after tool calls are completed.
|
||||
|
||||
Raises:
|
||||
ChatProviderNotFoundError: If the specified chat provider ID is not found
|
||||
Exception: For other errors during LLM generation
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from astrbot.core.astr_agent_context import (
|
||||
AgentContextWrapper,
|
||||
AstrAgentContext,
|
||||
)
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
|
||||
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
|
||||
if not prov or not isinstance(prov, Provider):
|
||||
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
|
||||
|
||||
agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
|
||||
agent_context = kwargs.get("agent_context")
|
||||
|
||||
context_ = []
|
||||
for msg in contexts or []:
|
||||
if isinstance(msg, Message):
|
||||
context_.append(msg.model_dump())
|
||||
else:
|
||||
context_.append(msg)
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=prompt,
|
||||
image_urls=image_urls or [],
|
||||
func_tool=tools,
|
||||
contexts=context_,
|
||||
system_prompt=system_prompt or "",
|
||||
)
|
||||
if agent_context is None:
|
||||
agent_context = AstrAgentContext(
|
||||
context=self,
|
||||
event=event,
|
||||
)
|
||||
agent_runner = ToolLoopAgentRunner()
|
||||
tool_executor = FunctionToolExecutor()
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
request=request,
|
||||
run_context=AgentContextWrapper(
|
||||
context=agent_context,
|
||||
tool_call_timeout=tool_call_timeout,
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
agent_hooks=agent_hooks,
|
||||
streaming=kwargs.get("stream", False),
|
||||
)
|
||||
async for _ in agent_runner.step_until_done(max_steps):
|
||||
pass
|
||||
llm_resp = agent_runner.get_final_llm_resp()
|
||||
if not llm_resp:
|
||||
raise Exception("Agent did not produce a final LLM response")
|
||||
return llm_resp
|
||||
|
||||
async def get_current_chat_provider_id(self, umo: str) -> str:
|
||||
"""Get the ID of the currently used chat provider.
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
|
||||
|
||||
Raises:
|
||||
ProviderNotFoundError: If the specified chat provider is not found
|
||||
|
||||
"""
|
||||
prov = self.get_using_provider(umo)
|
||||
if not prov:
|
||||
raise ProviderNotFoundError("Provider not found")
|
||||
return prov.meta().id
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata | None:
|
||||
"""根据插件名获取插件的 Metadata"""
|
||||
for star in star_registry:
|
||||
@@ -259,6 +107,10 @@ class Context:
|
||||
"""
|
||||
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(
|
||||
self,
|
||||
provider_id: str,
|
||||
@@ -337,6 +189,45 @@ class Context:
|
||||
return self._config
|
||||
return self.astrbot_config_mgr.get_conf(umo)
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
"""获取 AstrBot 数据库。"""
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> MemoryObjectSendStream:
|
||||
"""获取事件队列。"""
|
||||
return self._event_queue
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
||||
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
|
||||
"""获取指定类型的平台适配器。
|
||||
|
||||
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
name = platform.meta().name
|
||||
if isinstance(platform_type, str):
|
||||
if name == platform_type:
|
||||
return platform
|
||||
elif (
|
||||
name in ADAPTER_NAME_2_TYPE
|
||||
and ADAPTER_NAME_2_TYPE[name] & platform_type
|
||||
):
|
||||
return platform
|
||||
|
||||
def get_platform_inst(self, platform_id: str) -> Platform | None:
|
||||
"""获取指定 ID 的平台适配器实例。
|
||||
|
||||
Args:
|
||||
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
|
||||
|
||||
Returns:
|
||||
Platform: 平台适配器实例,如果未找到则返回 None。
|
||||
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().id == platform_id:
|
||||
return platform
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
session: str | MessageSesion,
|
||||
@@ -409,49 +300,6 @@ class Context:
|
||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||
"""
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
"""获取事件队列。"""
|
||||
return self._event_queue
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
||||
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
|
||||
"""获取指定类型的平台适配器。
|
||||
|
||||
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
name = platform.meta().name
|
||||
if isinstance(platform_type, str):
|
||||
if name == platform_type:
|
||||
return platform
|
||||
elif (
|
||||
name in ADAPTER_NAME_2_TYPE
|
||||
and ADAPTER_NAME_2_TYPE[name] & platform_type
|
||||
):
|
||||
return platform
|
||||
|
||||
def get_platform_inst(self, platform_id: str) -> Platform | None:
|
||||
"""获取指定 ID 的平台适配器实例。
|
||||
|
||||
Args:
|
||||
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
|
||||
|
||||
Returns:
|
||||
Platform: 平台适配器实例,如果未找到则返回 None。
|
||||
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().id == platform_id:
|
||||
return platform
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
"""获取 AstrBot 数据库。"""
|
||||
return self._db
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def register_llm_tool(
|
||||
self,
|
||||
name: str,
|
||||
|
||||
@@ -96,7 +96,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
prefix + "│ ",
|
||||
event=event,
|
||||
cfg=cfg,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
@@ -12,7 +11,7 @@ from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from ..filter.command import CommandFilter
|
||||
@@ -418,37 +417,18 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
||||
docstring = docstring_parser.parse(func_doc)
|
||||
args = []
|
||||
for arg in docstring.params:
|
||||
sub_type_name = None
|
||||
type_name = arg.type_name
|
||||
if not type_name:
|
||||
raise ValueError(
|
||||
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 的参数 {arg.arg_name} 缺少类型注释。",
|
||||
)
|
||||
# parse type_name to handle cases like "list[string]"
|
||||
match = re.match(r"(\w+)\[(\w+)\]", type_name)
|
||||
if match:
|
||||
type_name = match.group(1)
|
||||
sub_type_name = match.group(2)
|
||||
type_name = PY_TO_JSON_TYPE.get(type_name, type_name)
|
||||
if sub_type_name:
|
||||
sub_type_name = PY_TO_JSON_TYPE.get(sub_type_name, sub_type_name)
|
||||
if type_name not in SUPPORTED_TYPES or (
|
||||
sub_type_name and sub_type_name not in SUPPORTED_TYPES
|
||||
):
|
||||
if arg.type_name not in SUPPORTED_TYPES:
|
||||
raise ValueError(
|
||||
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}",
|
||||
)
|
||||
|
||||
arg_json_schema = {
|
||||
"type": type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description,
|
||||
}
|
||||
if sub_type_name:
|
||||
if type_name == "array":
|
||||
arg_json_schema["items"] = {"type": sub_type_name}
|
||||
args.append(arg_json_schema)
|
||||
|
||||
args.append(
|
||||
{
|
||||
"type": arg.type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description,
|
||||
},
|
||||
)
|
||||
# print(llm_tool_name, registering_agent)
|
||||
if not registering_agent:
|
||||
doc_desc = docstring.description.strip() if docstring.description else ""
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
|
||||
@@ -680,18 +680,11 @@ class PluginManager:
|
||||
|
||||
return plugin_info
|
||||
|
||||
async def uninstall_plugin(
|
||||
self,
|
||||
plugin_name: str,
|
||||
delete_config: bool = False,
|
||||
delete_data: bool = False,
|
||||
):
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
"""卸载指定的插件。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要卸载的插件名称
|
||||
delete_config (bool): 是否删除插件配置文件,默认为 False
|
||||
delete_data (bool): 是否删除插件数据,默认为 False
|
||||
|
||||
Raises:
|
||||
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
|
||||
@@ -721,7 +714,6 @@ class PluginManager:
|
||||
|
||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||
|
||||
# 删除插件文件夹
|
||||
try:
|
||||
remove_dir(os.path.join(ppath, root_dir_name))
|
||||
except Exception as e:
|
||||
@@ -729,51 +721,6 @@ class PluginManager:
|
||||
f"移除插件成功,但是删除插件文件夹失败: {e!s}。您可以手动删除该文件夹,位于 addons/plugins/ 下。",
|
||||
)
|
||||
|
||||
# 删除插件配置文件
|
||||
if delete_config and root_dir_name:
|
||||
config_file = os.path.join(
|
||||
self.plugin_config_path,
|
||||
f"{root_dir_name}_config.json",
|
||||
)
|
||||
if os.path.exists(config_file):
|
||||
try:
|
||||
os.remove(config_file)
|
||||
logger.info(f"已删除插件 {plugin_name} 的配置文件")
|
||||
except Exception as e:
|
||||
logger.warning(f"删除插件配置文件失败: {e!s}")
|
||||
|
||||
# 删除插件持久化数据
|
||||
# 注意:需要检查两个可能的目录名(plugin_data 和 plugins_data)
|
||||
# data/temp 目录可能被多个插件共享,不自动删除以防误删
|
||||
if delete_data and root_dir_name:
|
||||
data_base_dir = os.path.dirname(ppath) # data/
|
||||
|
||||
# 删除 data/plugin_data 下的插件持久化数据(单数形式,新版本)
|
||||
plugin_data_dir = os.path.join(
|
||||
data_base_dir, "plugin_data", root_dir_name
|
||||
)
|
||||
if os.path.exists(plugin_data_dir):
|
||||
try:
|
||||
remove_dir(plugin_data_dir)
|
||||
logger.info(
|
||||
f"已删除插件 {plugin_name} 的持久化数据 (plugin_data)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"删除插件持久化数据失败 (plugin_data): {e!s}")
|
||||
|
||||
# 删除 data/plugins_data 下的插件持久化数据(复数形式,旧版本兼容)
|
||||
plugins_data_dir = os.path.join(
|
||||
data_base_dir, "plugins_data", root_dir_name
|
||||
)
|
||||
if os.path.exists(plugins_data_dir):
|
||||
try:
|
||||
remove_dir(plugins_data_dir)
|
||||
logger.info(
|
||||
f"已删除插件 {plugin_name} 的持久化数据 (plugins_data)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"删除插件持久化数据失败 (plugins_data): {e!s}")
|
||||
|
||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||
"""解绑并移除一个插件。
|
||||
|
||||
|
||||
@@ -30,7 +30,9 @@ class UmopConfigRouter:
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
return all(
|
||||
p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls, strict=False)
|
||||
)
|
||||
|
||||
def get_conf_id_for_umop(self, umo: str) -> str | None:
|
||||
"""根据 UMO 获取对应的配置文件 ID
|
||||
|
||||
@@ -105,31 +105,16 @@ async def download_image_by_url(
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
||||
# 关闭SSL验证(仅在证书验证失败时作为fallback)
|
||||
logger.warning(
|
||||
f"SSL certificate verification failed for {url}. "
|
||||
"Disabling SSL verification (CERT_NONE) as a fallback. "
|
||||
"This is insecure and exposes the application to man-in-the-middle attacks. "
|
||||
"Please investigate and resolve certificate issues."
|
||||
)
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
ssl_context.set_ciphers("DEFAULT")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if post:
|
||||
async with session.post(url, json=post_data, ssl=ssl_context) as resp:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
return save_temp_img(await resp.read())
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -172,19 +157,9 @@ async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
end="",
|
||||
)
|
||||
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
||||
# 关闭SSL验证(仅在证书验证失败时作为fallback)
|
||||
logger.warning(
|
||||
"SSL 证书验证失败,已关闭 SSL 验证(不安全,仅用于临时下载)。请检查目标服务器的证书配置。"
|
||||
)
|
||||
logger.warning(
|
||||
f"SSL certificate verification failed for {url}. "
|
||||
"Falling back to unverified connection (CERT_NONE). "
|
||||
"This is insecure and exposes the application to man-in-the-middle attacks. "
|
||||
"Please investigate certificate issues with the remote server."
|
||||
)
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
ssl_context.set_ciphers("DEFAULT")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""会话控制"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import copy
|
||||
@@ -8,11 +10,13 @@ import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
|
||||
USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
|
||||
FILTERS: list["SessionFilter"] = [] # 存储 SessionFilter 实例
|
||||
USER_SESSIONS: dict[str, SessionWaiter] = {} # 存储 SessionWaiter 实例
|
||||
FILTERS: list[SessionFilter] = [] # 存储 SessionFilter 实例
|
||||
|
||||
|
||||
class SessionController:
|
||||
@@ -20,16 +24,16 @@ class SessionController:
|
||||
|
||||
def __init__(self):
|
||||
self.future = asyncio.Future()
|
||||
self.current_event: asyncio.Event = None
|
||||
self.current_event: anyio.Event | None = None
|
||||
"""当前正在等待的所用的异步事件"""
|
||||
self.ts: float = None
|
||||
self.ts: float | None = None
|
||||
"""上次保持(keep)开始时的时间"""
|
||||
self.timeout: float | int = None
|
||||
self.timeout: float | int | None = None
|
||||
"""上次保持(keep)开始时的超时时间"""
|
||||
|
||||
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
||||
|
||||
def stop(self, error: Exception = None):
|
||||
def stop(self, error: Exception | None = None):
|
||||
"""立即结束这个会话"""
|
||||
if not self.future.done():
|
||||
if error:
|
||||
@@ -53,7 +57,9 @@ class SessionController:
|
||||
self.stop()
|
||||
return
|
||||
else:
|
||||
left_timeout = self.timeout - (new_ts - self.ts)
|
||||
current_timeout = self.timeout if self.timeout is not None else 0
|
||||
current_ts = self.ts if self.ts is not None else new_ts
|
||||
left_timeout = current_timeout - (new_ts - current_ts)
|
||||
timeout = left_timeout + timeout
|
||||
if timeout <= 0:
|
||||
self.stop()
|
||||
@@ -62,18 +68,19 @@ class SessionController:
|
||||
if self.current_event and not self.current_event.is_set():
|
||||
self.current_event.set() # 通知上一个 keep 结束
|
||||
|
||||
new_event = asyncio.Event()
|
||||
new_event = anyio.Event()
|
||||
self.ts = new_ts
|
||||
self.current_event = new_event
|
||||
self.timeout = timeout
|
||||
|
||||
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||
anyio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||
|
||||
async def _holding(self, event: asyncio.Event, timeout: int):
|
||||
async def _holding(self, event: anyio.Event, timeout_seconds: float):
|
||||
"""等待事件结束或超时"""
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
with anyio.move_on_after(timeout_seconds):
|
||||
await event.wait()
|
||||
except TimeoutError:
|
||||
if not self.future.done():
|
||||
self.future.set_exception(TimeoutError("等待超时"))
|
||||
except asyncio.CancelledError:
|
||||
@@ -105,10 +112,12 @@ class SessionWaiter:
|
||||
session_filter: SessionFilter,
|
||||
session_id: str,
|
||||
record_history_chains: bool,
|
||||
):
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.session_filter = session_filter
|
||||
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
|
||||
self.handler: (
|
||||
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
|
||||
) = None # 处理函数
|
||||
|
||||
self.session_controller = SessionController()
|
||||
self.record_history_chains = record_history_chains
|
||||
@@ -119,15 +128,15 @@ class SessionWaiter:
|
||||
|
||||
async def register_wait(
|
||||
self,
|
||||
handler: Callable[[str], Awaitable[Any]],
|
||||
timeout: int = 30,
|
||||
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
timeout_seconds: int = 30,
|
||||
) -> Any:
|
||||
"""等待外部输入并处理"""
|
||||
self.handler = handler
|
||||
USER_SESSIONS[self.session_id] = self
|
||||
|
||||
# 开始一个会话保持事件
|
||||
self.session_controller.keep(timeout, reset_timeout=True)
|
||||
self.session_controller.keep(timeout_seconds, reset_timeout=True)
|
||||
|
||||
try:
|
||||
return await self.session_controller.future
|
||||
@@ -137,7 +146,7 @@ class SessionWaiter:
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self, error: Exception = None):
|
||||
def _cleanup(self, error: Exception | None = None):
|
||||
"""清理会话"""
|
||||
USER_SESSIONS.pop(self.session_id, None)
|
||||
try:
|
||||
@@ -153,6 +162,10 @@ class SessionWaiter:
|
||||
if not session or session.session_controller.future.done():
|
||||
return
|
||||
|
||||
# 此时 session 不会是 None,因为上面的检查
|
||||
if session is None:
|
||||
return
|
||||
|
||||
async with session._lock:
|
||||
if not session.session_controller.future.done():
|
||||
if session.record_history_chains:
|
||||
@@ -161,7 +174,8 @@ class SessionWaiter:
|
||||
)
|
||||
try:
|
||||
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
||||
await session.handler(session.session_controller, event)
|
||||
if session.handler is not None:
|
||||
await session.handler(session.session_controller, event)
|
||||
except Exception as e:
|
||||
session.session_controller.stop(e)
|
||||
|
||||
@@ -173,11 +187,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
||||
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[str], Awaitable[Any]]):
|
||||
def decorator(
|
||||
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(
|
||||
event: AstrMessageEvent,
|
||||
session_filter: SessionFilter = None,
|
||||
session_filter: SessionFilter | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
import anyio
|
||||
import jwt
|
||||
from quart import request
|
||||
|
||||
@@ -44,7 +44,7 @@ class AuthRoute(Route):
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
await anyio.sleep(3)
|
||||
return Response().error("用户名或密码错误").__dict__
|
||||
|
||||
async def edit_account(self):
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from quart import Response as QuartResponse
|
||||
from quart import g, make_response, request
|
||||
|
||||
@@ -125,8 +126,6 @@ class ChatRoute(Route):
|
||||
audio_url = post_data.get("audio_url")
|
||||
selected_provider = post_data.get("selected_provider")
|
||||
selected_model = post_data.get("selected_model")
|
||||
enable_streaming = post_data.get("enable_streaming", True) # 默认为 True
|
||||
|
||||
if not message and not image_url and not audio_url:
|
||||
return (
|
||||
Response()
|
||||
@@ -190,8 +189,8 @@ class ChatRoute(Route):
|
||||
|
||||
try:
|
||||
if not client_disconnected:
|
||||
await asyncio.sleep(0.05)
|
||||
except asyncio.CancelledError:
|
||||
await anyio.sleep(0.05)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||
client_disconnected = True
|
||||
|
||||
@@ -204,8 +203,6 @@ class ChatRoute(Route):
|
||||
):
|
||||
# 追加机器人消息
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
if "reasoning" in result:
|
||||
new_his["reasoning"] = result["reasoning"]
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
@@ -228,7 +225,6 @@ class ChatRoute(Route):
|
||||
"audio_url": audio_url,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"enable_streaming": enable_streaming,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@@ -817,7 +817,8 @@ class ConfigRoute(Route):
|
||||
cached_token = self._logo_token_cache[cache_key]
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl or not isinstance(
|
||||
platform_default_tmpl[platform.name], dict
|
||||
platform_default_tmpl[platform.name],
|
||||
dict,
|
||||
):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
platform_default_tmpl[platform.name]["logo_token"] = cached_token
|
||||
@@ -846,7 +847,8 @@ class ConfigRoute(Route):
|
||||
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl or not isinstance(
|
||||
platform_default_tmpl[platform.name], dict
|
||||
platform_default_tmpl[platform.name],
|
||||
dict,
|
||||
):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
|
||||
|
||||
@@ -395,15 +395,9 @@ class PluginRoute(Route):
|
||||
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
delete_config = post_data.get("delete_config", False)
|
||||
delete_data = post_data.get("delete_data", False)
|
||||
try:
|
||||
logger.info(f"正在卸载插件 {plugin_name}")
|
||||
await self.plugin_manager.uninstall_plugin(
|
||||
plugin_name,
|
||||
delete_config=delete_config,
|
||||
delete_data=delete_data,
|
||||
)
|
||||
await self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
logger.info(f"卸载插件 {plugin_name} 成功")
|
||||
return Response().ok(None, "卸载成功").__dict__
|
||||
except Exception as e:
|
||||
|
||||
@@ -296,15 +296,7 @@ class ToolsRoute(Route):
|
||||
"""获取所有注册的工具列表"""
|
||||
try:
|
||||
tools = self.tool_mgr.func_list
|
||||
tools_dict = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
"active": tool.active,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
tools_dict = [tool.__dict__() for tool in tools]
|
||||
return Response().ok(data=tools_dict).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
## What's Changed
|
||||
|
||||
1. 修复:>= Python 3.12 版本下可能导致 LLM Tool 注册错误的问题。
|
||||
2. 优化:更好地适配 Class 方式注册 LLM Tool 的场景。引入 `call` 方法。
|
||||
3. 新增:`ConversationManager` 类支持 `add_message_pair` 方法,简化对话消息的添加操作。
|
||||
4. 新增:增加对 Tool Parameters 的参数验证,确保工具参数符合 JSON Schema 标准。
|
||||
5. 新增:增加 LLM Message Schema 定义,提升消息结构的规范性和一致性。
|
||||
6. 新增:支持对 WebUI 的侧边栏模块进行自定义配置(入口在侧边栏下方的设置页中)。
|
||||
@@ -1,5 +0,0 @@
|
||||
## What's Changed
|
||||
|
||||
> hotfix version of 4.5.2
|
||||
|
||||
1. 修复:修正 `get_tool_list` 方法中工具字典推导式的错误导致的 WebUI MCP 页面工具列表无法显示的问题。
|
||||
@@ -1,5 +0,0 @@
|
||||
## What's Changed
|
||||
|
||||
1. 修复:Docker 镜像部分依赖问题导致某些情况下无法启动容器的问题;
|
||||
2. 优化:插件卡片样式
|
||||
3. 修复:部分情况下 Windows 一键启动部署时,更新 / 部署失败的问题;
|
||||
@@ -1,3 +0,0 @@
|
||||
## What's Changed
|
||||
|
||||
1. 修复:部署失败
|
||||
@@ -1,3 +0,0 @@
|
||||
## What's Changed
|
||||
|
||||
1. 修复:构建失败
|
||||
@@ -1,12 +0,0 @@
|
||||
## What's Changed
|
||||
|
||||
1. 新增:支持为 OpenAI API 提供商自定义请求头 ([#3581](https://github.com/AstrBotDevs/AstrBot/issues/3581))
|
||||
2. 新增:为 WebChat 为 Thinking 模型添加思考过程展示功能;支持快捷切换流式输出 / 非流式输出。([#3632](https://github.com/AstrBotDevs/AstrBot/issues/3632))
|
||||
3. 新增:优化插件调用 LLM 和 Agent 的路径,为 Context 类引入多个调用 LLM 和 Agent 的便捷方法 ([#3636](https://github.com/AstrBotDevs/AstrBot/issues/3636))
|
||||
4. 优化:改善不支持流式输出的消息平台的回退策略 ([#3547](https://github.com/AstrBotDevs/AstrBot/issues/3547))
|
||||
5. 优化:当同一个会话(umo)下同时有多个请求时,执行排队处理,避免并发请求导致的上下文混乱问题 ([#3607](https://github.com/AstrBotDevs/AstrBot/issues/3607))
|
||||
6. 优化:优化 WebUI 的登录界面和 Changelog 页面的显示效果
|
||||
7. 修复:修复在知识库名字过长的情况下,“选择知识库”按钮显示异常的问题 ([#3582](https://github.com/AstrBotDevs/AstrBot/issues/3582))
|
||||
8. 修复:修复部分情况下,分段消息发送时导致的死锁问题(由 PR #3607 引入)
|
||||
9. 修复:钉钉适配器使用部分指令无法生效的问题 ([#3634](https://github.com/AstrBotDevs/AstrBot/issues/3634))
|
||||
10. 其他:为部分适配器添加缺失的 send_streaming 方法 ([#3545](https://github.com/AstrBotDevs/AstrBot/issues/3545))
|
||||
@@ -1,5 +0,0 @@
|
||||
## What's Changed
|
||||
|
||||
hot fix of 4.5.7
|
||||
|
||||
fix: 无法正常发送图片,报错 `pydantic_core._pydantic_core.ValidationError`
|
||||
3
dashboard/.gitignore
vendored
3
dashboard/.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
node_modules/
|
||||
.DS_Store
|
||||
dist/
|
||||
.DS_Store
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 58 KiB |
@@ -1,70 +1,55 @@
|
||||
<template>
|
||||
<v-card class="chat-page-card" elevation="0" rounded="0">
|
||||
<v-card class="chat-page-card">
|
||||
<v-card-text class="chat-page-container">
|
||||
<!-- 遮罩层 (手机端) -->
|
||||
<div class="mobile-overlay" v-if="isMobile && mobileMenuOpen" @click="closeMobileSidebar"></div>
|
||||
|
||||
<div class="chat-layout">
|
||||
<div class="sidebar-panel"
|
||||
:class="{
|
||||
'sidebar-collapsed': sidebarCollapsed && !isMobile,
|
||||
'mobile-sidebar-open': isMobile && mobileMenuOpen,
|
||||
'mobile-sidebar': isMobile
|
||||
}"
|
||||
:style="{ 'background-color': isDark ? sidebarCollapsed ? '#1e1e1e' : '#2d2d2d' : sidebarCollapsed ? '#ffffff' : '#f1f4f9' }"
|
||||
<div class="sidebar-panel" :class="{ 'sidebar-collapsed': sidebarCollapsed }"
|
||||
:style="{ 'background-color': isDark ? sidebarCollapsed ? '#1e1e1e' : '#2d2d2d' : sidebarCollapsed ? '#ffffff' : '#f5f5f5' }"
|
||||
@mouseenter="handleSidebarMouseEnter" @mouseleave="handleSidebarMouseLeave">
|
||||
|
||||
<div style="display: flex; align-items: center; justify-content: center; padding: 16px; padding-bottom: 0px;"
|
||||
v-if="chatboxMode">
|
||||
<img width="50" src="@/assets/images/icon-no-shadow.svg" alt="AstrBot Logo">
|
||||
<img width="50" src="@/assets/images/astrbot_logo_mini.webp" alt="AstrBot Logo">
|
||||
<span v-if="!sidebarCollapsed"
|
||||
style="font-weight: 1000; font-size: 26px; margin-left: 8px;">AstrBot</span>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="sidebar-collapse-btn-container" v-if="!isMobile">
|
||||
<div class="sidebar-collapse-btn-container">
|
||||
<v-btn icon class="sidebar-collapse-btn" @click="toggleSidebar" variant="text"
|
||||
color="deep-purple">
|
||||
<v-icon>{{ (sidebarCollapsed || (!sidebarCollapsed && sidebarHoverExpanded)) ?
|
||||
'mdi-chevron-right' : 'mdi-chevron-left' }}</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
<!-- 手机端关闭按钮 -->
|
||||
<div class="sidebar-collapse-btn-container" v-if="isMobile">
|
||||
<v-btn icon class="sidebar-collapse-btn" @click="closeMobileSidebar" variant="text"
|
||||
color="deep-purple">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div style="padding: 16px; padding-top: 8px;">
|
||||
<v-btn block variant="text" class="new-chat-btn" @click="newC" :disabled="!currCid"
|
||||
v-if="!sidebarCollapsed || isMobile" prepend-icon="mdi-plus"
|
||||
v-if="!sidebarCollapsed" prepend-icon="mdi-plus"
|
||||
style="background-color: transparent !important; border-radius: 4px;">{{
|
||||
tm('actions.newChat') }}</v-btn>
|
||||
<v-btn icon="mdi-plus" rounded="lg" @click="newC" :disabled="!currCid" v-if="sidebarCollapsed && !isMobile"
|
||||
<v-btn icon="mdi-plus" rounded="lg" @click="newC" :disabled="!currCid" v-if="sidebarCollapsed"
|
||||
elevation="0"></v-btn>
|
||||
</div>
|
||||
<div v-if="!sidebarCollapsed || isMobile">
|
||||
<div v-if="!sidebarCollapsed">
|
||||
<v-divider class="mx-4"></v-divider>
|
||||
</div>
|
||||
|
||||
|
||||
<div style="overflow-y: auto; flex-grow: 1;" :class="{ 'fade-in': sidebarHoverExpanded }"
|
||||
v-if="!sidebarCollapsed || isMobile">
|
||||
v-if="!sidebarCollapsed">
|
||||
<v-card v-if="conversations.length > 0" flat style="background-color: transparent;">
|
||||
<v-list density="compact" nav class="conversation-list"
|
||||
style="background-color: transparent;" v-model:selected="selectedConversations"
|
||||
@update:selected="getConversationMessages">
|
||||
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
|
||||
rounded="lg" class="conversation-item" active-color="secondary">
|
||||
<v-list-item-title v-if="!sidebarCollapsed || isMobile" class="conversation-title">{{ item.title
|
||||
<v-list-item-title v-if="!sidebarCollapsed" class="conversation-title">{{ item.title
|
||||
|| tm('conversation.newConversation') }}</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="!sidebarCollapsed || isMobile" class="timestamp">{{
|
||||
<v-list-item-subtitle v-if="!sidebarCollapsed" class="timestamp">{{
|
||||
formatDate(item.updated_at)
|
||||
}}</v-list-item-subtitle>
|
||||
}}</v-list-item-subtitle>
|
||||
|
||||
<template v-if="!sidebarCollapsed || isMobile" v-slot:append>
|
||||
<template v-if="!sidebarCollapsed" v-slot:append>
|
||||
<div class="conversation-actions">
|
||||
<v-btn icon="mdi-pencil" size="x-small" variant="text"
|
||||
class="edit-title-btn"
|
||||
@@ -81,7 +66,7 @@
|
||||
<v-fade-transition>
|
||||
<div class="no-conversations" v-if="conversations.length === 0">
|
||||
<v-icon icon="mdi-message-text-outline" size="large" color="grey-lighten-1"></v-icon>
|
||||
<div class="no-conversations-text" v-if="!sidebarCollapsed || sidebarHoverExpanded || isMobile">
|
||||
<div class="no-conversations-text" v-if="!sidebarCollapsed || sidebarHoverExpanded">
|
||||
{{ tm('conversation.noHistory') }}</div>
|
||||
</div>
|
||||
</v-fade-transition>
|
||||
@@ -93,17 +78,12 @@
|
||||
<div class="chat-content-panel">
|
||||
|
||||
<div class="conversation-header fade-in">
|
||||
<!-- 手机端菜单按钮 -->
|
||||
<v-btn icon class="mobile-menu-btn" @click="toggleMobileSidebar" v-if="isMobile" variant="text">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<!-- <div v-if="currCid && getCurrentConversation">
|
||||
<div v-if="currCid && getCurrentConversation">
|
||||
<h3
|
||||
style="max-width: 300px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap;">
|
||||
{{ getCurrentConversation.title || tm('conversation.newConversation') }}</h3>
|
||||
<span style="font-size: 12px;">{{ formatDate(getCurrentConversation.updated_at) }}</span>
|
||||
</div> -->
|
||||
</div>
|
||||
<div class="conversation-header-actions">
|
||||
<!-- router 推送到 /chatbox -->
|
||||
<v-tooltip :text="tm('actions.fullscreen')" v-if="!chatboxMode">
|
||||
@@ -137,6 +117,7 @@
|
||||
</v-tooltip>
|
||||
</div>
|
||||
</div>
|
||||
<v-divider v-if="currCid && getCurrentConversation" class="conversation-divider"></v-divider>
|
||||
|
||||
<MessageList v-if="messages && messages.length > 0" :messages="messages" :isDark="isDark"
|
||||
:isStreaming="isStreaming || isConvRunning" @openImagePreview="openImagePreview"
|
||||
@@ -146,34 +127,36 @@
|
||||
<span>Hello, I'm</span>
|
||||
<span class="bot-name">AstrBot ⭐</span>
|
||||
</div>
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.type') }}</span>
|
||||
<code>help</code>
|
||||
<span>{{ tm('shortcuts.help') }} 😊</span>
|
||||
</div>
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.longPress') }}</span>
|
||||
<code>Ctrl + B</code>
|
||||
<span>{{ tm('shortcuts.voiceRecord') }} 🎤</span>
|
||||
</div>
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.press') }}</span>
|
||||
<code>Ctrl + V</code>
|
||||
<span>{{ tm('shortcuts.pasteImage') }} 🏞️</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 输入区域 -->
|
||||
<div class="input-area fade-in">
|
||||
<div class="input-container"
|
||||
<div
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px;">
|
||||
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown"
|
||||
:disabled="isStreaming" @click:clear="clearMessage" placeholder="Ask AstrBot..."
|
||||
:disabled="isStreaming" @click:clear="clearMessage"
|
||||
placeholder="Ask AstrBot..."
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
<div
|
||||
style="display: flex; justify-content: space-between; align-items: center; padding: 0px 12px;">
|
||||
<div
|
||||
style="display: flex; justify-content: flex-start; margin-top: 4px; align-items: center; gap: 8px;">
|
||||
style="display: flex; justify-content: space-between; align-items: center; padding: 0px 8px;">
|
||||
<div style="display: flex; justify-content: flex-start; margin-top: 4px;">
|
||||
<!-- 选择提供商和模型 -->
|
||||
<ProviderModelSelector ref="providerModelSelector" />
|
||||
<!-- 流式响应开关 -->
|
||||
<v-tooltip
|
||||
:text="enableStreaming ? tm('streaming.enabled') : tm('streaming.disabled')"
|
||||
location="top">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-chip v-bind="props" @click="toggleStreaming" size="x-small"
|
||||
class="streaming-toggle-chip">
|
||||
<v-icon start :icon="enableStreaming ? 'mdi-flash' : 'mdi-flash-off'"
|
||||
size="small"></v-icon>
|
||||
{{ enableStreaming ? tm('streaming.on') : tm('streaming.off') }}
|
||||
</v-chip>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
</div>
|
||||
<div
|
||||
style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;">
|
||||
@@ -192,6 +175,7 @@
|
||||
class="send-btn" size="small" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- 附件预览区 -->
|
||||
@@ -258,7 +242,6 @@ import ProviderModelSelector from '@/components/chat/ProviderModelSelector.vue';
|
||||
import MessageList from '@/components/chat/MessageList.vue';
|
||||
import 'highlight.js/styles/github.css';
|
||||
import { useToast } from '@/utils/toast';
|
||||
import { useTheme } from 'vuetify';
|
||||
|
||||
export default {
|
||||
name: 'ChatPage',
|
||||
@@ -275,12 +258,10 @@ export default {
|
||||
}, setup() {
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n('features/chat');
|
||||
const theme = useTheme();
|
||||
|
||||
return {
|
||||
t,
|
||||
tm,
|
||||
theme,
|
||||
router,
|
||||
ref
|
||||
};
|
||||
@@ -306,7 +287,7 @@ export default {
|
||||
// Ctrl键长按相关变量
|
||||
ctrlKeyDown: false,
|
||||
ctrlKeyTimer: null,
|
||||
ctrlKeyLongPressThreshold: 300, // 长按阈值,单位毫秒
|
||||
ctrlKeyLongPressThreshold: 300, // 长按阈值,单位毫秒
|
||||
|
||||
mediaCache: {}, // Add a cache to store media blobs
|
||||
|
||||
@@ -332,13 +313,6 @@ export default {
|
||||
|
||||
isToastedRunningInfo: false, // To avoid multiple toasts
|
||||
activeSSECount: 0, // Track number of active SSE connections
|
||||
|
||||
// 流式响应开关
|
||||
enableStreaming: true, // 默认开启流式响应
|
||||
|
||||
// 手机端相关变量
|
||||
isMobile: false,
|
||||
mobileMenuOpen: false,
|
||||
}
|
||||
},
|
||||
|
||||
@@ -417,18 +391,6 @@ export default {
|
||||
this.sidebarCollapsed = true; // 默认折叠状态
|
||||
}
|
||||
|
||||
// 从 localStorage 读取流式响应开关状态,默认为 true(开启)
|
||||
const savedStreamingState = localStorage.getItem('enableStreaming');
|
||||
if (savedStreamingState !== null) {
|
||||
this.enableStreaming = JSON.parse(savedStreamingState);
|
||||
} else {
|
||||
this.enableStreaming = true; // 默认开启
|
||||
}
|
||||
|
||||
// 检测是否为手机端
|
||||
this.checkMobile();
|
||||
window.addEventListener('resize', this.checkMobile);
|
||||
|
||||
// 设置输入框标签
|
||||
this.inputFieldLabel = this.tm('input.chatPrompt');
|
||||
this.getConversations();
|
||||
@@ -451,9 +413,6 @@ export default {
|
||||
beforeUnmount() {
|
||||
// 移除keyup事件监听
|
||||
document.removeEventListener('keyup', this.handleInputKeyUp);
|
||||
|
||||
// 移除resize事件监听
|
||||
window.removeEventListener('resize', this.checkMobile);
|
||||
|
||||
// 清除悬停定时器
|
||||
if (this.sidebarHoverTimer) {
|
||||
@@ -468,28 +427,6 @@ export default {
|
||||
const customizer = useCustomizerStore();
|
||||
const newTheme = customizer.uiTheme === 'PurpleTheme' ? 'PurpleThemeDark' : 'PurpleTheme';
|
||||
customizer.SET_UI_THEME(newTheme);
|
||||
this.theme.global.name.value = newTheme;
|
||||
},
|
||||
// 检测是否为手机端
|
||||
checkMobile() {
|
||||
this.isMobile = window.innerWidth <= 768;
|
||||
// 如果切换到桌面端,关闭手机菜单
|
||||
if (!this.isMobile) {
|
||||
this.mobileMenuOpen = false;
|
||||
}
|
||||
},
|
||||
// 切换手机端菜单
|
||||
toggleMobileSidebar() {
|
||||
this.mobileMenuOpen = !this.mobileMenuOpen;
|
||||
},
|
||||
// 关闭手机端菜单
|
||||
closeMobileSidebar() {
|
||||
this.mobileMenuOpen = false;
|
||||
},
|
||||
// 切换流式响应
|
||||
toggleStreaming() {
|
||||
this.enableStreaming = !this.enableStreaming;
|
||||
localStorage.setItem('enableStreaming', JSON.stringify(this.enableStreaming));
|
||||
},
|
||||
// 切换侧边栏折叠状态
|
||||
toggleSidebar() {
|
||||
@@ -504,7 +441,7 @@ export default {
|
||||
|
||||
// 侧边栏鼠标悬停处理
|
||||
handleSidebarMouseEnter() {
|
||||
if (!this.sidebarCollapsed || this.isMobile) return;
|
||||
if (!this.sidebarCollapsed) return;
|
||||
|
||||
this.sidebarHovered = true;
|
||||
|
||||
@@ -731,11 +668,6 @@ export default {
|
||||
return
|
||||
}
|
||||
|
||||
// 手机端关闭侧边栏
|
||||
if (this.isMobile) {
|
||||
this.closeMobileSidebar();
|
||||
}
|
||||
|
||||
axios.get('/api/chat/get_conversation?conversation_id=' + cid[0]).then(async response => {
|
||||
this.currCid = cid[0];
|
||||
// Update the selected conversation in the sidebar
|
||||
@@ -816,10 +748,6 @@ export default {
|
||||
this.currCid = '';
|
||||
this.selectedConversations = []; // 清除选中状态
|
||||
this.messages = [];
|
||||
// 手机端关闭侧边栏
|
||||
if (this.isMobile) {
|
||||
this.closeMobileSidebar();
|
||||
}
|
||||
if (this.$route.path.startsWith('/chatbox')) {
|
||||
this.$router.push('/chatbox');
|
||||
} else {
|
||||
@@ -939,8 +867,7 @@ export default {
|
||||
image_url: imageNamesToSend,
|
||||
audio_url: audioNameToSend ? [audioNameToSend] : [],
|
||||
selected_provider: selectedProviderId,
|
||||
selected_model: selectedModelName,
|
||||
enable_streaming: this.enableStreaming
|
||||
selected_model: selectedModelName
|
||||
})
|
||||
});
|
||||
|
||||
@@ -1016,26 +943,17 @@ export default {
|
||||
"content": bot_resp
|
||||
});
|
||||
} else if (chunk_json.type === 'plain') {
|
||||
const chain_type = chunk_json.chain_type || 'normal';
|
||||
|
||||
if (!in_streaming) {
|
||||
message_obj = {
|
||||
type: 'bot',
|
||||
message: this.ref(chain_type === 'reasoning' ? '' : chunk_json.data),
|
||||
reasoning: this.ref(chain_type === 'reasoning' ? chunk_json.data : ''),
|
||||
message: this.ref(chunk_json.data),
|
||||
}
|
||||
this.messages.push({
|
||||
"content": message_obj
|
||||
});
|
||||
in_streaming = true;
|
||||
} else {
|
||||
if (chain_type === 'reasoning') {
|
||||
// Append to reasoning content
|
||||
message_obj.reasoning.value += chunk_json.data;
|
||||
} else {
|
||||
// Append to normal message
|
||||
message_obj.message.value += chunk_json.data;
|
||||
}
|
||||
message_obj.message.value += chunk_json.data;
|
||||
}
|
||||
} else if (chunk_json.type === 'update_title') {
|
||||
// 更新对话标题
|
||||
@@ -1183,17 +1101,6 @@ export default {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
/* 流式响应开关芯片样式 */
|
||||
.streaming-toggle-chip {
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.streaming-toggle-chip:hover {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.welcome-title {
|
||||
font-size: 28px;
|
||||
margin-bottom: 16px;
|
||||
@@ -1234,6 +1141,7 @@ export default {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05) !important;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
@@ -1258,7 +1166,7 @@ export default {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 0;
|
||||
border-right: 1px solid rgba(0, 0, 0, 0.04);
|
||||
border-right: 1px solid rgba(0, 0, 0, 0.05);
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
position: relative;
|
||||
@@ -1280,77 +1188,6 @@ export default {
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
/* 手机端菜单按钮 */
|
||||
.mobile-menu-btn {
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
/* 手机端遮罩层 */
|
||||
.mobile-overlay {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.5);
|
||||
z-index: 999;
|
||||
animation: fadeIn 0.3s ease;
|
||||
}
|
||||
|
||||
/* 手机端侧边栏 */
|
||||
.mobile-sidebar {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
bottom: 0;
|
||||
max-width: 280px !important;
|
||||
min-width: 280px !important;
|
||||
transform: translateX(-100%);
|
||||
transition: transform 0.3s ease;
|
||||
z-index: 1000;
|
||||
}
|
||||
|
||||
.mobile-sidebar-open {
|
||||
transform: translateX(0) !important;
|
||||
}
|
||||
|
||||
/* 手机端样式调整 */
|
||||
@media (max-width: 768px) {
|
||||
.sidebar-panel:not(.mobile-sidebar) {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.chat-content-panel {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
/* 手机端去掉容器padding */
|
||||
.chat-page-container {
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
/* 手机端输入区域样式 */
|
||||
.input-area {
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
margin: 0 !important;
|
||||
border-radius: 0 !important;
|
||||
border-left: none !important;
|
||||
border-right: none !important;
|
||||
border-bottom: none !important;
|
||||
}
|
||||
|
||||
#input-field {
|
||||
border-radius: 0 !important;
|
||||
border-left: none !important;
|
||||
border-right: none !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* 侧边栏折叠按钮 */
|
||||
.sidebar-collapse-btn-container {
|
||||
margin: 16px;
|
||||
@@ -1430,12 +1267,25 @@ export default {
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.v-chip {
|
||||
.status-chips {
|
||||
display: flex;
|
||||
flex-wrap: nowrap;
|
||||
gap: 8px;
|
||||
margin-bottom: 8px;
|
||||
transition: opacity 0.25s ease;
|
||||
}
|
||||
|
||||
.status-chips .v-chip {
|
||||
flex: 1 1 0;
|
||||
justify-content: center;
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.status-chip {
|
||||
font-size: 12px;
|
||||
height: 24px !important;
|
||||
}
|
||||
|
||||
.no-conversations {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
@@ -33,23 +33,10 @@
|
||||
<v-avatar class="bot-avatar" size="36">
|
||||
<v-progress-circular :index="index" v-if="isStreaming && index === messages.length - 1" indeterminate size="28"
|
||||
width="2"></v-progress-circular>
|
||||
<v-icon v-else-if="messages[index - 1]?.content.type !== 'bot'" size="64" color="#8fb6d2">mdi-star-four-points-small</v-icon>
|
||||
<span v-else-if="messages[index - 1]?.content.type !== 'bot'" class="text-h2">✨</span>
|
||||
</v-avatar>
|
||||
<div class="bot-message-content">
|
||||
<div class="message-bubble bot-bubble">
|
||||
<!-- Reasoning Block (Collapsible) -->
|
||||
<div v-if="msg.content.reasoning && msg.content.reasoning.trim()" class="reasoning-container">
|
||||
<div class="reasoning-header" @click="toggleReasoning(index)">
|
||||
<v-icon size="small" class="reasoning-icon">
|
||||
{{ isReasoningExpanded(index) ? 'mdi-chevron-down' : 'mdi-chevron-right' }}
|
||||
</v-icon>
|
||||
<span class="reasoning-label">{{ tm('reasoning.thinking') }}</span>
|
||||
</div>
|
||||
<div v-if="isReasoningExpanded(index)" class="reasoning-content">
|
||||
<div v-html="md.render(msg.content.reasoning)" class="markdown-content reasoning-text"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Text -->
|
||||
<div v-if="msg.content.message && msg.content.message.trim()"
|
||||
v-html="md.render(msg.content.message)" class="markdown-content"></div>
|
||||
@@ -138,8 +125,7 @@ export default {
|
||||
copiedMessages: new Set(),
|
||||
isUserNearBottom: true,
|
||||
scrollThreshold: 1,
|
||||
scrollTimer: null,
|
||||
expandedReasoning: new Set(), // Track which reasoning blocks are expanded
|
||||
scrollTimer: null
|
||||
};
|
||||
},
|
||||
mounted() {
|
||||
@@ -156,22 +142,6 @@ export default {
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
// Toggle reasoning expansion state
|
||||
toggleReasoning(messageIndex) {
|
||||
if (this.expandedReasoning.has(messageIndex)) {
|
||||
this.expandedReasoning.delete(messageIndex);
|
||||
} else {
|
||||
this.expandedReasoning.add(messageIndex);
|
||||
}
|
||||
// Force reactivity
|
||||
this.expandedReasoning = new Set(this.expandedReasoning);
|
||||
},
|
||||
|
||||
// Check if reasoning is expanded
|
||||
isReasoningExpanded(messageIndex) {
|
||||
return this.expandedReasoning.has(messageIndex);
|
||||
},
|
||||
|
||||
// 复制代码到剪贴板
|
||||
copyCodeToClipboard(code) {
|
||||
navigator.clipboard.writeText(code).then(() => {
|
||||
@@ -378,7 +348,7 @@ export default {
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(0);
|
||||
transform: translateY(10px);
|
||||
}
|
||||
|
||||
to {
|
||||
@@ -569,69 +539,6 @@ export default {
|
||||
.fade-in {
|
||||
animation: fadeIn 0.3s ease-in-out;
|
||||
}
|
||||
|
||||
/* Reasoning 区块样式 */
|
||||
.reasoning-container {
|
||||
margin-bottom: 12px;
|
||||
margin-top: 6px;
|
||||
border: 1px solid var(--v-theme-border);
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
width: fit-content;
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-container {
|
||||
background-color: rgba(103, 58, 183, 0.08);
|
||||
}
|
||||
|
||||
.reasoning-header {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
padding: 8px 8px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s ease;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.reasoning-header:hover {
|
||||
background-color: rgba(103, 58, 183, 0.08);
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-header:hover {
|
||||
background-color: rgba(103, 58, 183, 0.15);
|
||||
}
|
||||
|
||||
.reasoning-icon {
|
||||
margin-right: 6px;
|
||||
color: var(--v-theme-secondary);
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.reasoning-label {
|
||||
font-size: 13px;
|
||||
font-weight: 500;
|
||||
color: var(--v-theme-secondary);
|
||||
letter-spacing: 0.3px;
|
||||
}
|
||||
|
||||
.reasoning-content {
|
||||
padding: 0px 12px;
|
||||
border-top: 1px solid var(--v-theme-border);
|
||||
color: gray;
|
||||
animation: fadeIn 0.2s ease-in-out;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.reasoning-text {
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
color: var(--v-theme-secondaryText);
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-text {
|
||||
opacity: 0.85;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style>
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
<template>
|
||||
<div>
|
||||
<!-- 选择提供商和模型按钮 -->
|
||||
<v-chip class="text-none" variant="tonal" size="x-small"
|
||||
<v-btn class="text-none" variant="tonal" rounded="xl" size="small"
|
||||
v-if="selectedProviderId && selectedModelName" @click="openDialog">
|
||||
{{ selectedProviderId }} / {{ selectedModelName }}
|
||||
</v-chip>
|
||||
<v-chip variant="tonal" rounded="xl" size="x-small" v-else @click="openDialog">
|
||||
</v-btn>
|
||||
<v-btn variant="tonal" rounded="xl" size="small" v-else @click="openDialog">
|
||||
选择模型
|
||||
</v-chip>
|
||||
</v-btn>
|
||||
|
||||
<!-- 选择提供商和模型对话框 -->
|
||||
<v-dialog v-model="showDialog" max-width="800" persistent>
|
||||
|
||||
@@ -154,8 +154,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<div class="w-100" v-if="!itemMeta?._special">
|
||||
<!-- Select input for JSON selector -->
|
||||
<v-select v-if="itemMeta?.options" v-model="createSelectorModel(itemKey).value"
|
||||
:items="itemMeta?.labels ? itemMeta.options.map((value, index) => ({ title: itemMeta.labels[index] || value, value: value })) : itemMeta.options"
|
||||
:disabled="itemMeta?.readonly" density="compact" variant="outlined"
|
||||
:items="itemMeta?.options" :disabled="itemMeta?.readonly" density="compact" variant="outlined"
|
||||
class="config-field" hide-details></v-select>
|
||||
|
||||
<!-- Code Editor for JSON selector -->
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import { ref, computed, inject } from 'vue';
|
||||
import { useCustomizerStore } from "@/stores/customizer";
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import UninstallConfirmDialog from './UninstallConfirmDialog.vue';
|
||||
|
||||
const props = defineProps({
|
||||
extension: {
|
||||
@@ -32,7 +31,6 @@ const emit = defineEmits([
|
||||
]);
|
||||
|
||||
const reveal = ref(false);
|
||||
const showUninstallDialog = ref(false);
|
||||
|
||||
// 国际化
|
||||
const { tm } = useModuleI18n('features/extension');
|
||||
@@ -57,11 +55,19 @@ const installExtension = async () => {
|
||||
};
|
||||
|
||||
const uninstallExtension = async () => {
|
||||
showUninstallDialog.value = true;
|
||||
};
|
||||
if (typeof $confirm !== "function") {
|
||||
console.error(tm("card.errors.confirmNotRegistered"));
|
||||
return;
|
||||
}
|
||||
|
||||
const handleUninstallConfirm = (options: { deleteConfig: boolean; deleteData: boolean }) => {
|
||||
emit("uninstall", props.extension, options);
|
||||
const confirmed = await $confirm({
|
||||
title: tm("dialogs.uninstall.title"),
|
||||
message: tm("dialogs.uninstall.message"),
|
||||
});
|
||||
|
||||
if (confirmed) {
|
||||
emit("uninstall", props.extension);
|
||||
}
|
||||
};
|
||||
|
||||
const toggleActivation = () => {
|
||||
@@ -214,12 +220,6 @@ const viewReadme = () => {
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
|
||||
<!-- 卸载确认对话框 -->
|
||||
<UninstallConfirmDialog
|
||||
v-model="showUninstallDialog"
|
||||
@confirm="handleUninstallConfirm"
|
||||
/>
|
||||
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
<template>
|
||||
<div class="d-flex align-center justify-space-between" style="gap: 8px;">
|
||||
<div style="flex: 1; min-width: 0; overflow: hidden;">
|
||||
<span v-if="!modelValue || (Array.isArray(modelValue) && modelValue.length === 0)"
|
||||
style="color: rgb(var(--v-theme-primaryText));">
|
||||
未选择
|
||||
</span>
|
||||
<div v-else class="d-flex flex-wrap gap-1">
|
||||
<v-chip
|
||||
v-for="name in modelValue"
|
||||
:key="name"
|
||||
size="small"
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
closable
|
||||
@click:close="removeKnowledgeBase(name)"
|
||||
style="max-width: 100%;">
|
||||
<span class="text-truncate" style="max-width: 200px;">{{ name }}</span>
|
||||
</v-chip>
|
||||
</div>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<span v-if="!modelValue || (Array.isArray(modelValue) && modelValue.length === 0)"
|
||||
style="color: rgb(var(--v-theme-primaryText));">
|
||||
未选择
|
||||
</span>
|
||||
<div v-else class="d-flex flex-wrap gap-1">
|
||||
<v-chip
|
||||
v-for="name in modelValue"
|
||||
:key="name"
|
||||
size="small"
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
closable
|
||||
@click:close="removeKnowledgeBase(name)">
|
||||
{{ name }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog" style="flex-shrink: 0;">
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
</v-btn>
|
||||
</div>
|
||||
@@ -223,11 +220,4 @@ function goToKnowledgeBasePage() {
|
||||
.gap-1 {
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.text-truncate {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
display: inline-block;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -70,6 +70,10 @@ const formatTitle = (title: string) => {
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.logo-image img:hover {
|
||||
transform: scale(1.05);
|
||||
}
|
||||
|
||||
.logo-text {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
<template>
|
||||
<div style="margin-top: 16px;">
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="outlined"
|
||||
size="small"
|
||||
@click="openDialog"
|
||||
style="margin-bottom: 8px;"
|
||||
>
|
||||
{{ t('features.settings.sidebar.customize.title') }}
|
||||
</v-btn>
|
||||
|
||||
<v-dialog v-model="dialog" max-width="700px">
|
||||
<v-card>
|
||||
<v-card-title class="d-flex justify-space-between align-center">
|
||||
<span>{{ t('features.settings.sidebar.customize.title') }}</span>
|
||||
<v-btn
|
||||
icon="mdi-close"
|
||||
variant="text"
|
||||
@click="dialog = false"
|
||||
></v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text>
|
||||
<p class="text-body-2 mb-4">{{ t('features.settings.sidebar.customize.subtitle') }}</p>
|
||||
|
||||
<v-row>
|
||||
<v-col cols="12" md="6">
|
||||
<div class="mb-2 font-weight-medium">{{ t('features.settings.sidebar.customize.mainItems') }}</div>
|
||||
<v-list
|
||||
density="compact"
|
||||
class="custom-list"
|
||||
@dragover.prevent
|
||||
@drop="handleDropToList($event, 'main')"
|
||||
>
|
||||
<v-list-item
|
||||
v-for="(item, index) in mainItems"
|
||||
:key="item.title"
|
||||
class="mb-1 draggable-item"
|
||||
draggable="true"
|
||||
@dragstart="handleDragStart($event, 'main', index)"
|
||||
@dragover.prevent
|
||||
@drop.stop="handleDrop($event, 'main', index)"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-icon :icon="item.icon" size="small" class="mr-2"></v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ t(item.title) }}</v-list-item-title>
|
||||
<template v-slot:append>
|
||||
<v-btn
|
||||
icon="mdi-arrow-right"
|
||||
variant="text"
|
||||
size="x-small"
|
||||
@click="moveToMore(index)"
|
||||
></v-btn>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" md="6">
|
||||
<div class="mb-2 font-weight-medium">{{ t('features.settings.sidebar.customize.moreItems') }}</div>
|
||||
<v-list
|
||||
density="compact"
|
||||
class="custom-list"
|
||||
@dragover.prevent
|
||||
@drop="handleDropToList($event, 'more')"
|
||||
>
|
||||
<v-list-item
|
||||
v-for="(item, index) in moreItems"
|
||||
:key="item.title"
|
||||
class="mb-1 draggable-item"
|
||||
draggable="true"
|
||||
@dragstart="handleDragStart($event, 'more', index)"
|
||||
@dragover.prevent
|
||||
@drop.stop="handleDrop($event, 'more', index)"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-icon :icon="item.icon" size="small" class="mr-2"></v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ t(item.title) }}</v-list-item-title>
|
||||
<template v-slot:append>
|
||||
<v-btn
|
||||
icon="mdi-arrow-left"
|
||||
variant="text"
|
||||
size="x-small"
|
||||
@click="moveToMain(index)"
|
||||
></v-btn>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions>
|
||||
<v-btn
|
||||
color="error"
|
||||
variant="text"
|
||||
@click="resetToDefault"
|
||||
>
|
||||
{{ t('features.settings.sidebar.customize.reset') }}
|
||||
</v-btn>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="saveCustomization"
|
||||
>
|
||||
{{ t('core.actions.save') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, onMounted } from 'vue';
|
||||
import { useI18n } from '@/i18n/composables';
|
||||
import sidebarItems from '@/layouts/full/vertical-sidebar/sidebarItem';
|
||||
import {
|
||||
getSidebarCustomization,
|
||||
setSidebarCustomization,
|
||||
clearSidebarCustomization
|
||||
} from '@/utils/sidebarCustomization';
|
||||
|
||||
const { t } = useI18n();
|
||||
|
||||
const dialog = ref(false);
|
||||
const mainItems = ref([]);
|
||||
const moreItems = ref([]);
|
||||
const draggedItem = ref(null);
|
||||
|
||||
function initializeItems() {
|
||||
const customization = getSidebarCustomization();
|
||||
|
||||
if (customization) {
|
||||
// Load from customization
|
||||
const allItemsMap = new Map();
|
||||
|
||||
sidebarItems.forEach(item => {
|
||||
if (item.children) {
|
||||
item.children.forEach(child => {
|
||||
allItemsMap.set(child.title, child);
|
||||
});
|
||||
} else {
|
||||
allItemsMap.set(item.title, item);
|
||||
}
|
||||
});
|
||||
|
||||
mainItems.value = customization.mainItems
|
||||
.map(title => allItemsMap.get(title))
|
||||
.filter(item => item);
|
||||
|
||||
moreItems.value = customization.moreItems
|
||||
.map(title => allItemsMap.get(title))
|
||||
.filter(item => item);
|
||||
} else {
|
||||
// Load default structure
|
||||
mainItems.value = sidebarItems.filter(item => !item.children);
|
||||
|
||||
const moreGroup = sidebarItems.find(item => item.title === 'core.navigation.groups.more');
|
||||
moreItems.value = moreGroup ? [...moreGroup.children] : [];
|
||||
}
|
||||
}
|
||||
|
||||
function openDialog() {
|
||||
initializeItems();
|
||||
dialog.value = true;
|
||||
}
|
||||
|
||||
function handleDragStart(event, listType, index) {
|
||||
draggedItem.value = {
|
||||
type: listType,
|
||||
index: index,
|
||||
item: listType === 'main' ? mainItems.value[index] : moreItems.value[index]
|
||||
};
|
||||
event.dataTransfer.effectAllowed = 'move';
|
||||
}
|
||||
|
||||
function handleDrop(event, targetListType, targetIndex) {
|
||||
event.preventDefault();
|
||||
|
||||
if (!draggedItem.value) return;
|
||||
|
||||
const sourceListType = draggedItem.value.type;
|
||||
const sourceIndex = draggedItem.value.index;
|
||||
const item = draggedItem.value.item;
|
||||
|
||||
// Remove from source
|
||||
if (sourceListType === 'main') {
|
||||
mainItems.value.splice(sourceIndex, 1);
|
||||
} else {
|
||||
moreItems.value.splice(sourceIndex, 1);
|
||||
}
|
||||
|
||||
// Add to target
|
||||
if (targetListType === 'main') {
|
||||
mainItems.value.splice(targetIndex, 0, item);
|
||||
} else {
|
||||
moreItems.value.splice(targetIndex, 0, item);
|
||||
}
|
||||
|
||||
draggedItem.value = null;
|
||||
}
|
||||
|
||||
function handleDropToList(event, targetListType) {
|
||||
event.preventDefault();
|
||||
|
||||
if (!draggedItem.value) return;
|
||||
|
||||
const sourceListType = draggedItem.value.type;
|
||||
const sourceIndex = draggedItem.value.index;
|
||||
const item = draggedItem.value.item;
|
||||
|
||||
// Remove from source
|
||||
if (sourceListType === 'main') {
|
||||
mainItems.value.splice(sourceIndex, 1);
|
||||
} else {
|
||||
moreItems.value.splice(sourceIndex, 1);
|
||||
}
|
||||
|
||||
// Add to target list at the end
|
||||
if (targetListType === 'main') {
|
||||
mainItems.value.push(item);
|
||||
} else {
|
||||
moreItems.value.push(item);
|
||||
}
|
||||
|
||||
draggedItem.value = null;
|
||||
}
|
||||
|
||||
function moveToMore(index) {
|
||||
const item = mainItems.value.splice(index, 1)[0];
|
||||
moreItems.value.push(item);
|
||||
}
|
||||
|
||||
function moveToMain(index) {
|
||||
const item = moreItems.value.splice(index, 1)[0];
|
||||
mainItems.value.push(item);
|
||||
}
|
||||
|
||||
function saveCustomization() {
|
||||
const config = {
|
||||
mainItems: mainItems.value.map(item => item.title),
|
||||
moreItems: moreItems.value.map(item => item.title)
|
||||
};
|
||||
|
||||
setSidebarCustomization(config);
|
||||
|
||||
// Notify the sidebar to reload
|
||||
window.dispatchEvent(new CustomEvent('sidebar-customization-changed'));
|
||||
|
||||
dialog.value = false;
|
||||
}
|
||||
|
||||
function resetToDefault() {
|
||||
clearSidebarCustomization();
|
||||
initializeItems();
|
||||
|
||||
// Notify the sidebar to reload
|
||||
window.dispatchEvent(new CustomEvent('sidebar-customization-changed'));
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
initializeItems();
|
||||
});
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.draggable-item {
|
||||
cursor: move;
|
||||
border: 1px solid rgba(var(--v-border-color), var(--v-border-opacity));
|
||||
border-radius: 4px;
|
||||
background-color: rgba(var(--v-theme-surface));
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.draggable-item:hover {
|
||||
background-color: rgba(var(--v-theme-primary), 0.1);
|
||||
border-color: rgba(var(--v-theme-primary), 0.3);
|
||||
}
|
||||
|
||||
.custom-list {
|
||||
min-height: 200px;
|
||||
border: 1px dashed rgba(var(--v-border-color), var(--v-border-opacity));
|
||||
border-radius: 4px;
|
||||
padding: 8px;
|
||||
}
|
||||
</style>
|
||||
@@ -1,135 +0,0 @@
|
||||
<template>
|
||||
<v-dialog
|
||||
v-model="show"
|
||||
max-width="500"
|
||||
@click:outside="handleCancel"
|
||||
@keydown.esc="handleCancel"
|
||||
>
|
||||
<v-card>
|
||||
<v-card-title class="text-h5">
|
||||
{{ tm('dialogs.uninstall.title') }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text>
|
||||
<div class="mb-4">
|
||||
{{ tm('dialogs.uninstall.message') }}
|
||||
</div>
|
||||
|
||||
<v-divider class="my-4"></v-divider>
|
||||
|
||||
<div class="text-subtitle-2 mb-3">{{ t('core.common.actions') }}:</div>
|
||||
|
||||
<v-checkbox
|
||||
v-model="deleteConfig"
|
||||
:label="tm('dialogs.uninstall.deleteConfig')"
|
||||
color="warning"
|
||||
hide-details
|
||||
class="mb-2"
|
||||
>
|
||||
<template v-slot:append>
|
||||
<v-tooltip location="top">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-icon v-bind="props" size="small" color="grey">mdi-information-outline</v-icon>
|
||||
</template>
|
||||
<span>{{ tm('dialogs.uninstall.configHint') }}</span>
|
||||
</v-tooltip>
|
||||
</template>
|
||||
</v-checkbox>
|
||||
|
||||
<v-checkbox
|
||||
v-model="deleteData"
|
||||
:label="tm('dialogs.uninstall.deleteData')"
|
||||
color="error"
|
||||
hide-details
|
||||
>
|
||||
<template v-slot:append>
|
||||
<v-tooltip location="top">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-icon v-bind="props" size="small" color="grey">mdi-information-outline</v-icon>
|
||||
</template>
|
||||
<span>{{ tm('dialogs.uninstall.dataHint') }}</span>
|
||||
</v-tooltip>
|
||||
</template>
|
||||
</v-checkbox>
|
||||
|
||||
<v-alert
|
||||
v-if="deleteConfig || deleteData"
|
||||
type="warning"
|
||||
variant="tonal"
|
||||
density="compact"
|
||||
class="mt-4"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-icon>mdi-alert</v-icon>
|
||||
</template>
|
||||
{{ t('messages.validation.operation_cannot_be_undone') }}
|
||||
</v-alert>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn
|
||||
color="grey"
|
||||
variant="text"
|
||||
@click="handleCancel"
|
||||
>
|
||||
{{ t('core.common.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="error"
|
||||
variant="elevated"
|
||||
@click="handleConfirm"
|
||||
>
|
||||
{{ t('core.common.confirm') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue';
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
});
|
||||
|
||||
const emit = defineEmits(['update:modelValue', 'confirm', 'cancel']);
|
||||
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n('features/extension');
|
||||
|
||||
const show = ref(props.modelValue);
|
||||
const deleteConfig = ref(false);
|
||||
const deleteData = ref(false);
|
||||
|
||||
watch(() => props.modelValue, (val) => {
|
||||
show.value = val;
|
||||
if (val) {
|
||||
// 重置选项
|
||||
deleteConfig.value = false;
|
||||
deleteData.value = false;
|
||||
}
|
||||
});
|
||||
|
||||
watch(show, (val) => {
|
||||
emit('update:modelValue', val);
|
||||
});
|
||||
|
||||
const handleConfirm = () => {
|
||||
emit('confirm', {
|
||||
deleteConfig: deleteConfig.value,
|
||||
deleteData: deleteData.value,
|
||||
});
|
||||
show.value = false;
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
emit('cancel');
|
||||
show.value = false;
|
||||
};
|
||||
</script>
|
||||
@@ -18,6 +18,5 @@
|
||||
"refresh": "Refresh",
|
||||
"submit": "Submit",
|
||||
"reset": "Reset",
|
||||
"clear": "Clear",
|
||||
"save": "Save"
|
||||
"clear": "Clear"
|
||||
}
|
||||
@@ -56,9 +56,6 @@
|
||||
"linkText": "View master branch commit history (click copy on the right to copy)",
|
||||
"confirm": "Confirm Switch"
|
||||
},
|
||||
"releaseNotes": {
|
||||
"title": "Release Notes"
|
||||
},
|
||||
"dashboardUpdate": {
|
||||
"title": "Update Dashboard to Latest Version Only",
|
||||
"currentVersion": "Current Version",
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
"login": "Login",
|
||||
"username": "Username",
|
||||
"password": "Password",
|
||||
"defaultHint": "Default username and password: astrbot",
|
||||
"logo": {
|
||||
"title": "AstrBot Dashboard",
|
||||
"subtitle": "Welcome"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user