Compare commits

..

5 Commits

310 changed files with 9547 additions and 9115 deletions

View File

@@ -1,9 +1,9 @@
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# github actions
.git
# github acions
.github/
.*ignore
.git/
# User-specific stuff
.idea/
# Byte-compiled / optimized / DLL files
@@ -15,10 +15,10 @@ env/
venv*/
ENV/
.conda/
README*.md
dashboard/
data/
changelogs/
tests/
.ruff_cache/
.astrbot
astrbot.lock
.astrbot

View File

@@ -6,13 +6,13 @@ body:
- type: markdown
attributes:
value: |
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。
- type: textarea
attributes:
label: 发生了什么
description: 描述你遇到的异常
placeholder: >
一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
一个清晰且具体的描述这个异常是什么。
validations:
required: true
@@ -55,7 +55,7 @@ body:
attributes:
label: 报错日志
description: >
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!
placeholder: >
请提供完整的报错日志或截图。
validations:

54
.gitignore vendored
View File

@@ -1,49 +1,35 @@
# Python related
__pycache__
.mypy_cache
.venv*
.conda/
uv.lock
.coverage
# IDE and editors
.vscode
.idea
# Logs and temporary files
botpy.log
logs/
temp
cookies.json
# Data files
.vscode
.venv*
.idea
data_v2.db
data_v3.db
data
configs/session
configs/config.yaml
**/.DS_Store
temp
cmd_config.json
# Plugins and packages
data
cookies.json
logs/
addons/plugins
packages/python_interpreter/workplace
tests/astrbot_plugin_openai
.coverage
# Dashboard
tests/astrbot_plugin_openai
chroma
dashboard/node_modules/
dashboard/dist/
.DS_Store
package-lock.json
package.json
# Operating System
**/.DS_Store
.DS_Store
# AstrBot specific
.astrbot
astrbot.lock
# Other
chroma
venv/*
packages/python_interpreter/workplace
.venv/*
.conda/
.idea
pytest.ini
.astrbot
uv.lock

View File

@@ -12,21 +12,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
bash \
ffmpeg \
curl \
gnupg \
git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
&& rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y curl gnupg \
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y nodejs
RUN apt-get update && apt-get install -y curl gnupg && \
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
apt-get install -y nodejs && \
rm -rf /var/lib/apt/lists/*
RUN python -m pip install uv \
&& echo "3.11" > .python-version
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pilk --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]
CMD [ "python", "main.py" ]

35
Dockerfile_with_node Normal file
View File

@@ -0,0 +1,35 @@
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
curl \
unzip \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Installation of Node.js
ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \
nvm install 22 && \
nvm use 22
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]

114
README.md
View File

@@ -119,73 +119,83 @@ uv run main.py
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## 支持的消息平台
## 消息平台支持情况
**官方维护**
- QQ (官方平台 & OneBot)
- Telegram
- 企微应用 & 企微智能机器人
- 微信客服 & 微信公众号
- 飞书
- 钉钉
- Slack
- Discord
- Satori
- Misskey
- Whatsapp (将支持)
- LINE (将支持)
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方平台) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企微应用 | ✔ |
| 企微智能机器人 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | ✔ |
| Discord | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
| Whatsapp | 将支持 |
| LINE | 将支持 |
**社区维护**
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
| 平台 | 支持性 |
| -------- | ------- |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | ✔ |
## 支持的模型服务
## ⚡ 提供商支持情况
**大模型服务**
- OpenAI 及兼容服务
- Anthropic
- Google Gemini
- Moonshot AI
- 智谱 AI
- DeepSeek
- Ollama (本地部署)
- LM Studio (本地部署)
- [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [小马算力](https://www.tokenpony.cn/3YPyf)
- [硅基流动](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
**LLMOps 平台**
- Dify
- 阿里云百炼应用
- Coze
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | |
| Google Gemini | ✔ | |
| Moonshot AI | ✔ | |
| 智谱 AI | ✔ | |
| DeepSeek | ✔ | |
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
| 硅基流动 | ✔ | |
| PPIO 派欧云 | ✔ | |
| ModelScope | ✔ | |
| OneAPI | ✔ | |
| Dify | ✔ | |
| 阿里云百炼应用 | ✔ | |
| Coze | ✔ | |
**语音转文本服务**
- OpenAI Whisper
- SenseVoice
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| Whisper | ✔ | 支持 API、本地部署 |
| SenseVoice | ✔ | 本地部署 |
**文本转语音服务**
- OpenAI TTS
- Gemini TTS
- GPT-Sovits-Inference
- GPT-Sovits
- FishAudio
- Edge TTS
- 阿里云百炼 TTS
- Azure TTS
- Minimax TTS
- 火山引擎 TTS
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI TTS | ✔ | |
| Gemini TTS | ✔ | |
| GSVI | ✔ | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | GPT-Sovits |
| FishAudio | ✔ | |
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | |
| Azure TTS | ✔ | |
| Minimax TTS | ✔ | |
| 火山引擎 TTS | ✔ | |
## ❤️ 贡献
@@ -219,7 +229,7 @@ pre-commit install
## ⭐ Star History
> [!TIP]
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我们维护这个开源项目的动力 <3
<div align="center">

0
astrbot.lock Normal file
View File

View File

@@ -1,19 +1,20 @@
from astrbot import logger
from astrbot.core import html_renderer, sp
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.register import register_agent as agent
from astrbot import logger
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_agent as agent
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
__all__ = [
"AstrBotConfig",
"BaseFunctionToolExecutor",
"FunctionTool",
"ToolSet",
"agent",
"logger",
"html_renderer",
"llm_tool",
"logger",
"agent",
"sp",
"ToolSet",
"FunctionTool",
"BaseFunctionToolExecutor",
]

View File

@@ -1,17 +1,18 @@
from astrbot.core.message.message_event_result import (
MessageEventResult,
MessageChain,
CommandResult,
EventResultType,
MessageChain,
MessageEventResult,
ResultContentType,
)
from astrbot.core.platform import AstrMessageEvent
__all__ = [
"AstrMessageEvent",
"MessageEventResult",
"MessageChain",
"CommandResult",
"EventResultType",
"MessageChain",
"MessageEventResult",
"AstrMessageEvent",
"ResultContentType",
]

View File

@@ -1,52 +1,51 @@
from astrbot.core.star.filter.custom_filter import CustomFilter
from astrbot.core.star.filter.event_message_type import (
EventMessageType,
EventMessageTypeFilter,
)
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
PlatformAdapterTypeFilter,
)
from astrbot.core.star.register import register_after_message_sent as after_message_sent
from astrbot.core.star.register import register_command as command
from astrbot.core.star.register import register_command_group as command_group
from astrbot.core.star.register import register_custom_filter as custom_filter
from astrbot.core.star.register import register_event_message_type as event_message_type
from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_on_astrbot_loaded as on_astrbot_loaded
from astrbot.core.star.register import (
register_on_decorating_result as on_decorating_result,
)
from astrbot.core.star.register import register_on_llm_request as on_llm_request
from astrbot.core.star.register import register_on_llm_response as on_llm_response
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
from astrbot.core.star.register import register_permission_type as permission_type
from astrbot.core.star.register import (
register_command as command,
register_command_group as command_group,
register_event_message_type as event_message_type,
register_regex as regex,
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_platform_loaded as on_platform_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent,
)
from astrbot.core.star.register import register_regex as regex
from astrbot.core.star.filter.event_message_type import (
EventMessageTypeFilter,
EventMessageType,
)
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterTypeFilter,
PlatformAdapterType,
)
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
from astrbot.core.star.filter.custom_filter import CustomFilter
__all__ = [
"CustomFilter",
"EventMessageType",
"EventMessageTypeFilter",
"PermissionType",
"PermissionTypeFilter",
"PlatformAdapterType",
"PlatformAdapterTypeFilter",
"after_message_sent",
"command",
"command_group",
"custom_filter",
"event_message_type",
"llm_tool",
"on_astrbot_loaded",
"on_decorating_result",
"on_llm_request",
"on_llm_response",
"on_platform_loaded",
"permission_type",
"platform_adapter_type",
"regex",
"platform_adapter_type",
"permission_type",
"EventMessageTypeFilter",
"EventMessageType",
"PlatformAdapterTypeFilter",
"PlatformAdapterType",
"PermissionTypeFilter",
"CustomFilter",
"custom_filter",
"PermissionType",
"on_astrbot_loaded",
"on_platform_loaded",
"on_llm_request",
"llm_tool",
"on_decorating_result",
"after_message_sent",
"on_llm_response",
]

View File

@@ -1,22 +1,23 @@
from astrbot.core.message.components import *
from astrbot.core.platform import (
AstrBotMessage,
AstrMessageEvent,
Group,
Platform,
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
Group,
)
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *
__all__ = [
"AstrBotMessage",
"AstrMessageEvent",
"Group",
"Platform",
"AstrBotMessage",
"MessageMember",
"MessageType",
"Platform",
"PlatformMetadata",
"register_platform_adapter",
"Group",
]

View File

@@ -1,17 +1,17 @@
from astrbot.core.provider import Personality, Provider, STTProvider
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMetaData,
ProviderRequest,
ProviderType,
ProviderMetaData,
LLMResponse,
)
__all__ = [
"LLMResponse",
"Personality",
"Provider",
"ProviderMetaData",
"STTProvider",
"Personality",
"ProviderRequest",
"ProviderType",
"STTProvider",
"ProviderMetaData",
"LLMResponse",
]

View File

@@ -1,7 +1,8 @@
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
from astrbot.core.star.register import (
register_star as register, # 注册插件Star
)
__all__ = ["Context", "Star", "StarTools", "register"]
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
__all__ = ["register", "Context", "Star", "StarTools"]

View File

@@ -1,7 +1,7 @@
from astrbot.core.utils.session_waiter import (
SessionController,
SessionWaiter,
SessionController,
session_waiter,
)
__all__ = ["SessionController", "SessionWaiter", "session_waiter"]
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]

View File

@@ -1,11 +1,11 @@
"""AstrBot CLI入口"""
import sys
"""
AstrBot CLI入口
"""
import click
import sys
from . import __version__
from .commands import conf, init, plug, run
from .commands import init, run, plug, conf
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.

View File

@@ -1,6 +1,6 @@
from .cmd_conf import conf
from .cmd_init import init
from .cmd_plug import plug
from .cmd_run import run
from .cmd_plug import plug
from .cmd_conf import conf
__all__ = ["conf", "init", "plug", "run"]
__all__ = ["init", "run", "plug", "conf"]

View File

@@ -1,12 +1,9 @@
import hashlib
import json
import zoneinfo
from collections.abc import Callable
from typing import Any
import click
from ..utils import check_astrbot_root, get_astrbot_root
import hashlib
import zoneinfo
from typing import Any, Callable
from ..utils import get_astrbot_root, check_astrbot_root
def _validate_log_level(value: str) -> str:
@@ -14,7 +11,7 @@ def _validate_log_level(value: str) -> str:
value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException(
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
)
return value
@@ -76,7 +73,7 @@ def _load_config() -> dict[str, Any]:
root = get_astrbot_root()
if not check_astrbot_root(root):
raise click.ClickException(
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
config_path = root / "data" / "cmd_config.json"
@@ -91,7 +88,7 @@ def _load_config() -> dict[str, Any]:
try:
return json.loads(config_path.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError as e:
raise click.ClickException(f"配置文件解析失败: {e!s}")
raise click.ClickException(f"配置文件解析失败: {str(e)}")
def _save_config(config: dict[str, Any]) -> None:
@@ -99,8 +96,7 @@ def _save_config(config: dict[str, Any]) -> None:
config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
)
@@ -112,7 +108,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
obj[part] = {}
elif not isinstance(obj[part], dict):
raise click.ClickException(
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典",
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
)
obj = obj[part]
obj[parts[-1]] = value
@@ -144,6 +140,7 @@ def conf():
- callback_api_base: 回调接口基址
"""
pass
@conf.command(name="set")
@@ -151,7 +148,7 @@ def conf():
@click.argument("value")
def set_config(key: str, value: str):
"""设置配置项的值"""
if key not in CONFIG_VALIDATORS:
if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}")
config = _load_config()
@@ -173,17 +170,17 @@ def set_config(key: str, value: str):
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"设置配置失败: {e!s}")
raise click.UsageError(f"设置配置失败: {str(e)}")
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str | None = None):
def get_config(key: str = None):
"""获取配置项的值不提供key则显示所有可配置项"""
config = _load_config()
if key:
if key not in CONFIG_VALIDATORS:
if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}")
try:
@@ -194,10 +191,10 @@ def get_config(key: str | None = None):
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"获取配置失败: {e!s}")
raise click.UsageError(f"获取配置失败: {str(e)}")
else:
click.echo("当前配置:")
for key in CONFIG_VALIDATORS:
for key in CONFIG_VALIDATORS.keys():
try:
value = (
"********"

View File

@@ -1,5 +1,4 @@
import asyncio
from pathlib import Path
import click
from filelock import FileLock, Timeout
@@ -7,14 +6,14 @@ from filelock import FileLock, Timeout
from ..utils import check_dashboard, get_astrbot_root
async def initialize_astrbot(astrbot_root: Path) -> None:
async def initialize_astrbot(astrbot_root) -> None:
"""执行 AstrBot 初始化逻辑"""
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
click.echo(f"Current Directory: {astrbot_root}")
click.echo(
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。",
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
)
if click.confirm(
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",

View File

@@ -1,29 +1,31 @@
import re
import shutil
from pathlib import Path
import click
import shutil
from ..utils import (
PluginStatus,
get_git_repo,
build_plug_list,
manage_plugin,
PluginStatus,
check_astrbot_root,
get_astrbot_root,
get_git_repo,
manage_plugin,
)
@click.group()
def plug():
"""插件管理"""
pass
def _get_data_path() -> Path:
base = get_astrbot_root()
if not check_astrbot_root(base):
raise click.ClickException(
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
return (base / "data").resolve()
@@ -39,7 +41,7 @@ def display_plugins(plugins, title=None, color=None):
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
click.echo(
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
f"{p['author']:<15} {desc:<30}",
f"{p['author']:<15} {desc:<30}"
)
@@ -76,7 +78,7 @@ def new(name: str):
f"desc: {desc}\n"
f"version: {version}\n"
f"author: {author}\n"
f"repo: {repo}\n",
f"repo: {repo}\n"
)
# 重写 README.md
@@ -84,7 +86,7 @@ def new(name: str):
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
# 重写 main.py
with open(plug_path / "main.py", encoding="utf-8") as f:
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
content = f.read()
new_content = content.replace(

View File

@@ -1,18 +1,19 @@
import asyncio
import os
import sys
import traceback
from pathlib import Path
import click
import asyncio
import traceback
from filelock import FileLock, Timeout
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
async def run_astrbot(astrbot_root: Path):
"""运行 AstrBot"""
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core import logger, LogManager, LogBroker, db_helper
from astrbot.core.initial_loader import InitialLoader
await check_dashboard(astrbot_root / "data")
@@ -37,7 +38,7 @@ def run(reload: bool, port: str) -> None:
if not check_astrbot_root(astrbot_root):
raise click.ClickException(
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)

View File

@@ -1,18 +1,18 @@
from .basic import (
get_astrbot_root,
check_astrbot_root,
check_dashboard,
get_astrbot_root,
)
from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
from .version_comparator import VersionComparator
__all__ = [
"PluginStatus",
"VersionComparator",
"build_plug_list",
"get_astrbot_root",
"check_astrbot_root",
"check_dashboard",
"get_astrbot_root",
"get_git_repo",
"manage_plugin",
"build_plug_list",
"VersionComparator",
"PluginStatus",
]

View File

@@ -21,9 +21,8 @@ def get_astrbot_root() -> Path:
async def check_dashboard(astrbot_root: Path) -> None:
"""检查是否安装了dashboard"""
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from .version_comparator import VersionComparator
try:
@@ -49,18 +48,19 @@ async def check_dashboard(astrbot_root: Path) -> None:
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
click.echo("管理面板已是最新版本")
return
try:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return
else:
try:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return
except FileNotFoundError:
click.echo("初始化管理面板目录...")
try:

View File

@@ -1,14 +1,14 @@
import shutil
import tempfile
import httpx
import yaml
from enum import Enum
from io import BytesIO
from pathlib import Path
from zipfile import ZipFile
import click
import httpx
import yaml
from .version_comparator import VersionComparator
@@ -32,8 +32,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
try:
with httpx.Client(
proxy=proxy if proxy else None,
follow_redirects=True,
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(release_url)
resp.raise_for_status()
@@ -56,8 +55,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
# 下载并解压
with httpx.Client(
proxy=proxy if proxy else None,
follow_redirects=True,
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(download_url)
if (
@@ -91,7 +89,6 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
Returns:
dict: 包含元数据的字典,如果读取失败则返回空字典
"""
yaml_path = plugin_dir / "metadata.yaml"
if yaml_path.exists():
@@ -110,7 +107,6 @@ def build_plug_list(plugins_dir: Path) -> list:
Returns:
list: 包含插件信息的字典列表
"""
# 获取本地插件信息
result = []
@@ -137,7 +133,7 @@ def build_plug_list(plugins_dir: Path) -> list:
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
},
}
)
# 获取在线插件列表
@@ -157,7 +153,7 @@ def build_plug_list(plugins_dir: Path) -> list:
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
},
}
)
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)
@@ -172,8 +168,7 @@ def build_plug_list(plugins_dir: Path) -> list:
)
if (
VersionComparator.compare_version(
local_plugin["version"],
online_plugin["version"],
local_plugin["version"], online_plugin["version"]
)
< 0
):
@@ -191,10 +186,7 @@ def build_plug_list(plugins_dir: Path) -> list:
def manage_plugin(
plugin: dict,
plugins_dir: Path,
is_update: bool = False,
proxy: str | None = None,
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
) -> None:
"""安装或更新插件
@@ -203,7 +195,6 @@ def manage_plugin(
plugins_dir (Path): 插件目录
is_update (bool, optional): 是否为更新操作. 默认为 False
proxy (str, optional): 代理服务器地址
"""
plugin_name = plugin["name"]
repo_url = plugin["repo"]
@@ -221,26 +212,26 @@ def manage_plugin(
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
# 备份现有插件
if is_update and backup_path is not None and backup_path.exists():
if is_update and backup_path.exists():
shutil.rmtree(backup_path)
if is_update and backup_path is not None:
if is_update:
shutil.copytree(target_path, backup_path)
try:
click.echo(
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...",
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
)
get_git_repo(repo_url, target_path, proxy)
# 更新成功,删除备份
if is_update and backup_path is not None and backup_path.exists():
if is_update and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
except Exception as e:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
if is_update and backup_path is not None and backup_path.exists():
if is_update and backup_path.exists():
shutil.move(backup_path, target_path)
raise click.ClickException(
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
)

View File

@@ -1,4 +1,6 @@
"""拷贝自 astrbot.core.utils.version_comparator"""
"""
拷贝自 astrbot.core.utils.version_comparator
"""
import re
@@ -40,15 +42,15 @@ class VersionComparator:
for i in range(length):
if v1_parts[i] > v2_parts[i]:
return 1
if v1_parts[i] < v2_parts[i]:
elif v1_parts[i] < v2_parts[i]:
return -1
# 比较预发布标签
if v1_prerelease is None and v2_prerelease is not None:
return 1 # 没有预发布标签的版本高于有预发布标签的版本
if v1_prerelease is not None and v2_prerelease is None:
elif v1_prerelease is not None and v2_prerelease is None:
return -1 # 有预发布标签的版本低于没有预发布标签的版本
if v1_prerelease is not None and v2_prerelease is not None:
elif v1_prerelease is not None and v2_prerelease is not None:
len_pre = max(len(v1_prerelease), len(v2_prerelease))
for i in range(len_pre):
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
@@ -56,21 +58,21 @@ class VersionComparator:
if p1 is None and p2 is not None:
return -1
if p1 is not None and p2 is None:
elif p1 is not None and p2 is None:
return 1
if isinstance(p1, int) and isinstance(p2, str):
elif isinstance(p1, int) and isinstance(p2, str):
return -1
if isinstance(p1, str) and isinstance(p2, int):
elif isinstance(p1, str) and isinstance(p2, int):
return 1
if isinstance(p1, int) and isinstance(p2, int):
elif isinstance(p1, int) and isinstance(p2, int):
if p1 > p2:
return 1
if p1 < p2:
elif p1 < p2:
return -1
elif isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
if p1 < p2:
elif p1 < p2:
return -1
return 0 # 预发布标签完全相同

View File

@@ -1,14 +1,12 @@
import os
from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.file_token_service import FileTokenService
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.utils.shared_preferences import SharedPreferences
from .log import LogManager, LogBroker # noqa
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from .log import LogBroker, LogManager # noqa
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
from astrbot.core.file_token_service import FileTokenService
from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹

View File

@@ -1,9 +1,8 @@
from dataclasses import dataclass
from typing import Generic
from .hooks import BaseAgentRunHooks
from .run_context import TContext
from .tool import FunctionTool
from typing import Generic
from .run_context import TContext
from .hooks import BaseAgentRunHooks
@dataclass

View File

@@ -1,18 +1,14 @@
from typing import Generic
from .tool import FunctionTool
from .agent import Agent
from .run_context import TContext
from .tool import FunctionTool
class HandoffTool(FunctionTool, Generic[TContext]):
"""Handoff tool for delegating tasks to another agent."""
def __init__(
self,
agent: Agent[TContext],
parameters: dict | None = None,
**kwargs,
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
):
self.agent = agent
super().__init__(

View File

@@ -1,13 +1,12 @@
from typing import Generic
import mcp
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.provider.entities import LLMResponse
from dataclasses import dataclass
from .run_context import ContextWrapper, TContext
from typing import Generic
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.agent.tool import FunctionTool
@dataclass
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_tool_start(
@@ -24,7 +23,5 @@ class BaseAgentRunHooks(Generic[TContext]):
tool_result: mcp.types.CallToolResult | None,
): ...
async def on_agent_done(
self,
run_context: ContextWrapper[TContext],
llm_response: LLMResponse,
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
): ...

View File

@@ -1,16 +1,11 @@
import asyncio
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe
from .run_context import TContext
from .tool import FunctionTool
try:
import mcp
from mcp.client.sse import sse_client
@@ -21,13 +16,13 @@ try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if config.get("mcpServers"):
if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
@@ -76,7 +71,8 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
else:
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
@@ -88,7 +84,8 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
else:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
@@ -99,7 +96,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: mcp.ClientSession | None = None
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name: str | None = None
@@ -118,7 +115,6 @@ class MCPClient:
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = _prepare_config(mcp_server_config.copy())
@@ -148,7 +144,7 @@ class MCPClient:
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context,
self._streams_context
)
# Create a new client session
@@ -158,12 +154,12 @@ class MCPClient:
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
),
)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5),
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
@@ -173,7 +169,7 @@ class MCPClient:
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context,
self._streams_context
)
# Create a new client session
@@ -184,7 +180,7 @@ class MCPClient:
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
),
)
)
else:
@@ -210,7 +206,7 @@ class MCPClient:
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport),
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
@@ -226,34 +222,3 @@ class MCPClient:
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done
class MCPTool(FunctionTool, Generic[TContext]):
"""A function tool that calls an MCP service."""
def __init__(
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
):
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=mcp_tool.inputSchema,
)
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client
self.mcp_server_name = mcp_server_name
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> mcp.types.CallToolResult:
session = self.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=self.mcp_tool.name,
arguments=kwargs,
read_timeout_seconds=timedelta(
seconds=context.tool_call_timeout,
),
)
return res

View File

@@ -1,168 +0,0 @@
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
# License: Apache License 2.0
from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import core_schema
class ContentPart(BaseModel):
"""A part of the content in a message."""
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
type: str
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
type_value = getattr(cls, "type", None)
if type_value is None or not isinstance(type_value, str):
raise ValueError(invalid_subclass_error_msg)
cls.__content_part_registry[type_value] = cls
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# If we're dealing with the base ContentPart class, use custom validation
if cls.__name__ == "ContentPart":
def validate_content_part(value: Any) -> Any:
# if it's already an instance of a ContentPart subclass, return it
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
return value
# if it's a dict with a type field, dispatch to the appropriate subclass
if isinstance(value, dict) and "type" in value:
type_value: Any | None = cast(dict[str, Any], value).get("type")
if not isinstance(type_value, str):
raise ValueError(f"Cannot validate {value} as ContentPart")
target_class = cls.__content_part_registry[type_value]
return target_class.model_validate(value)
raise ValueError(f"Cannot validate {value} as ContentPart")
return core_schema.no_info_plain_validator_function(validate_content_part)
# for subclasses, use the default schema
return handler(source_type)
class TextPart(ContentPart):
"""
>>> TextPart(text="Hello, world!").model_dump()
{'type': 'text', 'text': 'Hello, world!'}
"""
type: str = "text"
text: str
class ImageURLPart(ContentPart):
"""
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
{'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
"""
class ImageURL(BaseModel):
url: str
"""The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
id: str | None = None
"""The ID of the image, to allow LLMs to distinguish different images."""
type: str = "image_url"
image_url: str
class AudioURLPart(ContentPart):
"""
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
"""
class AudioURL(BaseModel):
url: str
"""The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
id: str | None = None
"""The ID of the audio, to allow LLMs to distinguish different audios."""
type: str = "audio_url"
audio_url: AudioURL
class ToolCall(BaseModel):
"""
A tool call requested by the assistant.
>>> ToolCall(
... id="123",
... function=ToolCall.FunctionBody(
... name="function",
... arguments="{}"
... ),
... ).model_dump()
{'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
"""
class FunctionBody(BaseModel):
name: str
arguments: str | None
type: Literal["function"] = "function"
id: str
"""The ID of the tool call."""
function: FunctionBody
"""The function body of the tool call."""
class ToolCallPart(BaseModel):
"""A part of the tool call."""
arguments_part: str | None = None
"""A part of the arguments of the tool call."""
class Message(BaseModel):
"""A message in a conversation."""
role: Literal[
"system",
"user",
"assistant",
"tool",
]
content: str | list[ContentPart]
"""The content of the message."""
class AssistantMessageSegment(Message):
"""A message segment from the assistant."""
role: Literal["assistant"] = "assistant"
tool_calls: list[ToolCall] | list[dict] | None = None
class ToolCallMessageSegment(Message):
"""A message segment representing a tool call."""
role: Literal["tool"] = "tool"
tool_call_id: str
class UserMessageSegment(Message):
"""A message segment from the user."""
role: Literal["user"] = "user"
class SystemMessageSegment(Message):
"""A message segment from the system."""
role: Literal["system"] = "system"

View File

@@ -1,6 +1,5 @@
import typing as T
from dataclasses import dataclass
import typing as T
from astrbot.core.message.message_event_result import MessageChain

View File

@@ -1,8 +1,9 @@
from dataclasses import dataclass
from typing import Any, Generic
from typing_extensions import TypeVar
from astrbot.core.platform.astr_message_event import AstrMessageEvent
TContext = TypeVar("TContext", default=Any)
@@ -11,7 +12,7 @@ class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
tool_call_timeout: int = 60 # Default tool call timeout in seconds
event: AstrMessageEvent
NoContext = ContextWrapper[None]

View File

@@ -1,15 +1,13 @@
import abc
import typing as T
from enum import Enum, auto
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponse
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
class AgentState(Enum):
"""Defines the state of the agent."""
@@ -30,26 +28,31 @@ class BaseAgentRunner(T.Generic[TContext]):
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
"""Reset the agent to its initial state.
"""
Reset the agent to its initial state.
This method should be called before starting a new run.
"""
...
@abc.abstractmethod
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
"""Process a single step of the agent."""
"""
Process a single step of the agent.
"""
...
@abc.abstractmethod
def done(self) -> bool:
"""Check if the agent has completed its task.
"""
Check if the agent has completed its task.
Returns True if the agent is done, False otherwise.
"""
...
@abc.abstractmethod
def get_final_llm_resp(self) -> LLMResponse | None:
"""Get the final observation from the agent.
"""
Get the final observation from the agent.
This method should be called after the agent is done.
"""
...

View File

@@ -1,33 +1,31 @@
import sys
import traceback
import typing as T
from mcp.types import (
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
)
from astrbot import logger
from .base import BaseAgentRunner, AgentResponse, AgentState
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponseData
from astrbot.core.provider.provider import Provider
from astrbot.core.message.message_event_result import (
MessageChain,
)
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, ToolCallMessageSegment
from ..response import AgentResponseData
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner
from mcp.types import (
TextContent,
ImageContent,
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
CallToolResult,
)
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
@@ -72,7 +70,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
@override
async def step(self):
"""Process a single step of the agent.
"""
Process a single step of the agent.
This method should return the result of the step.
"""
if not self.req:
@@ -100,7 +99,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text),
chain=MessageChain().message(llm_response.completion_text)
),
)
continue
@@ -121,8 +120,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
type="err",
data=AgentResponseData(
chain=MessageChain().message(
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}",
),
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
)
),
)
@@ -145,18 +144,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield AgentResponse(
type="llm_result",
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text),
chain=MessageChain().message(llm_resp.completion_text)
),
)
# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
for tool_call_name, tool_call_id in zip(
llm_resp.tools_call_name, llm_resp.tools_call_ids
):
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain().message(f"🔨 用工具: {tool_call_name}"),
chain=MessageChain().message(f"🔨 正在使用工具: {tool_call_name} ({tool_call_id})")
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
@@ -170,7 +171,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 将结果添加到上下文中
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
tool_calls=llm_resp.to_openai_to_calls_model(),
role="assistant",
tool_calls=llm_resp.to_openai_tool_calls(),
content=llm_resp.completion_text,
),
tool_calls_result=tool_call_result_blocks,
@@ -205,7 +207,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
role="tool",
tool_call_id=func_tool_id,
content=f"error: 未找到工具 {func_tool_name}",
),
)
)
continue
@@ -214,7 +216,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 获取实际的 handler 函数
if func_tool.handler:
logger.debug(
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}",
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}"
)
if func_tool.parameters and func_tool.parameters.get("properties"):
expected_params = set(func_tool.parameters["properties"].keys())
@@ -227,21 +229,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 记录被忽略的参数
ignored_params = set(func_tool_args.keys()) - set(
valid_params.keys(),
valid_params.keys()
)
if ignored_params:
logger.warning(
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}",
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}"
)
else:
# 如果没有 handler如 MCP 工具),使用所有参数
valid_params = func_tool_args
logger.warning(f"工具 {func_tool_name} 没有 handler使用所有参数")
try:
await self.agent_hooks.on_tool_start(
self.run_context,
func_tool,
valid_params,
self.run_context, func_tool, valid_params
)
except Exception as e:
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
@@ -256,94 +257,79 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
async for resp in executor: # type: ignore
if isinstance(resp, CallToolResult):
res = resp
_final_resp = resp
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
),
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
),
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data,
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
),
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
),
)
yield MessageChain(
type="tool_direct_result",
).base64_image(resource.blob)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
),
)
yield MessageChain().message("返回的数据类型不受支持。")
content = res.content
aggr_text_content = ""
for cont in content:
if isinstance(cont, TextContent):
aggr_text_content += cont.text
yield MessageChain().message(cont.text)
elif isinstance(cont, ImageContent):
aggr_text_content += "\n返回了图片(已直接发送给用户)\n"
yield MessageChain(
type="tool_direct_result"
).base64_image(cont.data)
elif isinstance(cont, EmbeddedResource):
resource = cont.resource
if isinstance(resource, TextResourceContents):
aggr_text_content += resource.text
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
aggr_text_content += (
"\n返回了图片(已直接发送给用户)\n"
)
yield MessageChain(
type="tool_direct_result"
).base64_image(resource.blob)
else:
aggr_text_content += "\n返回的数据类型不受支持。\n"
yield MessageChain().message(
"返回的数据类型不受支持。"
)
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=aggr_text_content,
)
)
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
# 发送消息逻辑在 ToolExecutor 中处理了。
logger.warning(
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
)
self._transition_state(AgentState.DONE)
if res := self.run_context.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
else:
# 不应该出现其他类型
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool,
func_tool_args,
_final_resp,
self.run_context, func_tool, func_tool_args, _final_resp
)
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
self.run_context.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {e!s}",
),
content=f"error: {str(e)}",
)
)
# 处理函数调用响应

View File

@@ -1,75 +1,55 @@
from collections.abc import Awaitable, Callable
from typing import Any, Generic
import jsonschema
import mcp
from dataclasses import dataclass
from deprecated import deprecated
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
from typing import Awaitable, Callable, Literal, Any, Optional
from .mcp_client import MCPClient
@dataclass
class ToolSchema:
"""A class representing the schema of a tool for function calling."""
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str
"""The name of the tool."""
description: str
"""The description of the tool."""
parameters: ParametersType
"""The parameters of the tool, in JSON Schema format."""
@model_validator(mode="after")
def validate_parameters(self) -> "ToolSchema":
jsonschema.validate(
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
)
return self
@dataclass
class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling."""
parameters: dict | None = None
description: str | None = None
handler: Callable[..., Awaitable[Any]] | None = None
"""a callable that implements the tool's functionality. It should be an async function."""
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""
The module path of the handler function. This is empty when the origin is mcp.
This field must be retained, as the handler will be wrapped in functools.partial during initialization,
causing the handler's __module__ to be functools
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""
Whether the tool is active. This field is a special field for AstrBot.
You can ignore it when integrating with other frameworks.
"""
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str | None = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient | None = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> str | mcp.types.CallToolResult:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
def __dict__(self) -> dict[str, Any]:
"""将 FunctionTool 转换为字典格式"""
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
"origin": self.origin,
"mcp_server_name": self.mcp_server_name,
}
class ToolSet:
"""A set of function tools that can be used in function calling.
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
"""
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
@@ -91,7 +71,7 @@ class ToolSet:
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
def get_tool(self, name: str) -> FunctionTool | None:
def get_tool(self, name: str) -> Optional[FunctionTool]:
"""Get a tool by its name."""
for tool in self.tools:
if tool.name == name:
@@ -152,8 +132,10 @@ class ToolSet:
}
if (
tool.parameters and tool.parameters.get("properties")
) or not omit_empty_parameter_field:
tool.parameters
and tool.parameters.get("properties")
or not omit_empty_parameter_field
):
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
@@ -203,8 +185,7 @@ class ToolSet:
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"],
set(),
result["type"], set()
):
result["format"] = schema["format"]
else:
@@ -241,7 +222,7 @@ class ToolSet:
tools = []
for tool in self.tools:
d: dict[str, Any] = {
d = {
"name": tool.name,
"description": tool.description,
}

View File

@@ -1,17 +1,11 @@
from collections.abc import AsyncGenerator
from typing import Any, Generic
import mcp
from .run_context import ContextWrapper, TContext
from typing import Any, Generic, AsyncGenerator
from .run_context import TContext, ContextWrapper
from .tool import FunctionTool
class BaseFunctionToolExecutor(Generic[TContext]):
@classmethod
async def execute(
cls,
tool: FunctionTool,
run_context: ContextWrapper[TContext],
**tool_args,
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...

View File

@@ -1,6 +1,4 @@
from dataclasses import dataclass
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@@ -11,4 +9,4 @@ class AstrAgentContext:
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
event: AstrMessageEvent
tool_call_timeout: int = 60 # Default tool call timeout in seconds

View File

@@ -1,14 +1,13 @@
import os
import uuid
from typing import TypedDict, TypeVar
from astrbot.core import AstrBotConfig, logger
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
from astrbot.core.utils.shared_preferences import SharedPreferences
from typing import TypeVar, TypedDict
_VT = TypeVar("_VT")
@@ -49,10 +48,7 @@ class AstrBotConfigManager:
"""获取所有的 abconf 数据"""
if self.abconf_data is None:
self.abconf_data = self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
"abconf_mapping", {}, scope="global", scope_id="global"
)
return self.abconf_data
@@ -68,7 +64,7 @@ class AstrBotConfigManager:
self.confs[uuid_] = conf
else:
logger.warning(
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping.",
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
)
continue
@@ -77,7 +73,6 @@ class AstrBotConfigManager:
Returns:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "path": str, "name": str }
abconf_data = self._get_abconf_data()
@@ -108,10 +103,7 @@ class AstrBotConfigManager:
) -> None:
"""保存配置文件的映射关系"""
abconf_data = self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
"abconf_mapping", {}, scope="global", scope_id="global"
)
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
@@ -185,17 +177,13 @@ class AstrBotConfigManager:
Raises:
ValueError: 如果试图删除默认配置文件
"""
if conf_id == "default":
raise ValueError("不能删除默认配置文件")
# 从映射中移除
abconf_data = self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
@@ -203,8 +191,7 @@ class AstrBotConfigManager:
# 获取配置文件路径
conf_path = os.path.join(
get_astrbot_config_path(),
abconf_data[conf_id]["path"],
get_astrbot_config_path(), abconf_data[conf_id]["path"]
)
# 删除配置文件
@@ -237,16 +224,12 @@ class AstrBotConfigManager:
Returns:
bool: 更新是否成功
"""
if conf_id == "default":
raise ValueError("不能更新默认配置文件的信息")
abconf_data = self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
@@ -263,10 +246,7 @@ class AstrBotConfigManager:
return True
def g(
self,
umo: str | None = None,
key: str | None = None,
default: _VT = None,
self, umo: str | None = None, key: str | None = None, default: _VT = None
) -> _VT:
"""获取配置项。umo 为 None 时使用默认配置"""
if umo is None:

View File

@@ -1,9 +1,9 @@
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
from .astrbot_config import *
from .default import DB_PATH, DEFAULT_CONFIG, VERSION
__all__ = [
"DB_PATH",
"DEFAULT_CONFIG",
"VERSION",
"DB_PATH",
"AstrBotConfig",
]

View File

@@ -1,11 +1,10 @@
import enum
import os
import json
import logging
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
import enum
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
logger = logging.getLogger("astrbot")
@@ -28,7 +27,7 @@ class AstrBotConfig(dict):
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict | None = None,
schema: dict = None,
):
super().__init__()
@@ -46,7 +45,7 @@ class AstrBotConfig(dict):
json.dump(default_config, f, indent=4, ensure_ascii=False)
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
with open(config_path, encoding="utf-8-sig") as f:
with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
conf = json.loads(conf_str)
@@ -66,7 +65,7 @@ class AstrBotConfig(dict):
for k, v in schema.items():
if v["type"] not in DEFAULT_VALUE_MAP:
raise TypeError(
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}",
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}"
)
if "default" in v:
default = v["default"]
@@ -83,7 +82,7 @@ class AstrBotConfig(dict):
return conf
def check_config_integrity(self, refer_conf: dict, conf: dict, path=""):
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
has_new = False
@@ -98,28 +97,27 @@ class AstrBotConfig(dict):
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
new_conf[key] = value
has_new = True
elif conf[key] is None:
# 配置项为 None使用默认值
new_conf[key] = value
has_new = True
elif isinstance(value, dict):
# 递归检查子配置项
if not isinstance(conf[key], dict):
# 类型不匹配,使用默认值
else:
if conf[key] is None:
# 配置项为 None使用默认值
new_conf[key] = value
has_new = True
elif isinstance(value, dict):
# 递归检查子配置项
if not isinstance(conf[key], dict):
# 类型不匹配,使用默认值
new_conf[key] = value
has_new = True
else:
# 递归检查并同步顺序
child_has_new = self.check_config_integrity(
value, conf[key], path + "." + key if path else key
)
new_conf[key] = conf[key]
has_new |= child_has_new
else:
# 递归检查并同步顺序
child_has_new = self.check_config_integrity(
value,
conf[key],
path + "." + key if path else key,
)
# 直接使用现有配置
new_conf[key] = conf[key]
has_new |= child_has_new
else:
# 直接使用现有配置
new_conf[key] = conf[key]
# 检查是否存在参考配置中没有的配置项
for key in list(conf.keys()):
@@ -142,7 +140,7 @@ class AstrBotConfig(dict):
return has_new
def save_config(self, replace_config: dict | None = None):
def save_config(self, replace_config: Dict = None):
"""将配置写入文件
如果传入 replace_config则将配置替换为 replace_config

View File

@@ -1,10 +1,12 @@
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
"""
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.5.5"
VERSION = "4.5.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -769,7 +771,6 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
"custom_extra_body": {},
"xai_native_search": False,
"modalities": ["text", "image", "tool_use"],
},
"Anthropic": {
@@ -1273,26 +1274,8 @@ CONFIG_METADATA_2 = {
"timeout": 20,
"launch_model_if_not_running": False,
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
"provider": "xinference",
"provider_type": "speech_to_text",
"enable": False,
"api_key": "",
"api_base": "http://127.0.0.1:9997",
"model": "whisper-large-v3",
"timeout": 180,
"launch_model_if_not_running": False,
},
},
"items": {
"xai_native_search": {
"description": "启用原生搜索功能",
"type": "bool",
"hint": "启用后,将通过 xAI 的 Chat Completions 原生 Live Search 进行联网检索(按需计费)。仅对 xAI 提供商生效。",
"condition": {"provider": "xai"},
},
"rerank_api_base": {
"description": "重排序模型 API Base URL",
"type": "string",
@@ -2705,9 +2688,9 @@ CONFIG_METADATA_3_SYSTEM = {
"items": {"type": "string"},
},
},
},
}
},
},
}
}

View File

@@ -1,14 +1,13 @@
"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库.
"""
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import json
from collections.abc import Awaitable, Callable
from astrbot.core import sp
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
from typing import Dict, List, Callable, Awaitable
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
@@ -17,34 +16,31 @@ class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: dict[str, str] = {}
self.session_conversations: Dict[str, str] = {}
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
# 会话删除回调函数列表(用于级联清理,如知识库配置)
self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = []
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
def register_on_session_deleted(
self,
callback: Callable[[str], Awaitable[None]],
self, callback: Callable[[str], Awaitable[None]]
) -> None:
"""注册会话删除回调函数.
"""注册会话删除回调函数
其他模块可以注册回调来响应会话删除事件,实现级联清理。
例如:知识库模块可以注册回调来清理会话的知识库配置。
Args:
callback: 回调函数接收会话ID (unified_msg_origin) 作为参数
"""
self._on_session_deleted_callbacks.append(callback)
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
"""触发会话删除回调.
"""触发会话删除回调
Args:
unified_msg_origin: 会话ID
"""
for callback in self._on_session_deleted_callbacks:
try:
@@ -53,7 +49,7 @@ class ConversationManager:
from astrbot.core import logger
logger.error(
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}",
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
)
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
@@ -79,13 +75,12 @@ class ConversationManager:
title: str | None = None,
persona_id: str | None = None,
) -> str:
"""新建对话,并将当前会话的对话转移到新对话.
"""新建对话,并将当前会话的对话转移到新对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
if not platform_id:
# 如果没有提供 platform_id则从 unified_msg_origin 中解析
@@ -111,22 +106,18 @@ class ConversationManager:
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
async def delete_conversation(
self,
unified_msg_origin: str,
conversation_id: str | None = None,
self, unified_msg_origin: str, conversation_id: str | None = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
@@ -142,7 +133,6 @@ class ConversationManager:
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
@@ -158,7 +148,6 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
ret = self.session_conversations.get(unified_msg_origin, None)
if not ret:
@@ -173,15 +162,13 @@ class ConversationManager:
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation | None:
"""获取会话的对话.
"""获取会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话
Returns:
conversation (Conversation): 对话对象
"""
conv = await self.db.get_conversation_by_id(cid=conversation_id)
if not conv and create_if_not_exists:
@@ -194,22 +181,18 @@ class ConversationManager:
return conv_res
async def get_conversations(
self,
unified_msg_origin: str | None = None,
platform_id: str | None = None,
) -> list[Conversation]:
"""获取对话列表.
self, unified_msg_origin: str | None = None, platform_id: str | None = None
) -> List[Conversation]:
"""获取对话列表
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id可选
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
Returns:
conversations (List[Conversation]): 对话对象列表
"""
convs = await self.db.get_conversations(
user_id=unified_msg_origin,
platform_id=platform_id,
user_id=unified_msg_origin, platform_id=platform_id
)
convs_res = []
for conv in convs:
@@ -225,7 +208,7 @@ class ConversationManager:
search_query: str = "",
**kwargs,
) -> tuple[list[Conversation], int]:
"""获取过滤后的对话列表.
"""获取过滤后的对话列表
Args:
page (int): 页码, 默认为 1
@@ -234,7 +217,6 @@ class ConversationManager:
search_query (str): 搜索查询字符串, 可选
Returns:
conversations (list[Conversation]): 对话对象列表
"""
convs, cnt = await self.db.get_filtered_conversations(
page=page,
@@ -256,14 +238,13 @@ class ConversationManager:
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
) -> None:
"""更新会话的对话.
):
"""更新会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if not conversation_id:
# 如果没有提供 conversation_id则获取当前的
@@ -277,20 +258,16 @@ class ConversationManager:
)
async def update_conversation_title(
self,
unified_msg_origin: str,
title: str,
conversation_id: str | None = None,
) -> None:
"""更新会话的对话标题.
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
):
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Deprecated:
Use `update_conversation` with `title` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
@@ -303,16 +280,15 @@ class ConversationManager:
unified_msg_origin: str,
persona_id: str,
conversation_id: str | None = None,
) -> None:
"""更新会话的对话 Persona ID.
):
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Deprecated:
Use `update_conversation` with `persona_id` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
@@ -320,85 +296,40 @@ class ConversationManager:
persona_id=persona_id,
)
async def add_message_pair(
self,
cid: str,
user_message: UserMessageSegment | dict,
assistant_message: AssistantMessageSegment | dict,
) -> None:
"""Add a user-assistant message pair to the conversation history.
Args:
cid (str): Conversation ID
user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
Raises:
Exception: If the conversation with the given ID is not found
"""
conv = await self.db.get_conversation_by_id(cid=cid)
if not conv:
raise Exception(f"Conversation with id {cid} not found")
history = conv.content or []
if isinstance(user_message, UserMessageSegment):
user_msg_dict = user_message.model_dump()
else:
user_msg_dict = user_message
if isinstance(assistant_message, AssistantMessageSegment):
assistant_msg_dict = assistant_message.model_dump()
else:
assistant_msg_dict = assistant_message
history.append(user_msg_dict)
history.append(assistant_msg_dict)
await self.db.update_conversation(
cid=cid,
content=history,
)
async def get_human_readable_context(
self,
unified_msg_origin: str,
conversation_id: str,
page: int = 1,
page_size: int = 10,
) -> tuple[list[str], int]:
"""获取人类可读的上下文.
self, unified_msg_origin, conversation_id, page=1, page_size=10
):
"""获取人类可读的上下文
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码
page_size (int): 每页大小
"""
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
if not conversation:
return [], 0
history = json.loads(conversation.history)
# contexts_groups 存放按顺序的段落(每个段落是一个 str 列表),
# 之后会被展平成一个扁平的 str 列表返回。
contexts_groups: list[list[str]] = []
temp_contexts: list[str] = []
contexts = []
temp_contexts = []
for record in history:
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
if record.get("content"):
if "content" in record and record["content"]:
temp_contexts.append(f"Assistant: {record['content']}")
elif "tool_calls" in record:
tool_calls_str = json.dumps(
record["tool_calls"],
ensure_ascii=False,
record["tool_calls"], ensure_ascii=False
)
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
else:
temp_contexts.append("Assistant: [未知的内容]")
contexts_groups.insert(0, temp_contexts)
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平分组后的 contexts 列表为单层字符串列表
contexts = [item for sublist in contexts_groups for item in sublist]
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size]

View File

@@ -1,5 +1,5 @@
"""Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作.
"""
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
@@ -9,44 +9,44 @@
3. 执行启动完成事件钩子
"""
import asyncio
import os
import threading
import time
import traceback
import asyncio
import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config, html_renderer
from asyncio import Queue
from astrbot.core import LogBroker, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.star import PluginManager
from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
from . import astrbot_config, html_renderer
from .event_bus import EventBus
from astrbot.core import logger, sp
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
class AstrBotCoreLifecycle:
"""AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作.
"""
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
"""
def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker # 初始化日志代理
self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库
@@ -70,11 +70,11 @@ class AstrBotCoreLifecycle:
del os.environ["no_proxy"]
logger.debug("HTTP proxy cleared")
async def initialize(self) -> None:
"""初始化 AstrBot 核心生命周期管理类.
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
async def initialize(self):
"""
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
logger.info("AstrBot v" + VERSION)
if os.environ.get("TESTING", ""):
@@ -91,9 +91,7 @@ class AstrBotCoreLifecycle:
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
default_config=self.astrbot_config,
ucr=self.umop_config_router,
sp=sp,
default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
)
# 4.5 to 4.6 migration for umop_config_router
@@ -112,9 +110,7 @@ class AstrBotCoreLifecycle:
# 初始化供应商管理器
self.provider_manager = ProviderManager(
self.astrbot_config_mgr,
self.db,
self.persona_mgr,
self.astrbot_config_mgr, self.db, self.persona_mgr
)
# 初始化平台管理器
@@ -162,9 +158,7 @@ class AstrBotCoreLifecycle:
# 初始化事件总线
self.event_bus = EventBus(
self.event_queue,
self.pipeline_scheduler_mapping,
self.astrbot_config_mgr,
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
)
# 记录启动时间
@@ -179,13 +173,13 @@ class AstrBotCoreLifecycle:
# 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event()
def _load(self) -> None:
"""加载事件总线和任务并初始化."""
def _load(self):
"""加载事件总线和任务并初始化"""
# 创建一个异步任务来执行事件总线的 dispatch() 方法
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
event_bus_task = asyncio.create_task(
self.event_bus.dispatch(),
name="event_bus",
self.event_bus.dispatch(), name="event_bus"
)
# 把插件中注册的所有协程函数注册到事件总线中并执行
@@ -196,17 +190,16 @@ class AstrBotCoreLifecycle:
tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_:
self.curr_tasks.append(
asyncio.create_task(self._task_wrapper(task), name=task.get_name()),
asyncio.create_task(self._task_wrapper(task), name=task.get_name())
)
self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task) -> None:
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常.
async def _task_wrapper(self, task: asyncio.Task):
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
Args:
task (asyncio.Task): 要执行的异步任务
"""
try:
await task
@@ -219,22 +212,19 @@ class AstrBotCoreLifecycle:
logger.error(f"| {line}")
logger.error("-------")
async def start(self) -> None:
"""启动 AstrBot 核心生命周期管理类.
用load加载事件总线和任务并初始化, 执行启动完成事件钩子
"""
async def start(self):
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
self._load()
logger.info("AstrBot 启动完成。")
# 执行启动完成事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAstrBotLoadedEvent,
EventType.OnAstrBotLoadedEvent
)
for handler in handlers:
try:
logger.info(
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler()
except BaseException:
@@ -243,8 +233,8 @@ class AstrBotCoreLifecycle:
# 同时运行curr_tasks中的所有任务
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self) -> None:
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器."""
async def stop(self):
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
# 请求停止所有正在运行的异步任务
for task in self.curr_tasks:
task.cancel()
@@ -255,7 +245,7 @@ class AstrBotCoreLifecycle:
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。",
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
)
await self.provider_manager.terminate()
@@ -272,16 +262,14 @@ class AstrBotCoreLifecycle:
except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
async def restart(self) -> None:
async def restart(self):
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread(
target=self.astrbot_updator._reboot,
name="restart",
daemon=True,
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()
def load_platform(self) -> list[asyncio.Task]:
@@ -293,38 +281,36 @@ class AstrBotCoreLifecycle:
asyncio.create_task(
platform_inst.run(),
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
),
)
)
return tasks
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
"""加载消息事件流水线调度器.
"""加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
mapping = {}
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id),
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
mapping[conf_id] = scheduler
return mapping
async def reload_pipeline_scheduler(self, conf_id: str) -> None:
"""重新加载消息事件流水线调度器.
async def reload_pipeline_scheduler(self, conf_id: str):
"""重新加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
if not ab_config:
raise ValueError(f"配置文件 {conf_id} 不存在")
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id),
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
self.pipeline_scheduler_mapping[conf_id] = scheduler

View File

@@ -1,27 +1,27 @@
import abc
import datetime
import typing as T
from contextlib import asynccontextmanager
from dataclasses import dataclass
from deprecated import deprecated
from dataclasses import dataclass
from astrbot.core.db.po import (
Stats,
PlatformStat,
ConversationV2,
PlatformMessageHistory,
Attachment,
Persona,
Preference,
)
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from astrbot.core.db.po import (
Attachment,
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformStat,
Preference,
Stats,
)
@dataclass
class BaseDatabase(abc.ABC):
"""数据库基类"""
"""
数据库基类
"""
DATABASE_URL = ""
@@ -32,13 +32,12 @@ class BaseDatabase(abc.ABC):
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
self.engine, class_=AsyncSession, expire_on_commit=False
)
async def initialize(self):
"""初始化数据库连接"""
pass
@asynccontextmanager
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
@@ -92,9 +91,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
async def get_conversations(
self,
user_id: str | None = None,
platform_id: str | None = None,
self, user_id: str | None = None, platform_id: str | None = None
) -> list[ConversationV2]:
"""Get all conversations for a specific user and platform_id(optional).
@@ -109,9 +106,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
async def get_all_conversations(
self,
page: int = 1,
page_size: int = 20,
self, page: int = 1, page_size: int = 20
) -> list[ConversationV2]:
"""Get all conversations with pagination."""
...
@@ -178,10 +173,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
async def delete_platform_message_offset(
self,
platform_id: str,
user_id: str,
offset_sec: int = 86400,
self, platform_id: str, user_id: str, offset_sec: int = 86400
) -> None:
"""Delete platform message history records older than the specified offset."""
...
@@ -251,11 +243,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
async def insert_preference_or_update(
self,
scope: str,
scope_id: str,
key: str,
value: dict,
self, scope: str, scope_id: str, key: str, value: dict
) -> Preference:
"""Insert a new preference record."""
...
@@ -267,10 +255,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
async def get_preferences(
self,
scope: str,
scope_id: str | None = None,
key: str | None = None,
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""Get all preferences for a specific scope ID or key."""
...

View File

@@ -1,33 +1,27 @@
import os
from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
from astrbot.core.config import AstrBotConfig
from astrbot.api import logger, sp
from .migra_3_to_4 import (
migration_conversation_table,
migration_persona_data,
migration_platform_table,
migration_preferences,
migration_webchat_data,
migration_persona_data,
migration_preferences,
)
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
"""检查是否需要进行数据库迁移
"""
检查是否需要进行数据库迁移
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4则需要进行迁移。
"""
# 仅当 data 目录下存在旧版本数据data_v3.db 文件)时才考虑迁移
data_dir = get_astrbot_data_path()
data_v3_db = os.path.join(data_dir, "data_v3.db")
if not os.path.exists(data_v3_db):
data_v3_exists = os.path.exists(get_astrbot_data_path())
if not data_v3_exists:
return False
migration_done = await db_helper.get_preference(
"global",
"global",
"migration_done_v4",
"global", "global", "migration_done_v4"
)
if migration_done:
return False
@@ -38,8 +32,9 @@ async def do_migration_v4(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
astrbot_config: AstrBotConfig,
) -> None:
"""执行数据库迁移
):
"""
执行数据库迁移
迁移旧的 webchat_conversation 表到新的 conversation 表。
迁移旧的 platform 到新的 platform_stats 表。
"""

View File

@@ -1,18 +1,15 @@
import datetime
import json
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
import datetime
from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from .shared_preferences_v3 import sp as sp_v3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from astrbot.core.platform.astr_message_event import MessageSesion
from .. import BaseDatabase
from .shared_preferences_v3 import sp as sp_v3
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from sqlalchemy.ext.asyncio import AsyncSession
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from sqlalchemy import text
"""
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
@@ -21,8 +18,7 @@ from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
def get_platform_id(
platform_id_map: dict[str, dict[str, str]],
old_platform_name: str,
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
@@ -31,8 +27,7 @@ def get_platform_id(
def get_platform_type(
platform_id_map: dict[str, dict[str, str]],
old_platform_name: str,
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
@@ -41,15 +36,13 @@ def get_platform_type(
async def migration_conversation_table(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1,
page_size=10000000,
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
@@ -68,14 +61,13 @@ async def migration_conversation_table(
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id(
platform_id_map,
session.platform_name,
platform_id_map, session.platform_name
)
session.platform_id = platform_id # 更新平台名称为新的 ID
conv_v2 = ConversationV2(
@@ -98,11 +90,10 @@ async def migration_conversation_table(
async def migration_platform_table(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
secs_from_2023_4_10_to_now = (
datetime.datetime.now(datetime.timezone.utc)
@@ -143,12 +134,10 @@ async def migration_platform_table(
if cnt == 0:
continue
platform_id = get_platform_id(
platform_id_map,
platform_stats_v3[idx].name,
platform_id_map, platform_stats_v3[idx].name
)
platform_type = get_platform_type(
platform_id_map,
platform_stats_v3[idx].name,
platform_id_map, platform_stats_v3[idx].name
)
try:
await dbsession.execute(
@@ -160,8 +149,7 @@ async def migration_platform_table(
"""),
{
"timestamp": datetime.datetime.fromtimestamp(
bucket_end,
tz=datetime.timezone.utc,
bucket_end, tz=datetime.timezone.utc
),
"platform_id": platform_id,
"platform_type": platform_type,
@@ -177,16 +165,14 @@ async def migration_platform_table(
async def migration_webchat_data(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1,
page_size=10000000,
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
@@ -205,7 +191,7 @@ async def migration_webchat_data(
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" in conv.user_id:
continue
@@ -232,10 +218,10 @@ async def migration_webchat_data(
async def migration_persona_data(
db_helper: BaseDatabase,
astrbot_config: AstrBotConfig,
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
):
"""迁移 Persona 数据到新的表中。
"""
迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
"""
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
@@ -250,15 +236,14 @@ async def migration_persona_data(
try:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
parts = []
mood_prompt = ""
user_turn = True
for mood_dialog in mood_imitation_dialogs:
if user_turn:
parts.append(f"A: {mood_dialog}\n")
mood_prompt += f"A: {mood_dialog}\n"
else:
parts.append(f"B: {mood_dialog}\n")
mood_prompt += f"B: {mood_dialog}\n"
user_turn = not user_turn
mood_prompt = "".join(parts)
system_prompt = persona.get("prompt", "")
if mood_prompt:
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
@@ -268,15 +253,14 @@ async def migration_persona_data(
begin_dialogs=begin_dialogs,
)
logger.info(
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。",
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
async def migration_preferences(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
# 1. global scope migration
keys = [
@@ -345,13 +329,10 @@ async def migration_preferences(
for provider_type, provider_id in perf.items():
await sp.put_async(
"umo",
str(session),
f"provider_perf_{provider_type}",
provider_id,
"umo", str(session), f"provider_perf_{provider_type}", provider_id
)
logger.info(
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}",
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
)
except Exception as e:
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)

View File

@@ -9,7 +9,7 @@ async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
if not isinstance(abconf_data, dict):
# should be unreachable
logger.warning(
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}",
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
)
return

View File

@@ -1,7 +1,6 @@
import json
import os
from typing import TypeVar
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
@@ -17,7 +16,7 @@ class SharedPreferences:
def _load_preferences(self):
if os.path.exists(self.path):
try:
with open(self.path) as f:
with open(self.path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
os.remove(self.path)

View File

@@ -1,9 +1,8 @@
import sqlite3
import time
from dataclasses import dataclass
from typing import Any
from astrbot.core.db.po import Platform, Stats
from typing import Tuple, List, Dict, Any
from dataclasses import dataclass
@dataclass
@@ -95,7 +94,7 @@ class SQLiteDatabase:
c.execute(
"""
PRAGMA table_info(webchat_conversation)
""",
"""
)
res = c.fetchall()
has_title = False
@@ -109,14 +108,14 @@ class SQLiteDatabase:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
""",
"""
)
self.conn.commit()
if not has_persona_id:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
""",
"""
)
self.conn.commit()
@@ -127,7 +126,7 @@ class SQLiteDatabase:
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: tuple = None):
def _exec_sql(self, sql: str, params: Tuple = None):
conn = self.conn
try:
c = self.conn.cursor()
@@ -175,7 +174,7 @@ class SQLiteDatabase:
"""
SELECT * FROM platform
"""
+ where_clause,
+ where_clause
)
platform = []
@@ -195,7 +194,7 @@ class SQLiteDatabase:
c.execute(
"""
SELECT SUM(count) FROM platform
""",
"""
)
res = c.fetchone()
c.close()
@@ -215,7 +214,7 @@ class SQLiteDatabase:
SELECT name, SUM(count), timestamp FROM platform
"""
+ where_clause
+ " GROUP BY name",
+ " GROUP BY name"
)
platform = []
@@ -243,7 +242,7 @@ class SQLiteDatabase:
c.close()
if not res:
return None
return
return Conversation(*res)
@@ -258,7 +257,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> tuple:
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -281,7 +280,7 @@ class SQLiteDatabase:
title = row[3]
persona_id = row[4]
conversations.append(
Conversation("", cid, "[]", created_at, updated_at, title, persona_id),
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
)
return conversations
@@ -320,10 +319,8 @@ class SQLiteDatabase:
)
def get_all_conversations(
self,
page: int = 1,
page_size: int = 20,
) -> tuple[list[dict[str, Any]], int]:
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
@@ -369,7 +366,7 @@ class SQLiteDatabase:
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
},
}
)
return conversations, total_count
@@ -384,12 +381,12 @@ class SQLiteDatabase:
self,
page: int = 1,
page_size: int = 20,
platforms: list[str] | None = None,
message_types: list[str] | None = None,
search_query: str | None = None,
exclude_ids: list[str] | None = None,
exclude_platforms: list[str] | None = None,
) -> tuple[list[dict[str, Any]], int]:
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
@@ -425,7 +422,7 @@ class SQLiteDatabase:
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)",
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
@@ -485,7 +482,7 @@ class SQLiteDatabase:
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
},
}
)
return conversations, total_count

View File

@@ -1,15 +1,15 @@
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TypedDict
from datetime import datetime, timezone
from dataclasses import dataclass, field
from sqlmodel import (
JSON,
Field,
SQLModel,
Text,
JSON,
UniqueConstraint,
Field,
)
from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True):
@@ -40,8 +40,7 @@ class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations"
inner_conversation_id: int = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conversation_id: str = Field(
max_length=36,
@@ -51,14 +50,14 @@ class ConversationV2(SQLModel, table=True):
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
content: list | None = Field(default=None, sa_type=JSON)
content: Optional[list] = Field(default=None, sa_type=JSON)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
title: str | None = Field(default=None, max_length=255)
persona_id: str | None = Field(default=None)
title: Optional[str] = Field(default=None, max_length=255)
persona_id: Optional[str] = Field(default=None)
__table_args__ = (
UniqueConstraint(
@@ -77,15 +76,13 @@ class Persona(SQLModel, table=True):
__tablename__ = "personas"
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: list | None = Field(default=None, sa_type=JSON)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with"""
tools: list | None = Field(default=None, sa_type=JSON)
tools: Optional[list] = Field(default=None, sa_type=JSON)
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
@@ -107,9 +104,7 @@ class Preference(SQLModel, table=True):
__tablename__ = "preferences"
id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
scope: str = Field(nullable=False)
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
@@ -143,15 +138,13 @@ class PlatformMessageHistory(SQLModel, table=True):
__tablename__ = "platform_message_history"
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: str | None = Field(default=None) # ID of the sender in the platform
sender_name: str | None = Field(
default=None,
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
sender_name: Optional[str] = Field(
default=None
) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@@ -170,9 +163,7 @@ class Attachment(SQLModel, table=True):
__tablename__ = "attachments"
inner_attachment_id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
attachment_id: str = Field(
max_length=36,

View File

@@ -1,27 +1,22 @@
import asyncio
import threading
import typing as T
import threading
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import (
Attachment,
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformStat,
PlatformMessageHistory,
Attachment,
Persona,
Preference,
Stats as DeprecatedStats,
Platform as DeprecatedPlatformStat,
SQLModel,
)
from astrbot.core.db.po import (
Platform as DeprecatedPlatformStat,
)
from astrbot.core.db.po import (
Stats as DeprecatedStats,
)
from sqlmodel import select, update, delete, text, func, or_, desc, col
from sqlalchemy.ext.asyncio import AsyncSession
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
@@ -62,9 +57,7 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
if timestamp is None:
timestamp = datetime.now().replace(
minute=0,
second=0,
microsecond=0,
minute=0, second=0, microsecond=0
)
current_hour = timestamp
await session.execute(
@@ -88,13 +81,13 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
result = await session.execute(
select(func.count(col(PlatformStat.platform_id))).select_from(
PlatformStat,
),
PlatformStat
)
)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
async with self.get_db() as session:
session: AsyncSession
@@ -145,7 +138,7 @@ class SQLiteDatabase(BaseDatabase):
select(ConversationV2)
.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size),
.limit(page_size)
)
return result.scalars().all()
@@ -164,7 +157,7 @@ class SQLiteDatabase(BaseDatabase):
if platform_ids:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(platform_ids),
col(ConversationV2.platform_id).in_(platform_ids)
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
@@ -174,16 +167,16 @@ class SQLiteDatabase(BaseDatabase):
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
),
)
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
col(ConversationV2.platform_id).in_(kwargs["platforms"])
)
# Get total count matching the filters
@@ -240,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
col(ConversationV2.conversation_id) == cid
)
values = {}
if title is not None:
@@ -250,7 +243,7 @@ class SQLiteDatabase(BaseDatabase):
if content is not None:
values["content"] = content
if not values:
return None
return
query = query.values(**values)
await session.execute(query)
return await self.get_conversation_by_id(cid)
@@ -261,8 +254,8 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
),
col(ConversationV2.conversation_id) == cid
)
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
@@ -270,9 +263,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.user_id) == user_id
),
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
)
async def get_session_conversations(
@@ -291,7 +282,7 @@ class SQLiteDatabase(BaseDatabase):
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id",
"conversation_id"
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
@@ -304,8 +295,7 @@ class SQLiteDatabase(BaseDatabase):
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
Persona, col(ConversationV2.persona_id) == Persona.persona_id
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
@@ -318,14 +308,14 @@ class SQLiteDatabase(BaseDatabase):
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern),
col(Preference.scope_id).like(platform_pattern)
)
# 排序
@@ -346,8 +336,7 @@ class SQLiteDatabase(BaseDatabase):
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
Persona, col(ConversationV2.persona_id) == Persona.persona_id
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
@@ -360,13 +349,13 @@ class SQLiteDatabase(BaseDatabase):
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern),
col(Preference.scope_id).like(platform_pattern)
)
total_result = await session.execute(count_base_query)
@@ -407,10 +396,7 @@ class SQLiteDatabase(BaseDatabase):
return new_history
async def delete_platform_message_offset(
self,
platform_id,
user_id,
offset_sec=86400,
self, platform_id, user_id, offset_sec=86400
):
"""Delete platform message history records older than the specified offset."""
async with self.get_db() as session:
@@ -423,15 +409,11 @@ class SQLiteDatabase(BaseDatabase):
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) < cutoff_time,
),
)
)
async def get_platform_message_history(
self,
platform_id,
user_id,
page=1,
page_size=20,
self, platform_id, user_id, page=1, page_size=20
):
"""Get platform message history records."""
async with self.get_db() as session:
@@ -470,11 +452,7 @@ class SQLiteDatabase(BaseDatabase):
return result.scalar_one_or_none()
async def insert_persona(
self,
persona_id,
system_prompt,
begin_dialogs=None,
tools=None,
self, persona_id, system_prompt, begin_dialogs=None, tools=None
):
"""Insert a new persona record."""
async with self.get_db() as session:
@@ -506,11 +484,7 @@ class SQLiteDatabase(BaseDatabase):
return result.scalars().all()
async def update_persona(
self,
persona_id,
system_prompt=None,
begin_dialogs=None,
tools=NOT_GIVEN,
self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
):
"""Update a persona's system prompt or begin dialogs."""
async with self.get_db() as session:
@@ -525,7 +499,7 @@ class SQLiteDatabase(BaseDatabase):
if tools is not NOT_GIVEN:
values["tools"] = tools
if not values:
return None
return
query = query.values(**values)
await session.execute(query)
return await self.get_persona_by_id(persona_id)
@@ -536,7 +510,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(col(Persona.persona_id) == persona_id),
delete(Persona).where(col(Persona.persona_id) == persona_id)
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
@@ -555,10 +529,7 @@ class SQLiteDatabase(BaseDatabase):
existing_preference.value = value
else:
new_preference = Preference(
scope=scope,
scope_id=scope_id,
key=key,
value=value,
scope=scope, scope_id=scope_id, key=key, value=value
)
session.add(new_preference)
return existing_preference or new_preference
@@ -597,7 +568,7 @@ class SQLiteDatabase(BaseDatabase):
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
col(Preference.key) == key,
),
)
)
await session.commit()
@@ -610,7 +581,7 @@ class SQLiteDatabase(BaseDatabase):
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
),
)
)
await session.commit()
@@ -627,7 +598,7 @@ class SQLiteDatabase(BaseDatabase):
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
select(PlatformStat).where(PlatformStat.timestamp >= start_time)
)
all_datas = result.scalars().all()
deprecated_stats = DeprecatedStats()
@@ -637,7 +608,7 @@ class SQLiteDatabase(BaseDatabase):
name=data.platform_id,
count=data.count,
timestamp=int(data.timestamp.timestamp()),
),
)
)
return deprecated_stats
@@ -659,7 +630,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
select(func.sum(PlatformStat.count)).select_from(PlatformStat)
)
total_count = result.scalar_one_or_none()
return total_count if total_count is not None else 0
@@ -685,7 +656,7 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
.where(PlatformStat.timestamp >= start_time)
.group_by(PlatformStat.platform_id),
.group_by(PlatformStat.platform_id)
)
grouped_stats = result.all()
deprecated_stats = DeprecatedStats()
@@ -695,7 +666,7 @@ class SQLiteDatabase(BaseDatabase):
name=platform_id,
count=count,
timestamp=int(start_time.timestamp()),
),
)
)
return deprecated_stats

View File

@@ -10,16 +10,18 @@ class Result:
class BaseVecDB:
async def initialize(self):
"""初始化向量数据库"""
"""
初始化向量数据库
"""
pass
@abc.abstractmethod
async def insert(
self,
content: str,
metadata: dict | None = None,
id: str | None = None,
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
@@ -33,11 +35,11 @@ class BaseVecDB:
max_retries: int = 3,
progress_callback=None,
) -> int:
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
"""
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
...
@@ -50,7 +52,8 @@ class BaseVecDB:
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
"""搜索最相似的文档。
"""
搜索最相似的文档。
Args:
query (str): 查询文本
top_k (int): 返回的最相似文档的数量
@@ -61,7 +64,8 @@ class BaseVecDB:
@abc.abstractmethod
async def delete(self, doc_id: str) -> bool:
"""删除指定文档。
"""
删除指定文档。
Args:
doc_id (str): 要删除的文档 ID
Returns:

View File

@@ -1,13 +1,12 @@
import json
import os
from contextlib import asynccontextmanager
import json
from datetime import datetime
from contextlib import asynccontextmanager
from sqlalchemy import Column, Text
from sqlalchemy import Text, Column
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, MetaData, SQLModel, col, func, select, text
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
from astrbot.core import logger
@@ -21,9 +20,7 @@ class Document(BaseDocModel, table=True):
__tablename__ = "documents" # type: ignore
id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
doc_id: str = Field(nullable=False)
text: str = Field(nullable=False)
@@ -39,8 +36,7 @@ class DocumentStorage:
self.engine: AsyncEngine | None = None
self.async_session_maker: sessionmaker | None = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__),
"sqlite_init.sql",
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
@@ -54,26 +50,26 @@ class DocumentStorage:
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED",
),
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
)
)
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN user_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED",
),
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
)
)
# Create indexes
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)",
),
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)",
),
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
)
)
except BaseException:
pass
@@ -117,11 +113,10 @@ class DocumentStorage:
Returns:
list: The list of documents that match the filters.
"""
if self.engine is None:
logger.warning(
"Database connection is not initialized, returning empty result",
"Database connection is not initialized, returning empty result"
)
return []
@@ -130,7 +125,7 @@ class DocumentStorage:
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
if ids is not None and len(ids) > 0:
@@ -158,27 +153,24 @@ class DocumentStorage:
Returns:
int: The integer ID of the inserted document.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session, session.begin():
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
session.add(document)
await session.flush() # Flush to get the ID
return document.id # type: ignore
async with self.get_session() as session:
async with session.begin():
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
session.add(document)
await session.flush() # Flush to get the ID
return document.id # type: ignore
async def insert_documents_batch(
self,
doc_ids: list[str],
texts: list[str],
metadatas: list[dict],
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
) -> list[int]:
"""Batch insert documents and return their integer IDs.
@@ -189,44 +181,44 @@ class DocumentStorage:
Returns:
list[int]: List of integer IDs of the inserted documents.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session, session.begin():
import json
async with self.get_session() as session:
async with session.begin():
import json
documents = []
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
documents.append(document)
session.add(document)
documents = []
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
documents.append(document)
session.add(document)
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str):
"""Delete a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to delete.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session, session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
await session.delete(document)
if document:
await session.delete(document)
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
@@ -236,7 +228,6 @@ class DocumentStorage:
Returns:
dict: The document data or None if not found.
"""
assert self.engine is not None, "Database connection is not initialized."
@@ -255,46 +246,46 @@ class DocumentStorage:
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session, session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
if document:
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
async def delete_documents(self, metadata_filters: dict):
"""Delete documents by their metadata filters.
Args:
metadata_filters (dict): The metadata filters to apply.
"""
if self.engine is None:
logger.warning(
"Database connection is not initialized, skipping delete operation",
"Database connection is not initialized, skipping delete operation"
)
return
async with self.get_session() as session, session.begin():
query = select(Document)
async with self.get_session() as session:
async with session.begin():
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
).params(**{f"filter_{key}": val})
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
documents = result.scalars().all()
result = await session.execute(query)
documents = result.scalars().all()
for doc in documents:
await session.delete(doc)
for doc in documents:
await session.delete(doc)
async def count_documents(self, metadata_filters: dict | None = None) -> int:
"""Count documents in the database.
@@ -304,7 +295,6 @@ class DocumentStorage:
Returns:
int: The count of documents.
"""
if self.engine is None:
logger.warning("Database connection is not initialized, returning 0")
@@ -316,7 +306,7 @@ class DocumentStorage:
if metadata_filters:
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
@@ -328,13 +318,12 @@ class DocumentStorage:
Returns:
list: A list of user IDs.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = text(
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL",
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
)
result = await session.execute(query)
rows = result.fetchall()
@@ -348,7 +337,6 @@ class DocumentStorage:
Returns:
dict: The converted dictionary.
"""
return {
"id": document.id,
@@ -373,7 +361,6 @@ class DocumentStorage:
dict: The converted dictionary.
Note: This method is kept for backward compatibility but is no longer used internally.
"""
return {
"id": row[0],

View File

@@ -2,10 +2,9 @@ try:
import faiss
except ModuleNotFoundError:
raise ImportError(
"faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。",
"faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。"
)
import os
import numpy as np
@@ -28,12 +27,11 @@ class EmbeddingStorage:
id (int): 向量的ID
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
await self.save_index()
@@ -46,12 +44,11 @@ class EmbeddingStorage:
ids (list[int]): 向量的ID列表
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vectors.shape[1] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}"
)
self.index.add_with_ids(vectors, np.array(ids))
await self.save_index()
@@ -64,7 +61,6 @@ class EmbeddingStorage:
k (int): 返回的最相似向量的数量
Returns:
tuple: (距离, 索引)
"""
assert self.index is not None, "FAISS index is not initialized."
faiss.normalize_L2(vector)
@@ -76,7 +72,6 @@ class EmbeddingStorage:
Args:
ids (list[int]): 要删除的向量ID列表
"""
assert self.index is not None, "FAISS index is not initialized."
id_array = np.array(ids, dtype=np.int64)
@@ -88,6 +83,5 @@ class EmbeddingStorage:
Args:
path (str): 保存索引的路径
"""
faiss.write_index(self.index, self.path)

View File

@@ -1,18 +1,18 @@
import time
import uuid
import time
import numpy as np
from astrbot import logger
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from ..base import BaseVecDB, Result
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.provider import RerankProvider
from astrbot import logger
class FaissVecDB(BaseVecDB):
"""A class to represent a vector database."""
"""
A class to represent a vector database.
"""
def __init__(
self,
@@ -26,8 +26,7 @@ class FaissVecDB(BaseVecDB):
self.embedding_provider = embedding_provider
self.document_storage = DocumentStorage(doc_store_path)
self.embedding_storage = EmbeddingStorage(
embedding_provider.get_dim(),
index_store_path,
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
self.rerank_provider = rerank_provider
@@ -36,12 +35,11 @@ class FaissVecDB(BaseVecDB):
await self.document_storage.initialize()
async def insert(
self,
content: str,
metadata: dict | None = None,
id: str | None = None,
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
metadata = metadata or {}
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
@@ -65,11 +63,11 @@ class FaissVecDB(BaseVecDB):
max_retries: int = 3,
progress_callback=None,
) -> list[int]:
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
"""
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
metadatas = metadatas or [{} for _ in contents]
ids = ids or [str(uuid.uuid4()) for _ in contents]
@@ -85,14 +83,12 @@ class FaissVecDB(BaseVecDB):
)
end = time.time()
logger.debug(
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds.",
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
)
# 使用 DocumentStorage 的批量插入方法
int_ids = await self.document_storage.insert_documents_batch(
ids,
contents,
metadatas,
ids, contents, metadatas
)
# 批量插入向量到 FAISS
@@ -108,7 +104,8 @@ class FaissVecDB(BaseVecDB):
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
"""搜索最相似的文档。
"""
搜索最相似的文档。
Args:
query (str): 查询文本
@@ -119,7 +116,6 @@ class FaissVecDB(BaseVecDB):
Returns:
List[Result]: 查询结果
"""
embedding = await self.embedding_provider.get_embedding(query)
scores, indices = await self.embedding_storage.search(
@@ -132,8 +128,7 @@ class FaissVecDB(BaseVecDB):
scores[0] = 1.0 - (scores[0] / 2.0)
# NOTE: maybe the size is less than k.
fetched_docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters or {},
ids=indices[0],
metadata_filters=metadata_filters or {}, ids=indices[0]
)
if not fetched_docs:
return []
@@ -154,9 +149,7 @@ class FaissVecDB(BaseVecDB):
documents = [doc.data["text"] for doc in top_k_results]
reranked_results = await self.rerank_provider.rerank(query, documents)
reranked_results = sorted(
reranked_results,
key=lambda x: x.relevance_score,
reverse=True,
reranked_results, key=lambda x: x.relevance_score, reverse=True
)
top_k_results = [
top_k_results[reranked_result.index]
@@ -166,7 +159,9 @@ class FaissVecDB(BaseVecDB):
return top_k_results
async def delete(self, doc_id: str):
"""删除一条文档块chunk"""
"""
删除一条文档块chunk
"""
# 获得对应的 int id
result = await self.document_storage.get_document_by_doc_id(doc_id)
int_id = result["id"] if result else None
@@ -181,23 +176,23 @@ class FaissVecDB(BaseVecDB):
await self.document_storage.close()
async def count_documents(self, metadata_filter: dict | None = None) -> int:
"""计算文档数量
"""
计算文档数量
Args:
metadata_filter (dict | None): 元数据过滤器
"""
count = await self.document_storage.count_documents(
metadata_filters=metadata_filter or {},
metadata_filters=metadata_filter or {}
)
return count
async def delete_documents(self, metadata_filters: dict):
"""根据元数据过滤器删除文档"""
"""
根据元数据过滤器删除文档
"""
docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters,
offset=None,
limit=None,
metadata_filters=metadata_filters, offset=None, limit=None
)
doc_ids: list[int] = [doc["id"] for doc in docs]
await self.embedding_storage.delete(doc_ids)

View File

@@ -1,4 +1,5 @@
"""事件总线, 用于处理事件的分发和处理
"""
事件总线, 用于处理事件的分发和处理
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
@@ -12,12 +13,10 @@ class:
import asyncio
from asyncio import Queue
from astrbot.core import logger
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.pipeline.scheduler import PipelineScheduler
from astrbot.core import logger
from .platform import AstrMessageEvent
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
class EventBus:
@@ -47,15 +46,14 @@ class EventBus:
Args:
event (AstrMessageEvent): 事件对象
"""
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name():
logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}",
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
else:
logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}",
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
)

View File

@@ -1,9 +1,9 @@
import asyncio
import os
import platform
import time
import uuid
from urllib.parse import unquote, urlparse
import time
from urllib.parse import urlparse, unquote
import platform
class FileTokenService:
@@ -40,8 +40,8 @@ class FileTokenService:
Raises:
FileNotFoundError: 当路径不存在时抛出
"""
# 处理 file:///
try:
parsed_uri = urlparse(file_path)
@@ -61,7 +61,7 @@ class FileTokenService:
if not os.path.exists(local_path):
raise FileNotFoundError(
f"文件不存在: {local_path} (原始输入: {file_path})",
f"文件不存在: {local_path} (原始输入: {file_path})"
)
file_token = str(uuid.uuid4())
@@ -84,7 +84,6 @@ class FileTokenService:
Raises:
KeyError: 当令牌不存在或已过期时抛出
FileNotFoundError: 当文件本身已被删除时抛出
"""
async with self.lock:
await self._cleanup_expired_tokens()

View File

@@ -1,4 +1,5 @@
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
"""
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
工作流程:
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
@@ -7,10 +8,10 @@
import asyncio
import traceback
from astrbot.core import LogBroker, logger
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core import LogBroker
from astrbot.dashboard.server import AstrBotDashboard
@@ -38,10 +39,7 @@ class InitialLoader:
webui_dir = self.webui_dir
self.dashboard_server = AstrBotDashboard(
core_lifecycle,
self.db,
core_lifecycle.dashboard_shutdown_event,
webui_dir,
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
)
coro = self.dashboard_server.run()

View File

@@ -1,4 +1,6 @@
"""文档分块模块"""
"""
文档分块模块
"""
from .base import BaseChunker
from .fixed_size import FixedSizeChunker

View File

@@ -21,5 +21,4 @@ class BaseChunker(ABC):
Returns:
list[str]: 分块后的文本列表
"""

View File

@@ -18,7 +18,6 @@ class FixedSizeChunker(BaseChunker):
Args:
chunk_size: 块的大小(字符数)
chunk_overlap: 块之间的重叠字符数
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
@@ -33,7 +32,6 @@ class FixedSizeChunker(BaseChunker):
Returns:
list[str]: 分块后的文本列表
"""
chunk_size = kwargs.get("chunk_size", self.chunk_size)
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)

View File

@@ -1,5 +1,4 @@
from collections.abc import Callable
from .base import BaseChunker
@@ -12,7 +11,8 @@ class RecursiveCharacterChunker(BaseChunker):
is_separator_regex: bool = False,
separators: list[str] | None = None,
):
"""初始化递归字符文本分割器
"""
初始化递归字符文本分割器
Args:
chunk_size: 每个文本块的最大大小
@@ -20,7 +20,6 @@ class RecursiveCharacterChunker(BaseChunker):
length_function: 计算文本长度的函数
is_separator_regex: 分隔符是否为正则表达式
separators: 用于分割文本的分隔符列表,按优先级排序
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
@@ -40,7 +39,8 @@ class RecursiveCharacterChunker(BaseChunker):
]
async def chunk(self, text: str, **kwargs) -> list[str]:
"""递归地将文本分割成块
"""
递归地将文本分割成块
Args:
text: 要分割的文本
@@ -49,7 +49,6 @@ class RecursiveCharacterChunker(BaseChunker):
Returns:
分割后的文本块列表
"""
if not text:
return []
@@ -91,7 +90,7 @@ class RecursiveCharacterChunker(BaseChunker):
combined_text,
chunk_size=chunk_size,
chunk_overlap=overlap,
),
)
)
current_chunk = []
current_chunk_length = 0
@@ -99,10 +98,8 @@ class RecursiveCharacterChunker(BaseChunker):
# 递归分割过大的部分
final_chunks.extend(
await self.chunk(
split,
chunk_size=chunk_size,
chunk_overlap=overlap,
),
split, chunk_size=chunk_size, chunk_overlap=overlap
)
)
# 如果添加这部分会使当前块超过chunk_size
elif current_chunk_length + split_length > chunk_size:
@@ -135,19 +132,16 @@ class RecursiveCharacterChunker(BaseChunker):
return [text]
def _split_by_character(
self,
text: str,
chunk_size: int | None = None,
overlap: int | None = None,
self, text: str, chunk_size: int | None = None, overlap: int | None = None
) -> list[str]:
"""按字符级别分割文本
"""
按字符级别分割文本
Args:
text: 要分割的文本
Returns:
分割后的文本块列表
"""
chunk_size = chunk_size or self.chunk_size
overlap = overlap or self.chunk_overlap

View File

@@ -1,18 +1,18 @@
from contextlib import asynccontextmanager
from pathlib import Path
from sqlalchemy import delete, func, select, text, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import col, desc
from sqlalchemy import text, func, select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.models import (
BaseKBModel,
KBDocument,
KBMedia,
KnowledgeBase,
)
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
class KBSQLiteDatabase:
@@ -21,7 +21,6 @@ class KBSQLiteDatabase:
Args:
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
@@ -86,77 +85,77 @@ class KBSQLiteDatabase:
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)",
),
"ON knowledge_bases(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)",
),
"ON knowledge_bases(kb_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)",
),
"ON knowledge_bases(created_at)"
)
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)",
),
"ON kb_documents(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)",
),
"ON kb_documents(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)",
),
"ON kb_documents(doc_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)",
),
"ON kb_documents(file_type)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)",
),
"ON kb_documents(created_at)"
)
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)",
),
"ON kb_media(media_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)",
),
"ON kb_media(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)",
),
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)",
),
"ON kb_media(media_type)"
)
)
await session.commit()
@@ -209,10 +208,7 @@ class KBSQLiteDatabase:
return result.scalar_one_or_none()
async def list_documents_by_kb(
self,
kb_id: str,
offset: int = 0,
limit: int = 100,
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.get_db() as session:
@@ -230,7 +226,7 @@ class KBSQLiteDatabase:
"""统计知识库的文档数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KBDocument.id))).where(
col(KBDocument.kb_id) == kb_id,
col(KBDocument.kb_id) == kb_id
)
result = await session.execute(stmt)
return result.scalar() or 0
@@ -256,11 +252,12 @@ class KBSQLiteDatabase:
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session, session.begin():
# 删除文档记录
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
await session.execute(delete_stmt)
await session.commit()
async with self.get_db() as session:
async with session.begin():
# 删除文档记录
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
await session.execute(delete_stmt)
await session.commit()
# 在 vec db 中删除相关向量
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
@@ -285,17 +282,18 @@ class KBSQLiteDatabase:
"""更新知识库统计信息"""
chunk_cnt = await vec_db.count_documents()
async with self.get_db() as session, session.begin():
update_stmt = (
update(KnowledgeBase)
.where(col(KnowledgeBase.kb_id) == kb_id)
.values(
doc_count=select(func.count(col(KBDocument.id)))
.where(col(KBDocument.kb_id) == kb_id)
.scalar_subquery(),
chunk_count=chunk_cnt,
async with self.get_db() as session:
async with session.begin():
update_stmt = (
update(KnowledgeBase)
.where(col(KnowledgeBase.kb_id) == kb_id)
.values(
doc_count=select(func.count(col(KBDocument.id)))
.where(col(KBDocument.kb_id) == kb_id)
.scalar_subquery(),
chunk_count=chunk_cnt,
)
)
)
await session.execute(update_stmt)
await session.commit()
await session.execute(update_stmt)
await session.commit()

View File

@@ -1,19 +1,16 @@
import json
import uuid
from pathlib import Path
import aiofiles
from astrbot.core import logger
import json
from pathlib import Path
from .models import KnowledgeBase, KBDocument, KBMedia
from .kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from .chunking.base import BaseChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .models import KBDocument, KBMedia, KnowledgeBase
from astrbot.core.provider.manager import ProviderManager
from .parsers.util import select_parser
from .chunking.base import BaseChunker
from astrbot.core import logger
class KBHelper:
@@ -48,11 +45,11 @@ class KBHelper:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
self.kb.embedding_provider_id,
self.kb.embedding_provider_id
) # type: ignore
if not ep:
raise ValueError(
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider",
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
)
return ep
@@ -60,11 +57,11 @@ class KBHelper:
if not self.kb.rerank_provider_id:
return None
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
self.kb.rerank_provider_id,
self.kb.rerank_provider_id
) # type: ignore
if not rp:
raise ValueError(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider",
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
)
return rp
@@ -125,7 +122,6 @@ class KBHelper:
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
- current: 当前进度
- total: 总数
"""
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())
@@ -166,9 +162,7 @@ class KBHelper:
await progress_callback("chunking", 0, 100)
chunks_text = await self.chunker.chunk(
text_content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
contents = []
metadatas = []
@@ -179,7 +173,7 @@ class KBHelper:
"kb_id": self.kb.kb_id,
"kb_doc_id": doc_id,
"chunk_index": idx,
},
}
)
if progress_callback:
@@ -240,9 +234,7 @@ class KBHelper:
raise e
async def list_documents(
self,
offset: int = 0,
limit: int = 100,
self, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
@@ -296,17 +288,12 @@ class KBHelper:
await session.refresh(doc)
async def get_chunks_by_doc_id(
self,
doc_id: str,
offset: int = 0,
limit: int = 100,
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[dict]:
"""获取文档的所有块及其元数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
chunks = await vec_db.document_storage.get_documents(
metadata_filters={"kb_doc_id": doc_id},
offset=offset,
limit=limit,
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
)
result = []
for chunk in chunks:
@@ -319,7 +306,7 @@ class KBHelper:
"chunk_index": chunk_md["chunk_index"],
"content": chunk["text"],
"char_count": len(chunk["text"]),
},
}
)
return result

View File

@@ -1,17 +1,19 @@
import traceback
from pathlib import Path
from astrbot.core import logger
from astrbot.core.provider.manager import ProviderManager
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.sparse_retriever import SparseRetriever
from .retrieval.rank_fusion import RankFusion
from .kb_db_sqlite import KBSQLiteDatabase
# from .chunking.fixed_size import FixedSizeChunker
from .chunking.recursive import RecursiveCharacterChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .kb_helper import KBHelper
from .models import KnowledgeBase
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.rank_fusion import RankFusion
from .retrieval.sparse_retriever import SparseRetriever
FILES_PATH = "data/knowledge_base"
DB_PATH = Path(FILES_PATH) / "kb.db"
@@ -255,7 +257,6 @@ class KnowledgeBaseManager:
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]

View File

@@ -1,7 +1,7 @@
import uuid
from datetime import datetime, timezone
from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
class BaseKBModel(SQLModel, table=False):
@@ -17,9 +17,7 @@ class KnowledgeBase(BaseKBModel, table=True):
__tablename__ = "knowledge_bases" # type: ignore
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
kb_id: str = Field(
max_length=36,
@@ -65,9 +63,7 @@ class KBDocument(BaseKBModel, table=True):
__tablename__ = "kb_documents" # type: ignore
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
doc_id: str = Field(
max_length=36,
@@ -99,9 +95,7 @@ class KBMedia(BaseKBModel, table=True):
__tablename__ = "kb_media" # type: ignore
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
media_id: str = Field(
max_length=36,

View File

@@ -1,13 +1,15 @@
"""文档解析器模块"""
"""
文档解析器模块
"""
from .base import BaseParser, MediaItem, ParseResult
from .pdf_parser import PDFParser
from .text_parser import TextParser
from .pdf_parser import PDFParser
__all__ = [
"BaseParser",
"MediaItem",
"PDFParser",
"ParseResult",
"TextParser",
"PDFParser",
]

View File

@@ -47,5 +47,4 @@ class BaseParser(ABC):
Returns:
ParseResult: 解析结果
"""

View File

@@ -1,12 +1,11 @@
import io
import os
from markitdown_no_magika import MarkItDown, StreamInfo
from astrbot.core.knowledge_base.parsers.base import (
BaseParser,
ParseResult,
)
from markitdown_no_magika import MarkItDown, StreamInfo
class MarkitdownParser(BaseParser):

View File

@@ -29,7 +29,6 @@ class PDFParser(BaseParser):
Returns:
ParseResult: 包含文本和图片的解析结果
"""
pdf_file = io.BytesIO(file_content)
reader = PdfReader(pdf_file)
@@ -88,7 +87,7 @@ class PDFParser(BaseParser):
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
content=image_data,
mime_type=mime_type,
),
)
)
except Exception:
# 单个图片提取失败不影响整体

View File

@@ -26,7 +26,6 @@ class TextParser(BaseParser):
Raises:
ValueError: 如果无法解码文件
"""
# 尝试多种编码
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:

View File

@@ -6,7 +6,7 @@ async def select_parser(ext: str) -> BaseParser:
from .markitdown_parser import MarkitdownParser
return MarkitdownParser()
if ext == ".pdf":
elif ext == ".pdf":
from .pdf_parser import PDFParser
return PDFParser()

View File

@@ -1,14 +1,16 @@
"""检索模块"""
"""
检索模块
"""
from .manager import RetrievalManager, RetrievalResult
from .rank_fusion import FusedResult, RankFusion
from .sparse_retriever import SparseResult, SparseRetriever
from .sparse_retriever import SparseRetriever, SparseResult
from .rank_fusion import RankFusion, FusedResult
__all__ = [
"FusedResult",
"RankFusion",
"RetrievalManager",
"RetrievalResult",
"SparseResult",
"SparseRetriever",
"SparseResult",
"RankFusion",
"FusedResult",
]

View File

@@ -4,17 +4,18 @@
"""
import time
from dataclasses import dataclass
from astrbot import logger
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from dataclasses import dataclass
from typing import List
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from ..kb_helper import KBHelper
from astrbot import logger
@dataclass
@@ -52,7 +53,6 @@ class RetrievalManager:
sparse_retriever: 稀疏检索器
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
"""
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
@@ -61,11 +61,11 @@ class RetrievalManager:
async def retrieve(
self,
query: str,
kb_ids: list[str],
kb_ids: List[str],
kb_id_helper_map: dict[str, KBHelper],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> list[RetrievalResult]:
) -> List[RetrievalResult]:
"""混合检索
流程:
@@ -82,7 +82,6 @@ class RetrievalManager:
Returns:
List[RetrievalResult]: 检索结果列表
"""
if not kb_ids:
return []
@@ -115,7 +114,7 @@ class RetrievalManager:
)
time_end = time.time()
logger.debug(
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.",
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
)
# 2. 稀疏检索
@@ -127,7 +126,7 @@ class RetrievalManager:
)
time_end = time.time()
logger.debug(
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.",
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
)
# 3. 结果融合
@@ -139,7 +138,7 @@ class RetrievalManager:
)
time_end = time.time()
logger.debug(
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.",
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
)
# 4. 转换为 RetrievalResult (获取元数据)
@@ -160,7 +159,7 @@ class RetrievalManager:
"chunk_index": fr.chunk_index,
"char_count": len(fr.content),
},
),
)
)
# 5. Rerank
@@ -189,7 +188,7 @@ class RetrievalManager:
async def _dense_retrieve(
self,
query: str,
kb_ids: list[str],
kb_ids: List[str],
kb_options: dict,
):
"""稠密检索 (向量相似度)
@@ -203,7 +202,6 @@ class RetrievalManager:
Returns:
List[Result]: 检索结果列表
"""
all_results: list[Result] = []
for kb_id in kb_ids:
@@ -235,10 +233,10 @@ class RetrievalManager:
async def _rerank(
self,
query: str,
results: list[RetrievalResult],
results: List[RetrievalResult],
top_k: int,
rerank_provider: RerankProvider,
) -> list[RetrievalResult]:
) -> List[RetrievalResult]:
"""Rerank 重排序
Args:
@@ -248,7 +246,6 @@ class RetrievalManager:
Returns:
List[RetrievalResult]: 重排序后的结果列表
"""
if not results:
return []

View File

@@ -37,7 +37,6 @@ class RankFusion:
Args:
kb_db: 知识库数据库实例
k: RRF 参数,用于平滑排名
"""
self.kb_db = kb_db
self.k = k
@@ -60,7 +59,6 @@ class RankFusion:
Returns:
List[FusedResult]: 融合后的结果列表
"""
# 1. 构建排名映射
dense_ranks = {
@@ -103,9 +101,7 @@ class RankFusion:
# 4. 排序
sorted_ids = sorted(
rrf_scores.keys(),
key=lambda cid: rrf_scores[cid],
reverse=True,
rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
)[:top_k]
# 5. 构建融合结果
@@ -122,7 +118,7 @@ class RankFusion:
kb_id=sr.kb_id,
content=sr.content,
score=rrf_scores[identifier],
),
)
)
elif identifier in vec_doc_id_to_dense:
# 从向量检索获取信息,需要从数据库获取块的详细信息
@@ -136,7 +132,7 @@ class RankFusion:
kb_id=chunk_md["kb_id"],
content=vec_result.data["text"],
score=rrf_scores[identifier],
),
)
)
return fused_results

View File

@@ -3,15 +3,13 @@
使用 BM25 算法进行基于关键词的文档检索
"""
import json
import os
from dataclasses import dataclass
import jieba
import os
import json
from dataclasses import dataclass
from rank_bm25 import BM25Okapi
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
@dataclass
@@ -39,7 +37,6 @@ class SparseRetriever:
Args:
kb_db: 知识库数据库实例
"""
self.kb_db = kb_db
self._index_cache = {} # 缓存 BM25 索引
@@ -67,7 +64,6 @@ class SparseRetriever:
Returns:
List[SparseResult]: 检索结果列表
"""
# 1. 获取所有相关块
top_k_sparse = 0
@@ -77,9 +73,7 @@ class SparseRetriever:
if not vec_db:
continue
result = await vec_db.document_storage.get_documents(
metadata_filters={},
limit=None,
offset=None,
metadata_filters={}, limit=None, offset=None
)
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
result = [
@@ -128,7 +122,7 @@ class SparseRetriever:
kb_id=chunk["kb_id"],
content=chunk["text"],
score=float(score),
),
)
)
results.sort(key=lambda x: x.score, reverse=True)

View File

@@ -1,4 +1,5 @@
"""日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
"""
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
const:
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
@@ -20,14 +21,14 @@ function:
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
"""
import asyncio
import logging
import colorlog
import asyncio
import os
import sys
from asyncio import Queue
from collections import deque
import colorlog
from asyncio import Queue
from typing import List
# 日志缓存大小
CACHED_SIZE = 200
@@ -51,7 +52,6 @@ def is_plugin_path(pathname):
Returns:
bool: 如果路径来自插件目录,则返回 True否则返回 False
"""
if not pathname:
return False
@@ -68,7 +68,6 @@ def get_short_level_name(level_name):
Returns:
str: 四个字母的日志级别缩写
"""
level_map = {
"DEBUG": "DBUG",
@@ -88,14 +87,13 @@ class LogBroker:
def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: list[Queue] = [] # 订阅者列表
self.subscribers: List[Queue] = [] # 订阅者列表
def register(self) -> Queue:
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Returns:
Queue: 订阅者的队列, 可用于接收日志消息
"""
q = Queue(maxsize=CACHED_SIZE + 10)
self.subscribers.append(q)
@@ -106,7 +104,6 @@ class LogBroker:
Args:
q (Queue): 需要取消订阅的队列
"""
self.subscribers.remove(q)
@@ -116,7 +113,6 @@ class LogBroker:
Args:
log_entry (dict): 日志消息, 包含日志级别和日志内容.
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
"""
self.log_cache.append(log_entry)
for q in self.subscribers:
@@ -142,7 +138,6 @@ class LogQueueHandler(logging.Handler):
Args:
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record)
self.log_broker.publish(
@@ -150,7 +145,7 @@ class LogQueueHandler(logging.Handler):
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
},
}
)
@@ -169,7 +164,6 @@ class LogManager:
Returns:
logging.Logger: 返回配置好的日志记录器
"""
logger = logging.getLogger(log_name)
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
@@ -177,10 +171,10 @@ class LogManager:
return logger
# 如果logger没有处理器
console_handler = logging.StreamHandler(
sys.stdout,
sys.stdout
) # 创建一个StreamHandler用于控制台输出
console_handler.setLevel(
logging.DEBUG,
logging.DEBUG
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
@@ -201,8 +195,7 @@ class LogManager:
class FileNameFilter(logging.Filter):
"""文件名过滤器类, 用于修改日志记录的文件名格式
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式
"""
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record):
@@ -238,7 +231,6 @@ class LogManager:
Args:
logger (logging.Logger): 日志记录器
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
"""
handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG)
@@ -248,7 +240,7 @@ class LogManager:
# 为队列处理器设置相同格式的formatter
handler.setFormatter(
logging.Formatter(
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s",
),
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
)
)
logger.addHandler(handler)

View File

@@ -1,4 +1,5 @@
"""MIT License
"""
MIT License
Copyright (c) 2021 Lxns-Network
@@ -25,6 +26,7 @@ import asyncio
import base64
import json
import os
import typing as T
import uuid
from enum import Enum
@@ -36,36 +38,60 @@ from astrbot.core.utils.io import download_file, download_image_by_url, file_to_
class ComponentType(str, Enum):
# Basic Segment Types
Plain = "Plain" # plain text message
Image = "Image" # image
Record = "Record" # audio
Video = "Video" # video
File = "File" # file attachment
Plain = "Plain" # 纯文本消息
Face = "Face" # QQ表情
Record = "Record" # 语音
Video = "Video" # 视频
At = "At" # At
Node = "Node" # 转发消息的一个节点
Nodes = "Nodes" # 转发消息的多个节点
Poke = "Poke" # QQ 戳一戳
Image = "Image" # 图片
Reply = "Reply" # 回复
Forward = "Forward" # 转发消息
File = "File" # 文件
# IM-specific Segment Types
Face = "Face" # Emoji segment for Tencent QQ platform
At = "At" # mention a user in IM apps
Node = "Node" # a node in a forwarded message
Nodes = "Nodes" # a forwarded message consisting of multiple nodes
Poke = "Poke" # a poke message for Tencent QQ platform
Reply = "Reply" # a reply message segment
Forward = "Forward" # a forwarded message segment
RPS = "RPS" # TODO
Dice = "Dice" # TODO
Shake = "Shake" # TODO
Anonymous = "Anonymous" # TODO
Share = "Share"
Contact = "Contact" # TODO
Location = "Location" # TODO
Music = "Music"
RedBag = "RedBag"
Xml = "Xml"
Json = "Json"
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
class BaseMessageComponent(BaseModel):
type: ComponentType
def toString(self):
output = f"[CQ:{self.type.lower()}"
for k, v in self.__dict__.items():
if k == "type" or v is None:
continue
if k == "_type":
k = "type"
if isinstance(v, bool):
v = 1 if v else 0
output += ",%s=%s" % (
k,
str(v)
.replace("&", "&amp;")
.replace(",", "&#44;")
.replace("[", "&#91;")
.replace("]", "&#93;"),
)
output += "]"
return output
def toDict(self):
data = {}
for k, v in self.__dict__.items():
@@ -84,11 +110,18 @@ class BaseMessageComponent(BaseModel):
class Plain(BaseMessageComponent):
type = ComponentType.Plain
text: str
convert: bool | None = True
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
def __init__(self, text: str, convert: bool = True, **_):
super().__init__(text=text, convert=convert, **_)
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
if not self.convert:
return self.text
return (
self.text.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
)
def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}}
@@ -106,17 +139,17 @@ class Face(BaseMessageComponent):
class Record(BaseMessageComponent):
type = ComponentType.Record
file: str | None = ""
magic: bool | None = False
url: str | None = ""
cache: bool | None = True
proxy: bool | None = True
timeout: int | None = 0
file: T.Optional[str] = ""
magic: T.Optional[bool] = False
url: T.Optional[str] = ""
cache: T.Optional[bool] = True
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
path: str | None
path: T.Optional[str]
def __init__(self, file: str | None, **_):
for k in _:
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():
if k == "url":
pass
# Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}")
@@ -141,16 +174,15 @@ class Record(BaseMessageComponent):
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
return self.file[8:]
if self.file.startswith("http"):
elif self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
if self.file.startswith("base64://"):
elif self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -158,16 +190,16 @@ class Record(BaseMessageComponent):
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
if os.path.exists(self.file):
elif os.path.exists(self.file):
return os.path.abspath(self.file)
raise Exception(f"not a valid file: {self.file}")
else:
raise Exception(f"not a valid file: {self.file}")
async def convert_to_base64(self) -> str:
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
Returns:
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if not self.file:
@@ -187,14 +219,14 @@ class Record(BaseMessageComponent):
return bs64_data
async def register_to_file_service(self) -> str:
"""将语音注册到文件服务。
"""
将语音注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
@@ -213,10 +245,10 @@ class Record(BaseMessageComponent):
class Video(BaseMessageComponent):
type = ComponentType.Video
file: str
cover: str | None = ""
c: int | None = 2
cover: T.Optional[str] = ""
c: T.Optional[int] = 2
# 额外
path: str | None = ""
path: T.Optional[str] = ""
def __init__(self, file: str, **_):
super().__init__(file=file, **_)
@@ -236,31 +268,32 @@ class Video(BaseMessageComponent):
Returns:
str: 视频的本地路径,以绝对路径表示。
"""
url = self.file
if url and url.startswith("file:///"):
return url[8:]
if url and url.startswith("http"):
elif url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
raise Exception(f"download failed: {url}")
if os.path.exists(url):
else:
raise Exception(f"download failed: {url}")
elif os.path.exists(url):
return os.path.abspath(url)
raise Exception(f"not a valid file: {url}")
else:
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self):
"""将视频注册到文件服务。
"""
将视频注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
@@ -297,8 +330,8 @@ class Video(BaseMessageComponent):
class At(BaseMessageComponent):
type = ComponentType.At
qq: int | str # 此处str为all时代表所有人
name: str | None = ""
qq: T.Union[int, str] # 此处str为all时代表所有人
name: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
@@ -338,12 +371,20 @@ class Shake(BaseMessageComponent): # TODO
super().__init__(**_)
class Anonymous(BaseMessageComponent): # TODO
type = ComponentType.Anonymous
ignore: T.Optional[bool] = False
def __init__(self, **_):
super().__init__(**_)
class Share(BaseMessageComponent):
type = ComponentType.Share
url: str
title: str
content: str | None = ""
image: str | None = ""
content: T.Optional[str] = ""
image: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
@@ -352,7 +393,7 @@ class Share(BaseMessageComponent):
class Contact(BaseMessageComponent): # TODO
type = ComponentType.Contact
_type: str # type 字段冲突
id: int | None = 0
id: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
@@ -362,8 +403,8 @@ class Location(BaseMessageComponent): # TODO
type = ComponentType.Location
lat: float
lon: float
title: str | None = ""
content: str | None = ""
title: T.Optional[str] = ""
content: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
@@ -372,12 +413,12 @@ class Location(BaseMessageComponent): # TODO
class Music(BaseMessageComponent):
type = ComponentType.Music
_type: str
id: int | None = 0
url: str | None = ""
audio: str | None = ""
title: str | None = ""
content: str | None = ""
image: str | None = ""
id: T.Optional[int] = 0
url: T.Optional[str] = ""
audio: T.Optional[str] = ""
title: T.Optional[str] = ""
content: T.Optional[str] = ""
image: T.Optional[str] = ""
def __init__(self, **_):
# for k in _.keys():
@@ -388,18 +429,18 @@ class Music(BaseMessageComponent):
class Image(BaseMessageComponent):
type = ComponentType.Image
file: str | None = ""
_type: str | None = ""
subType: int | None = 0
url: str | None = ""
cache: bool | None = True
id: int | None = 40000
c: int | None = 2
file: T.Optional[str] = ""
_type: T.Optional[str] = ""
subType: T.Optional[int] = 0
url: T.Optional[str] = ""
cache: T.Optional[bool] = True
id: T.Optional[int] = 40000
c: T.Optional[int] = 2
# 额外
path: str | None = ""
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
path: T.Optional[str] = ""
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: str | None, **_):
def __init__(self, file: T.Optional[str], **_):
super().__init__(file=file, **_)
@staticmethod
@@ -429,17 +470,16 @@ class Image(BaseMessageComponent):
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
return url[8:]
if url.startswith("http"):
elif url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
if url.startswith("base64://"):
elif url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -447,16 +487,16 @@ class Image(BaseMessageComponent):
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
if os.path.exists(url):
elif os.path.exists(url):
return os.path.abspath(url)
raise Exception(f"not a valid file: {url}")
else:
raise Exception(f"not a valid file: {url}")
async def convert_to_base64(self) -> str:
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
Returns:
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url or self.file
@@ -477,14 +517,14 @@ class Image(BaseMessageComponent):
return bs64_data
async def register_to_file_service(self) -> str:
"""将图片注册到文件服务。
"""
将图片注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
@@ -502,34 +542,42 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
type = ComponentType.Reply
id: str | int
id: T.Union[str, int]
"""所引用的消息 ID"""
chain: list["BaseMessageComponent"] | None = []
chain: T.Optional[T.List["BaseMessageComponent"]] = []
"""被引用的消息段列表"""
sender_id: int | None | str = 0
sender_id: T.Optional[int] | T.Optional[str] = 0
"""被引用的消息对应的发送者的 ID"""
sender_nickname: str | None = ""
sender_nickname: T.Optional[str] = ""
"""被引用的消息对应的发送者的昵称"""
time: int | None = 0
time: T.Optional[int] = 0
"""被引用的消息发送时间"""
message_str: str | None = ""
message_str: T.Optional[str] = ""
"""被引用的消息解析后的纯文本消息字符串"""
text: str | None = ""
text: T.Optional[str] = ""
"""deprecated"""
qq: int | None = 0
qq: T.Optional[int] = 0
"""deprecated"""
seq: int | None = 0
seq: T.Optional[int] = 0
"""deprecated"""
def __init__(self, **_):
super().__init__(**_)
class RedBag(BaseMessageComponent):
type = ComponentType.RedBag
title: str
def __init__(self, **_):
super().__init__(**_)
class Poke(BaseMessageComponent):
type: str = ComponentType.Poke
id: int | None = 0
qq: int | None = 0
id: T.Optional[int] = 0
qq: T.Optional[int] = 0
def __init__(self, type: str, **_):
type = f"Poke:{type}"
@@ -548,12 +596,12 @@ class Node(BaseMessageComponent):
"""群合并转发消息"""
type = ComponentType.Node
id: int | None = 0 # 忽略
name: str | None = "" # qq昵称
uin: str | None = "0" # qq号
content: list[BaseMessageComponent] | None = []
seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
content: T.Optional[list[BaseMessageComponent]] = []
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0 # 忽略
def __init__(self, content: list[BaseMessageComponent], **_):
if isinstance(content, Node):
@@ -571,7 +619,7 @@ class Node(BaseMessageComponent):
{
"type": comp.type.lower(),
"data": {"file": f"base64://{bs64}"},
},
}
)
elif isinstance(comp, Plain):
# For Plain segments, we need to handle the plain differently
@@ -600,9 +648,9 @@ class Node(BaseMessageComponent):
class Nodes(BaseMessageComponent):
type = ComponentType.Nodes
nodes: list[Node]
nodes: T.List[Node]
def __init__(self, nodes: list[Node], **_):
def __init__(self, nodes: T.List[Node], **_):
super().__init__(nodes=nodes, **_)
def toDict(self):
@@ -624,10 +672,19 @@ class Nodes(BaseMessageComponent):
return ret
class Xml(BaseMessageComponent):
type = ComponentType.Xml
data: str
resid: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
class Json(BaseMessageComponent):
type = ComponentType.Json
data: str | dict
resid: int | None = 0
data: T.Union[str, dict]
resid: T.Optional[int] = 0
def __init__(self, data, **_):
if isinstance(data, dict):
@@ -635,18 +692,50 @@ class Json(BaseMessageComponent):
super().__init__(data=data, **_)
class CardImage(BaseMessageComponent):
type = ComponentType.CardImage
file: str
cache: T.Optional[bool] = True
minwidth: T.Optional[int] = 400
minheight: T.Optional[int] = 400
maxwidth: T.Optional[int] = 500
maxheight: T.Optional[int] = 500
source: T.Optional[str] = ""
icon: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
@staticmethod
def fromFileSystem(path, **_):
return CardImage(file=f"file:///{os.path.abspath(path)}", **_)
class TTS(BaseMessageComponent):
type = ComponentType.TTS
text: str
def __init__(self, **_):
super().__init__(**_)
class Unknown(BaseMessageComponent):
type = ComponentType.Unknown
text: str
def toString(self):
return ""
class File(BaseMessageComponent):
"""文件消息段"""
"""
文件消息段
"""
type = ComponentType.File
name: str | None = "" # 名字
file_: str | None = "" # 本地路径
url: str | None = "" # url
name: T.Optional[str] = "" # 名字
file_: T.Optional[str] = "" # 本地路径
url: T.Optional[str] = "" # url
def __init__(self, name: str, file: str = "", url: str = ""):
"""文件消息段。"""
@@ -654,11 +743,11 @@ class File(BaseMessageComponent):
@property
def file(self) -> str:
"""获取文件路径如果文件不存在但有URL则同步下载文件
"""
获取文件路径如果文件不存在但有URL则同步下载文件
Returns:
str: 文件路径
"""
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
@@ -668,16 +757,19 @@ class File(BaseMessageComponent):
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段"
)
)
return ""
# 等待下载完成
loop.run_until_complete(self._download_file())
else:
# 等待下载完成
loop.run_until_complete(self._download_file())
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
@@ -685,11 +777,11 @@ class File(BaseMessageComponent):
@file.setter
def file(self, value: str):
"""向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
"""
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
Args:
value (str): 文件路径或URL
"""
if value.startswith("http://") or value.startswith("https://"):
self.url = value
@@ -704,7 +796,6 @@ class File(BaseMessageComponent):
注意,如果为 True也可能返回文件路径。
Returns:
str: 文件路径或者 http 下载链接
"""
if allow_return_url and self.url:
return self.url
@@ -722,19 +813,20 @@ class File(BaseMessageComponent):
"""下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
fname = self.name if self.name else uuid.uuid4().hex
file_path = os.path.join(download_dir, fname)
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
async def register_to_file_service(self):
"""将文件注册到文件服务。
"""
将文件注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
@@ -772,38 +864,41 @@ class File(BaseMessageComponent):
class WechatEmoji(BaseMessageComponent):
type = ComponentType.WechatEmoji
md5: str | None = ""
md5_len: int | None = 0
cdnurl: str | None = ""
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
ComponentTypes = {
# Basic Message Segments
"plain": Plain,
"text": Plain,
"image": Image,
"face": Face,
"record": Record,
"video": Video,
"file": File,
# IM-specific Message Segments
"face": Face,
"at": At,
"rps": RPS,
"dice": Dice,
"shake": Shake,
"anonymous": Anonymous,
"share": Share,
"contact": Contact,
"location": Location,
"music": Music,
"image": Image,
"reply": Reply,
"redbag": RedBag,
"poke": Poke,
"forward": Forward,
"node": Node,
"nodes": Nodes,
"xml": Xml,
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown,
"file": File,
"WechatEmoji": WechatEmoji,
}

View File

@@ -1,16 +1,15 @@
import enum
from collections.abc import AsyncGenerator
from typing import List, Optional, Union, AsyncGenerator
from dataclasses import dataclass, field
from typing_extensions import deprecated
from astrbot.core.message.components import (
BaseMessageComponent,
Plain,
Image,
At,
AtAll,
BaseMessageComponent,
Image,
Plain,
)
from typing_extensions import deprecated
@dataclass
@@ -21,18 +20,18 @@ class MessageChain:
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
"""
chain: list[BaseMessageComponent] = field(default_factory=list)
use_t2i_: bool | None = None # None 为跟随用户设置
type: str | None = None
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
type: Optional[str] = None
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
def message(self, message: str):
"""添加一条文本消息到消息链 `chain` 中。
Example:
CommandResult().message("Hello ").message("world!")
# 输出 Hello world!
@@ -40,10 +39,11 @@ class MessageChain:
self.chain.append(Plain(message))
return self
def at(self, name: str, qq: str | int):
def at(self, name: str, qq: Union[str, int]):
"""添加一条 At 消息到消息链 `chain` 中。
Example:
CommandResult().at("张三", "12345678910")
# 输出 @张三
@@ -55,6 +55,7 @@ class MessageChain:
"""添加一条 AtAll 消息到消息链 `chain` 中。
Example:
CommandResult().at_all()
# 输出 @所有人
@@ -67,6 +68,7 @@ class MessageChain:
"""添加一条错误消息到消息链 `chain` 中
Example:
CommandResult().error("解析失败")
"""
@@ -80,6 +82,7 @@ class MessageChain:
如果需要发送本地图片,请使用 `file_image` 方法。
Example:
CommandResult().image("https://example.com/image.jpg")
"""
@@ -93,7 +96,6 @@ class MessageChain:
如果需要发送网络图片,请使用 `url_image` 方法。
CommandResult().image("image.jpg")
"""
self.chain.append(Image.fromFileSystem(path))
return self
@@ -112,7 +114,6 @@ class MessageChain:
Args:
use_t2i (bool): 是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
"""
self.use_t2i_ = use_t2i
return self
@@ -124,7 +125,7 @@ class MessageChain:
def squash_plain(self):
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
if not self.chain:
return None
return
new_chain = []
first_plain = None
@@ -152,7 +153,6 @@ class EventResultType(enum.Enum):
Attributes:
CONTINUE: 事件将会继续传播
STOP: 事件将会终止传播
"""
CONTINUE = enum.auto()
@@ -181,18 +181,17 @@ class MessageEventResult(MessageChain):
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`result_type` (EventResultType): 事件处理的结果类型。
"""
result_type: EventResultType | None = field(
default_factory=lambda: EventResultType.CONTINUE,
result_type: Optional[EventResultType] = field(
default_factory=lambda: EventResultType.CONTINUE
)
result_content_type: ResultContentType | None = field(
default_factory=lambda: ResultContentType.GENERAL_RESULT,
result_content_type: Optional[ResultContentType] = field(
default_factory=lambda: ResultContentType.GENERAL_RESULT
)
async_stream: AsyncGenerator | None = None
async_stream: Optional[AsyncGenerator] = None
"""异步流"""
def stop_event(self) -> "MessageEventResult":
@@ -206,7 +205,9 @@ class MessageEventResult(MessageChain):
return self
def is_stopped(self) -> bool:
"""是否终止事件传播。"""
"""
是否终止事件传播。
"""
return self.result_type == EventResultType.STOP
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
@@ -219,7 +220,6 @@ class MessageEventResult(MessageChain):
Args:
result_type (EventResultType): 事件处理的结果类型。
"""
self.result_content_type = typ
return self

View File

@@ -1,8 +1,8 @@
from astrbot import logger
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, Personality
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.platform.message_session import MessageSession
from astrbot import logger
DEFAULT_PERSONALITY = Personality(
prompt="You are a helpful and friendly assistant.",
@@ -41,14 +41,12 @@ class PersonaManager:
return persona
async def get_default_persona_v3(
self,
umo: str | MessageSession | None = None,
self, umo: str | MessageSession | None = None
) -> Personality:
"""获取默认 persona"""
cfg = self.acm.get_conf(umo)
default_persona_id = cfg.get("provider_settings", {}).get(
"default_personality",
"default",
"default_personality", "default"
)
if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY
@@ -68,19 +66,16 @@ class PersonaManager:
async def update_persona(
self,
persona_id: str,
system_prompt: str | None = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id)
if not existing_persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
persona = await self.db.update_persona(
persona_id,
system_prompt,
begin_dialogs,
tools=tools,
persona_id, system_prompt, begin_dialogs, tools=tools
)
if persona:
for i, p in enumerate(self.personas):
@@ -105,10 +100,7 @@ class PersonaManager:
if await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} already exists.")
new_persona = await self.db.insert_persona(
persona_id,
system_prompt,
begin_dialogs,
tools=tools,
persona_id, system_prompt, begin_dialogs, tools=tools
)
self.personas.append(new_persona)
self.get_v3_persona_data()
@@ -123,7 +115,6 @@ class PersonaManager:
- list[dict]: 包含 persona 配置的字典列表。
- list[Personality]: 包含 Personality 对象的列表。
- Personality: 默认选择的 Personality 对象。
"""
v3_persona_config = [
{
@@ -145,7 +136,7 @@ class PersonaManager:
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。",
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
@@ -155,7 +146,7 @@ class PersonaManager:
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
},
}
)
user_turn = not user_turn

View File

@@ -27,15 +27,15 @@ STAGES_ORDER = [
]
__all__ = [
"ContentSafetyCheckStage",
"EventResultType",
"MessageEventResult",
"PreProcessStage",
"ProcessStage",
"RateLimitStage",
"RespondStage",
"ResultDecorateStage",
"SessionStatusCheckStage",
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
"MessageEventResult",
"EventResultType",
]

View File

@@ -1,11 +1,9 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from .strategies.strategy import StrategySelector
@@ -21,10 +19,8 @@ class ContentSafetyCheckStage(Stage):
self.strategy_selector = StrategySelector(config)
async def process(
self,
event: AstrMessageEvent,
check_text: str | None = None,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent, check_text: str | None = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
@@ -32,8 +28,8 @@ class ContentSafetyCheckStage(Stage):
if event.is_at_or_wake_command:
event.set_result(
MessageEventResult().message(
"你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。",
),
"你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"
)
)
yield
event.stop_event()

View File

@@ -1,7 +1,8 @@
import abc
from typing import Tuple
class ContentSafetyStrategy(abc.ABC):
@abc.abstractmethod
def check(self, content: str) -> tuple[bool, str]:
def check(self, content: str) -> Tuple[bool, str]:
raise NotImplementedError

View File

@@ -1,8 +1,9 @@
"""使用此功能应该先 pip install baidu-aip"""
from aip import AipContentCensor
"""
使用此功能应该先 pip install baidu-aip
"""
from . import ContentSafetyStrategy
from aip import AipContentCensor
class BaiduAipStrategy(ContentSafetyStrategy):
@@ -18,12 +19,12 @@ class BaiduAipStrategy(ContentSafetyStrategy):
return False, ""
if res["conclusionType"] == 1:
return True, ""
if "data" not in res:
return False, ""
count = len(res["data"])
parts = [f"百度审核服务发现 {count} 处违规:\n"]
for i in res["data"]:
parts.append(f"{i['msg']}\n")
parts.append("\n判断结果:" + res["conclusion"])
info = "".join(parts)
return False, info
else:
if "data" not in res:
return False, ""
count = len(res["data"])
info = f"百度审核服务发现 {count} 处违规:\n"
for i in res["data"]:
info += f"{i['msg']}\n"
info += "\n判断结果:" + res["conclusion"]
return False, info

View File

@@ -1,5 +1,4 @@
import re
from . import ContentSafetyStrategy

View File

@@ -1,16 +1,16 @@
from astrbot import logger
from . import ContentSafetyStrategy
from typing import List, Tuple
from astrbot import logger
class StrategySelector:
def __init__(self, config: dict) -> None:
self.enabled_strategies: list[ContentSafetyStrategy] = []
self.enabled_strategies: List[ContentSafetyStrategy] = []
if config["internal_keywords"]["enable"]:
from .keywords import KeywordsStrategy
self.enabled_strategies.append(
KeywordsStrategy(config["internal_keywords"]["extra_keywords"]),
KeywordsStrategy(config["internal_keywords"]["extra_keywords"])
)
if config["baidu_aip"]["enable"]:
try:
@@ -23,10 +23,10 @@ class StrategySelector:
config["baidu_aip"]["app_id"],
config["baidu_aip"]["api_key"],
config["baidu_aip"]["secret_key"],
),
)
)
def check(self, content: str) -> tuple[bool, str]:
def check(self, content: str) -> Tuple[bool, str]:
for strategy in self.enabled_strategies:
ok, info = strategy.check(content)
if not ok:

View File

@@ -1,9 +1,7 @@
from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
from .context_utils import call_handler, call_event_hook
@dataclass
@@ -15,4 +13,3 @@ class PipelineContext:
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
call_local_llm_tool = call_local_llm_tool

View File

@@ -1,14 +1,11 @@
import inspect
import traceback
import typing as T
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
@@ -29,7 +26,6 @@ async def call_handler(
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
@@ -84,17 +80,14 @@ async def call_event_hook(
Returns:
bool: 如果事件被终止,返回 True
#
"""
#"""
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type,
plugins_name=event.plugins_name,
hook_type, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args, **kwargs)
except BaseException:
@@ -102,71 +95,8 @@ async def call_event_hook(
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。",
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return True
return event.is_stopped()
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
method_name: str,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
elif method_name == "call":
ready_to_call = handler(context, *args, **kwargs)
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -1,14 +1,12 @@
import traceback
import asyncio
import random
import traceback
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.message.components import Image, Plain, Record
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Record, Image
@register_stage
@@ -22,9 +20,8 @@ class PreProcessStage(Stage):
self.platform_settings: dict = self.config.get("platform_settings", {})
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""在处理事件之前的预处理"""
# 平台特异配置platform_specific.<platform>.pre_ack_emoji
supported = {"telegram", "lark"}
@@ -71,7 +68,7 @@ class PreProcessStage(Stage):
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
if not stt_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。",
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
)
return
message_chain = event.get_messages()

View File

@@ -1,24 +1,15 @@
"""本地 Agent 模式的 LLM 调用 Stage"""
"""
本地 Agent 模式的 LLM 调用 Stage
"""
import asyncio
import copy
import json
import traceback
from datetime import timedelta
from collections.abc import AsyncGenerator
from typing import Any
from mcp.types import CallToolResult
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
@@ -31,14 +22,21 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.register import llm_tools
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
from ...context import PipelineContext, call_event_hook, call_handler
from ..stage import Stage
from ..utils import inject_kb_context
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.star_handler import star_map
from astrbot.core.astr_agent_context import AstrAgentContext
try:
import mcp
@@ -61,22 +59,23 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
elif isinstance(tool, MCPTool):
if tool.origin == "local":
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
elif tool.origin == "mcp":
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
else:
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
raise Exception(f"Unknown function origin: {tool.origin}")
@classmethod
async def _execute_handoff(
@@ -114,22 +113,18 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
event=run_context.context.event,
)
event = run_context.context.event
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
await run_context.event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=run_context.tool_call_timeout,
context=astr_agent_ctx, event=run_context.event
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
@@ -151,7 +146,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
result = (
@@ -179,46 +174,25 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
event = run_context.context.event
if not event:
if not run_context.event:
raise ValueError("Event must be provided for local function tools.")
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
logger.debug(f"Found call in: {ty}")
is_override_call = True
break
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
if not tool.handler and not hasattr(tool, "run"):
raise ValueError("Tool must have a valid handler or 'run' method.")
awaitable = tool.handler or getattr(tool, "run")
awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
wrapper = call_local_llm_tool(
context=run_context,
wrapper = call_handler(
event=run_context.event,
handler=awaitable,
method_name=method_name,
**tool_args,
)
# async for resp in wrapper:
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.tool_call_timeout,
timeout=run_context.context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
@@ -233,24 +207,10 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
)
except StopAsyncIteration:
break
@@ -262,7 +222,19 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
res = await tool.call(run_context, **tool_args)
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
read_timeout_seconds=timedelta(
seconds=run_context.context.tool_call_timeout
),
)
if not res:
return
yield res
@@ -272,31 +244,18 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
run_context.event, EventType.OnLLMResponseEvent, llm_response
)
async def on_tool_end(
self,
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
run_context.context.event.clear_result()
MAIN_AGENT_HOOKS = MainAgentHooks()
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
astr_event = agent_runner.run_context.event
while step_idx < max_step:
step_idx += 1
try:
@@ -331,18 +290,19 @@ async def run_agent(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
),
)
)
yield
astr_event.clear_result()
elif resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
@@ -372,7 +332,7 @@ class LLMRequestSubStage(Stage):
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。"
)
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
@@ -407,9 +367,7 @@ class LLMRequestSubStage(Stage):
return conversation
async def process(
self,
event: AstrMessageEvent,
_nested: bool = False,
self, event: AstrMessageEvent, _nested: bool = False
) -> None | AsyncGenerator[None, None]:
req: ProviderRequest | None = None
@@ -465,9 +423,7 @@ class LLMRequestSubStage(Stage):
# 应用知识库
try:
await inject_kb_context(
umo=event.unified_msg_origin,
p_ctx=self.ctx,
req=req,
umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
)
except Exception as e:
logger.error(f"调用知识库时遇到问题: {e}")
@@ -519,7 +475,7 @@ class LLMRequestSubStage(Stage):
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
if "tool_use" not in provider_cfg:
logger.debug(
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
)
req.func_tool = None
# 插件可用性设置
@@ -542,22 +498,19 @@ class LLMRequestSubStage(Stage):
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
)
astr_agent_ctx = AstrAgentContext(
provider=provider,
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
event=event,
tool_call_timeout=self.tool_call_timeout,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
@@ -569,8 +522,8 @@ class LLMRequestSubStage(Stage):
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use),
),
run_agent(agent_runner, self.max_step, self.show_tool_use)
)
)
yield
if agent_runner.done():
@@ -587,7 +540,7 @@ class LLMRequestSubStage(Stage):
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
)
else:
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
@@ -607,21 +560,17 @@ class LLMRequestSubStage(Stage):
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)
)
async def _handle_webchat(
self,
event: AstrMessageEvent,
req: ProviderRequest,
prov: Provider,
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
if not req.conversation:
return
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin,
req.conversation.cid,
event.unified_msg_origin, req.conversation.cid
)
if conversation and not req.conversation.title:
messages = json.loads(conversation.history)
@@ -658,7 +607,7 @@ class LLMRequestSubStage(Stage):
)
if llm_resp and llm_resp.completion_text:
logger.debug(
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
)
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
@@ -701,9 +650,7 @@ class LLMRequestSubStage(Stage):
messages.append({"role": "assistant", "content": llm_response.completion_text})
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=messages,
event.unified_msg_origin, req.conversation.cid, history=messages
)
def fix_messages(self, messages: list[dict]) -> list[dict]:

View File

@@ -1,17 +1,16 @@
"""本地 Agent 模式的 AstrBot 插件调用 Stage"""
import traceback
from collections.abc import AsyncGenerator
from typing import Any
from astrbot.core import logger
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
"""
本地 Agent 模式的 AstrBot 插件调用 Stage
"""
from ...context import PipelineContext, call_handler
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.star.star import star_map
import traceback
class StarRequestSubStage(Stage):
@@ -22,14 +21,13 @@ class StarRequestSubStage(Stage):
self.ctx = ctx
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers",
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
"activated_handlers"
)
handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra(
"handlers_parsed_params",
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra(
"handlers_parsed_params"
)
if not handlers_parsed_params:
handlers_parsed_params = {}
@@ -39,7 +37,7 @@ class StarRequestSubStage(Stage):
md = star_map.get(handler.handler_module_path)
if not md:
logger.warning(
f"Cannot find plugin for given handler module path: {handler.handler_module_path}",
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
)
continue
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")

View File

@@ -1,14 +1,12 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.star.star_handler import StarHandlerMetadata
from ..context import PipelineContext
from typing import List, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core import logger
@register_stage
@@ -24,12 +22,11 @@ class ProcessStage(Stage):
await self.star_request_sub_stage.initialize(ctx)
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件"""
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers",
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
"activated_handlers"
)
# 有插件 Handler 被激活
if activated_handlers:

View File

@@ -1,7 +1,6 @@
from astrbot.api import logger, sp
from astrbot.core.provider.entities import ProviderRequest
from ..context import PipelineContext
from astrbot.core.provider.entities import ProviderRequest
from astrbot.api import logger, sp
async def inject_kb_context(
@@ -9,14 +8,14 @@ async def inject_kb_context(
p_ctx: PipelineContext,
req: ProviderRequest,
) -> None:
"""Inject knowledge base context into the provider request
"""inject knowledge base context into the provider request
Args:
umo: Unique message object (session ID)
p_ctx: Pipeline context
req: Provider request
"""
kb_mgr = p_ctx.plugin_manager.context.kb_manager
# 1. 优先读取会话级配置
@@ -46,7 +45,7 @@ async def inject_kb_context(
if invalid_kb_ids:
logger.warning(
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}",
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}"
)
if not kb_names:

View File

@@ -1,19 +1,18 @@
import asyncio
from collections import defaultdict, deque
from collections.abc import AsyncGenerator
from datetime import datetime, timedelta
from collections import defaultdict, deque
from typing import DefaultDict, Deque, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.config.astrbot_config import RateLimitStrategy
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from ..context import PipelineContext
from ..stage import Stage, register_stage
@register_stage
class RateLimitStage(Stage):
"""检查是否需要限制消息发送的限流器。
"""
检查是否需要限制消息发送的限流器。
使用 Fixed Window 算法。
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
@@ -21,30 +20,32 @@ class RateLimitStage(Stage):
def __init__(self):
# 存储每个会话的请求时间队列
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque)
# 为每个会话设置一个锁,避免并发冲突
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self.locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
# 限流参数
self.rate_limit_count: int = 0
self.rate_limit_time: timedelta = timedelta(0)
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化限流器,根据配置设置限流参数。"""
"""
初始化限流器,根据配置设置限流参数。
"""
self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][
"count"
]
self.rate_limit_time = timedelta(
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"],
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"]
)
self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][
"strategy"
] # stall or discard
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
"""检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
Args:
event (AstrMessageEvent): 当前消息事件。
@@ -52,7 +53,6 @@ class RateLimitStage(Stage):
Returns:
MessageEventResult: 继续或停止事件处理的结果。
"""
session_id = event.session_id
now = datetime.now()
@@ -66,33 +66,32 @@ class RateLimitStage(Stage):
if len(timestamps) < self.rate_limit_count:
timestamps.append(now)
break
next_window_time = timestamps[0] + self.rate_limit_time
stall_duration = (next_window_time - now).total_seconds() + 0.3
else:
next_window_time = timestamps[0] + self.rate_limit_time
stall_duration = (next_window_time - now).total_seconds() + 0.3
match self.rl_strategy:
case RateLimitStrategy.STALL.value:
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。",
)
await asyncio.sleep(stall_duration)
now = datetime.now()
case RateLimitStrategy.DISCARD.value:
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。",
)
return event.stop_event()
match self.rl_strategy:
case RateLimitStrategy.STALL.value:
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
)
await asyncio.sleep(stall_duration)
now = datetime.now()
case RateLimitStrategy.DISCARD.value:
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
)
return event.stop_event()
def _remove_expired_timestamps(
self,
timestamps: deque[datetime],
now: datetime,
self, timestamps: Deque[datetime], now: datetime
) -> None:
"""移除时间窗口外的时间戳。
"""
移除时间窗口外的时间戳。
Args:
timestamps (Deque[datetime]): 当前会话的时间戳队列。
now (datetime): 当前时间,用于计算过期时间。
"""
expiry_threshold: datetime = now - self.rate_limit_time
while timestamps and timestamps[0] < expiry_threshold:

View File

@@ -1,27 +1,25 @@
import random
import asyncio
import math
import random
from collections.abc import AsyncGenerator
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext, call_event_hook
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.components import BaseMessageComponent, ComponentType
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
from ..context import PipelineContext, call_event_hook
from ..stage import Stage, register_stage
@register_stage
class RespondStage(Stage):
# 组件类型到其非空判断函数的映射
_component_validators = {
Comp.Plain: lambda comp: bool(
comp.text and comp.text.strip(),
comp.text and comp.text.strip()
), # 纯文本消息需要strip
Comp.Face: lambda comp: comp.id is not None, # QQ表情
Comp.Record: lambda comp: bool(comp.file), # 语音
@@ -60,7 +58,7 @@ class RespondStage(Stage):
"segmented_reply"
]["interval_method"]
self.log_base = float(
ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"],
ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"]
)
interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][
"interval"
@@ -88,16 +86,17 @@ class RespondStage(Stage):
wc = await self._word_cnt(comp.text)
i = math.log(wc + 1, self.log_base)
return random.uniform(i, i + 0.5)
return random.uniform(1, 1.75)
# random
return random.uniform(self.interval[0], self.interval[1])
else:
return random.uniform(1, 1.75)
else:
# random
return random.uniform(self.interval[0], self.interval[1])
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
"""检查消息链是否为空
Args:
chain (list[BaseMessageComponent]): 包含消息对象的列表
"""
if not chain:
return True
@@ -151,9 +150,8 @@ class RespondStage(Stage):
return extracted
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
return
@@ -161,7 +159,7 @@ class RespondStage(Stage):
return
logger.info(
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}",
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
if result.result_content_type == ResultContentType.STREAMING_RESULT:
@@ -170,13 +168,12 @@ class RespondStage(Stage):
return
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented",
False,
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_id()})")
await event.send_streaming(result.async_stream, use_fallback)
return
if len(result.chain) > 0:
elif len(result.chain) > 0:
# 检查路径映射
if mappings := self.platform_settings.get("path_mapping", []):
for idx, component in enumerate(result.chain):
@@ -215,7 +212,7 @@ class RespondStage(Stage):
if not result.chain or len(result.chain) == 0:
# may fix #2670
logger.warning(
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
)
return
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
@@ -240,7 +237,7 @@ class RespondStage(Stage):
):
# may fix #2670
logger.warning(
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}",
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
)
return
sep_comps = self._extract_comp(

View File

@@ -1,7 +1,7 @@
import re
import time
import traceback
from collections.abc import AsyncGenerator
from typing import AsyncGenerator, Union
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
@@ -30,7 +30,8 @@ class ResultDecorateStage(Stage):
self.t2i_word_threshold = ctx.astrbot_config["t2i_word_threshold"]
try:
self.t2i_word_threshold = int(self.t2i_word_threshold)
self.t2i_word_threshold = max(self.t2i_word_threshold, 50)
if self.t2i_word_threshold < 50:
self.t2i_word_threshold = 50
except BaseException:
self.t2i_word_threshold = 150
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
@@ -45,7 +46,7 @@ class ResultDecorateStage(Stage):
self.words_count_threshold = int(
ctx.astrbot_config["platform_settings"]["segmented_reply"][
"words_count_threshold"
],
]
)
self.enable_segmented_reply = ctx.astrbot_config["platform_settings"][
"segmented_reply"
@@ -70,9 +71,8 @@ class ResultDecorateStage(Stage):
await self.content_safe_check_stage.initialize(ctx)
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None or not result.chain:
return
@@ -94,36 +94,34 @@ class ResultDecorateStage(Stage):
if isinstance(comp, Plain):
text += comp.text
async for _ in self.content_safe_check_stage.process(
event,
check_text=text,
event, check_text=text
):
yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnDecoratingResultEvent,
plugins_name=event.plugins_name,
EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
if is_stream:
logger.warning(
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
)
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。"
)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。",
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
@@ -162,9 +160,7 @@ class ResultDecorateStage(Stage):
new_chain.append(comp)
continue
split_response = re.findall(
self.regex,
comp.text,
re.DOTALL | re.MULTILINE,
self.regex, comp.text, re.DOTALL | re.MULTILINE
)
if not split_response:
new_chain.append(comp)
@@ -181,7 +177,7 @@ class ResultDecorateStage(Stage):
# TTS
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
event.unified_msg_origin,
event.unified_msg_origin
)
if (
@@ -191,7 +187,7 @@ class ResultDecorateStage(Stage):
):
if not tts_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。"
)
else:
new_chain = []
@@ -203,7 +199,7 @@ class ResultDecorateStage(Stage):
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
)
new_chain.append(comp)
continue
@@ -221,7 +217,7 @@ class ResultDecorateStage(Stage):
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path,
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
@@ -230,7 +226,7 @@ class ResultDecorateStage(Stage):
Record(
file=url or audio_path,
url=url or audio_path,
),
)
)
if dual_output:
new_chain.append(comp)
@@ -246,13 +242,12 @@ class ResultDecorateStage(Stage):
elif (
result.use_t2i_ is None and self.ctx.astrbot_config["t2i"]
) or result.use_t2i_:
parts = []
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
parts.append("\n\n" + comp.text)
plain_str += "\n\n" + comp.text
else:
break
plain_str = "".join(parts)
if plain_str and len(plain_str) > self.t2i_word_threshold:
render_start = time.time()
try:
@@ -267,7 +262,7 @@ class ResultDecorateStage(Stage):
return
if time.time() - render_start > 3:
logger.warning(
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。",
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
)
if url:
if url.startswith("http"):
@@ -291,9 +286,7 @@ class ResultDecorateStage(Stage):
word_cnt += len(comp.text)
if word_cnt > self.forward_threshold:
node = Node(
uin=event.get_self_id(),
name="AstrBot",
content=[*result.chain],
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
)
result.chain = [node]
@@ -305,8 +298,7 @@ class ResultDecorateStage(Stage):
and event.get_message_type() != MessageType.FRIEND_MESSAGE
):
result.chain.insert(
0,
At(qq=event.get_sender_id(), name=event.get_sender_name()),
0, At(qq=event.get_sender_id(), name=event.get_sender_name())
)
if len(result.chain) > 1 and isinstance(result.chain[1], Plain):
result.chain[1].text = "\n" + result.chain[1].text

View File

@@ -1,11 +1,9 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform import AstrMessageEvent
from . import STAGES_ORDER
from .context import PipelineContext
from .stage import registered_stages
from .context import PipelineContext
from typing import AsyncGenerator
from astrbot.core.platform import AstrMessageEvent
from astrbot.core import logger
class PipelineScheduler:
@@ -13,7 +11,7 @@ class PipelineScheduler:
def __init__(self, context: PipelineContext):
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__),
key=lambda x: STAGES_ORDER.index(x.__name__)
) # 按照顺序排序
self.ctx = context # 上下文对象
self.stages = [] # 存储阶段实例
@@ -31,13 +29,12 @@ class PipelineScheduler:
Args:
event (AstrMessageEvent): 事件对象
from_stage (int): 从第几个阶段开始执行, 默认从0开始
"""
for i in range(from_stage, len(self.stages)):
stage = self.stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__.__name__}")
coroutine = stage.process(
event,
event
) # 调用阶段的process方法, 返回协程或者异步生成器
if isinstance(coroutine, AsyncGenerator):
@@ -46,7 +43,7 @@ class PipelineScheduler:
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
if event.is_stopped():
logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。",
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
)
break
@@ -56,7 +53,7 @@ class PipelineScheduler:
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
if event.is_stopped():
logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。",
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
)
break
else:
@@ -73,7 +70,6 @@ class PipelineScheduler:
Args:
event (AstrMessageEvent): 事件对象
"""
await self._process_stages(event)

View File

@@ -1,11 +1,9 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_llm_manager import SessionServiceManager
from ..context import PipelineContext
from ..stage import Stage, register_stage
from astrbot.core import logger
@register_stage
@@ -17,21 +15,19 @@ class SessionStatusCheckStage(Stage):
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
conv_id = await self.conv_mgr.get_curr_conversation_id(
event.unified_msg_origin,
event.unified_msg_origin
)
if not conv_id:
await self.conv_mgr.new_conversation(
event.unified_msg_origin,
platform_id=event.get_platform_id(),
event.unified_msg_origin, platform_id=event.get_platform_id()
)
event.stop_event()

View File

@@ -1,13 +1,10 @@
from __future__ import annotations
import abc
from collections.abc import AsyncGenerator
from typing import List, AsyncGenerator, Union, Type
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
registered_stages: list[type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
def register_stage(cls):
@@ -25,21 +22,18 @@ class Stage(abc.ABC):
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
raise NotImplementedError
@abc.abstractmethod
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件
Args:
event (AstrMessageEvent): 事件对象,包含事件的相关信息
Returns:
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
"""
raise NotImplementedError

View File

@@ -1,11 +1,11 @@
from collections.abc import AsyncGenerator
from typing import AsyncGenerator, Union
from astrbot import logger
from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -30,12 +30,10 @@ class WakingCheckStage(Stage):
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply",
True,
"no_permission_reply", True
)
# 私聊是否需要 wake_prefix 才能唤醒机器人
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
@@ -43,18 +41,15 @@ class WakingCheckStage(Stage):
].get("friend_message_needs_wake_prefix", False)
# 是否忽略机器人自己发送的消息
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
"ignore_bot_self_message",
False,
"ignore_bot_self_message", False
)
self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get(
"ignore_at_all",
False,
"ignore_at_all", False
)
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
if (
self.ignore_bot_self_message
and event.get_self_id() == event.get_sender_id()
@@ -128,8 +123,7 @@ class WakingCheckStage(Stage):
logger.debug(f"enabled_plugins_name: {enabled_plugins_name}")
for handler in star_handlers_registry.get_handlers_by_event_type(
EventType.AdapterMessageEvent,
plugins_name=event.plugins_name,
EventType.AdapterMessageEvent, plugins_name=event.plugins_name
):
# filter 需满足 AND 逻辑关系
passed = True
@@ -144,14 +138,15 @@ class WakingCheckStage(Stage):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
permission_filter_raise_error = filter.raise_error
elif not filter.filter(event, self.ctx.astrbot_config):
passed = False
break
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
break
except Exception as e:
await event.send(
MessageEventResult().message(
f"插件 {star_map[handler.handler_module_path].name}: {e}",
),
f"插件 {star_map[handler.handler_module_path].name}: {e}"
)
)
event.stop_event()
passed = False
@@ -164,11 +159,11 @@ class WakingCheckStage(Stage):
if self.no_permission_reply:
await event.send(
MessageChain().message(
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。",
),
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
)
)
logger.info(
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。",
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
)
event.stop_event()
return
@@ -190,8 +185,7 @@ class WakingCheckStage(Stage):
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
event,
activated_handlers,
event, activated_handlers
)
event.set_extra("activated_handlers", activated_handlers)

Some files were not shown because too many files have changed in this diff Show More