Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
6a0d0a8a3a fix: remove unnecessary provider check
fixes: #3815
2025-11-29 23:14:23 +08:00
18 changed files with 214 additions and 461 deletions

2
.gitignore vendored
View File

@@ -48,5 +48,3 @@ astrbot.lock
chroma chroma
venv/* venv/*
pytest.ini pytest.ini
AGENTS.md
IFLOW.md

View File

@@ -1 +1 @@
__version__ = "4.7.3" __version__ = "4.7.1"

View File

@@ -3,7 +3,7 @@
from typing import Any, ClassVar, Literal, cast from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import core_schema from pydantic_core import core_schema
@@ -145,39 +145,22 @@ class Message(BaseModel):
"tool", "tool",
] ]
content: str | list[ContentPart] | None = None content: str | list[ContentPart]
"""The content of the message.""" """The content of the message."""
tool_calls: list[ToolCall] | list[dict] | None = None
"""The tool calls of the message."""
tool_call_id: str | None = None
"""The ID of the tool call."""
@model_validator(mode="after")
def check_content_required(self):
# assistant + tool_calls is not None: allow content to be None
if self.role == "assistant" and self.tool_calls is not None:
return self
# other all cases: content is required
if self.content is None:
raise ValueError(
"content is required unless role='assistant' and tool_calls is not None"
)
return self
class AssistantMessageSegment(Message): class AssistantMessageSegment(Message):
"""A message segment from the assistant.""" """A message segment from the assistant."""
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
tool_calls: list[ToolCall] | list[dict] | None = None
class ToolCallMessageSegment(Message): class ToolCallMessageSegment(Message):
"""A message segment representing a tool call.""" """A message segment representing a tool call."""
role: Literal["tool"] = "tool" role: Literal["tool"] = "tool"
tool_call_id: str
class UserMessageSegment(Message): class UserMessageSegment(Message):

View File

@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.7.3" VERSION = "4.7.1"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置 # 默认配置
@@ -73,7 +73,6 @@ DEFAULT_CONFIG = {
"coze_agent_runner_provider_id": "", "coze_agent_runner_provider_id": "",
"dashscope_agent_runner_provider_id": "", "dashscope_agent_runner_provider_id": "",
"unsupported_streaming_strategy": "realtime_segmenting", "unsupported_streaming_strategy": "realtime_segmenting",
"reachability_check": False,
"max_agent_step": 30, "max_agent_step": 30,
"tool_call_timeout": 60, "tool_call_timeout": 60,
}, },
@@ -91,7 +90,6 @@ DEFAULT_CONFIG = {
"group_icl_enable": False, "group_icl_enable": False,
"group_message_max_cnt": 300, "group_message_max_cnt": 300,
"image_caption": False, "image_caption": False,
"image_caption_provider_id": "",
"active_reply": { "active_reply": {
"enable": False, "enable": False,
"method": "possibility_reply", "method": "possibility_reply",
@@ -2111,9 +2109,6 @@ CONFIG_METADATA_2 = {
"image_caption": { "image_caption": {
"type": "bool", "type": "bool",
}, },
"image_caption_provider_id": {
"type": "string",
},
"image_caption_prompt": { "image_caption_prompt": {
"type": "string", "type": "string",
}, },
@@ -2497,11 +2492,6 @@ CONFIG_METADATA_3 = {
"description": "开启 TTS 时同时输出语音和文字内容", "description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool", "type": "bool",
}, },
"provider_settings.reachability_check": {
"description": "提供商可达性检测",
"type": "bool",
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
},
}, },
"condition": { "condition": {
"provider_settings.enable": True, "provider_settings.enable": True,
@@ -2795,16 +2785,7 @@ CONFIG_METADATA_3 = {
"provider_ltm_settings.image_caption": { "provider_ltm_settings.image_caption": {
"description": "自动理解图片", "description": "自动理解图片",
"type": "bool", "type": "bool",
"hint": "需要设置群聊图片转述模型。", "hint": "需要设置默认图片转述模型。",
},
"provider_ltm_settings.image_caption_provider_id": {
"description": "群聊图片转述模型",
"type": "string",
"_special": "select_provider",
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。",
"condition": {
"provider_ltm_settings.image_caption": True,
},
}, },
"provider_ltm_settings.active_reply.enable": { "provider_ltm_settings.active_reply.enable": {
"description": "主动回复", "description": "主动回复",

View File

@@ -2,7 +2,7 @@ import asyncio
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from astrbot.core import astrbot_config, logger from astrbot.core import logger
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner, DashscopeAgentRunner,
@@ -88,7 +88,7 @@ class ThirdPartyAgentSubStage(Stage):
return return
self.prov_cfg: dict = next( self.prov_cfg: dict = next(
(p for p in astrbot_config["provider"] if p["id"] == self.prov_id), (p for p in self.conf["provider"] if p["id"] == self.prov_id),
{}, {},
) )
if not self.prov_id: if not self.prov_id:

View File

@@ -250,7 +250,7 @@ class DingtalkPlatformAdapter(Platform):
async def terminate(self): async def terminate(self):
def monkey_patch_close(): def monkey_patch_close():
raise KeyboardInterrupt("Graceful shutdown") raise Exception("Graceful shutdown")
self.client_.open_connection = monkey_patch_close self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown") await self.client_.websocket.close(code=1000, reason="Graceful shutdown")

View File

@@ -1,6 +1,5 @@
import abc import abc
import asyncio import asyncio
import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from astrbot.core.agent.message import Message from astrbot.core.agent.message import Message
@@ -12,7 +11,6 @@ from astrbot.core.provider.entities import (
ToolCallsResult, ToolCallsResult,
) )
from astrbot.core.provider.register import provider_cls_map from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
class AbstractProvider(abc.ABC): class AbstractProvider(abc.ABC):
@@ -45,14 +43,6 @@ class AbstractProvider(abc.ABC):
) )
return meta return meta
async def test(self) -> bool:
"""test the provider is a
Returns:
bool: the provider is available
"""
return True
class Provider(AbstractProvider): class Provider(AbstractProvider):
"""Chat Provider""" """Chat Provider"""
@@ -175,16 +165,6 @@ class Provider(AbstractProvider):
return dicts return dicts
async def test(self, timeout: float = 45.0) -> bool:
try:
response = await asyncio.wait_for(
self.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=timeout,
)
return response is not None
except Exception:
return False
class STTProvider(AbstractProvider): class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -197,20 +177,6 @@ class STTProvider(AbstractProvider):
"""获取音频的文本""" """获取音频的文本"""
raise NotImplementedError raise NotImplementedError
async def test(self) -> bool:
try:
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
return False
text_result = await self.get_text(sample_audio_path)
return isinstance(text_result, str) and bool(text_result)
except Exception:
return False
class TTSProvider(AbstractProvider): class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -223,13 +189,6 @@ class TTSProvider(AbstractProvider):
"""获取文本的音频,返回音频文件路径""" """获取文本的音频,返回音频文件路径"""
raise NotImplementedError raise NotImplementedError
async def test(self) -> bool:
try:
audio_result = await self.get_audio("hi")
return isinstance(audio_result, str) and bool(audio_result)
except Exception:
return False
class EmbeddingProvider(AbstractProvider): class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -252,15 +211,6 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度""" """获取向量的维度"""
... ...
async def test(self) -> bool:
try:
embedding_result = await self.get_embedding("health_check")
return isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
)
except Exception:
return False
async def get_embeddings_batch( async def get_embeddings_batch(
self, self,
texts: list[str], texts: list[str],
@@ -344,10 +294,3 @@ class RerankProvider(AbstractProvider):
) -> list[RerankResult]: ) -> list[RerankResult]:
"""获取查询和文档的重排序分数""" """获取查询和文档的重排序分数"""
... ...
async def test(self) -> bool:
try:
await self.rerank("Apple", documents=["apple", "banana"])
return True
except Exception:
return False

View File

@@ -18,8 +18,11 @@ from astrbot.core.config.i18n_utils import ConfigMetadataI18n
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.provider.register import provider_registry from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry from astrbot.core.star.star import star_registry
from astrbot.core.utils.astrbot_path import get_astrbot_path
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
@@ -353,26 +356,169 @@ class ConfigRoute(Route):
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
) )
try: if provider_capability_type == ProviderType.CHAT_COMPLETION:
result = await provider.test() try:
if result: logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
status_info["status"] = "available" response = await asyncio.wait_for(
logger.info( provider.text_chat(prompt="REPLY `PONG` ONLY"),
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", timeout=45.0,
)
logger.debug(
f"Received response from {status_info['name']}: {response}",
)
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if (
hasattr(response, "completion_text")
and response.completion_text
):
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
else response.completion_text
)
elif hasattr(response, "result_chain") and response.result_chain:
try:
response_text_snippet = (
response.result_chain.get_plain_text()[:70] + "..."
if len(response.result_chain.get_plain_text()) > 70
else response.result_chain.get_plain_text()
)
except Exception as _:
pass
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
)
else:
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
)
except asyncio.TimeoutError:
status_info["error"] = (
"Connection timed out after 45 seconds during test call."
) )
else:
status_info["error"] = "Provider test returned False."
logger.warning( logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.", f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
) )
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
status_info["error"] = error_message status_info["error"] = error_message
logger.warning( logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}", f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
) )
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
)
elif provider_capability_type == ProviderType.EMBEDDING:
try:
# For embedding, we can call the get_embedding method with a short prompt.
embedding_result = await provider.get_embedding("health_check")
if isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"Embedding test failed: unexpected result type {type(embedding_result)}"
)
except Exception as e:
logger.error(
f"Error testing embedding provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: {e!s}"
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
try:
# For TTS, we can call the get_audio method with a short prompt.
audio_result = await provider.get_audio("你好")
if isinstance(audio_result, str) and audio_result:
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"TTS test failed: unexpected result type {type(audio_result)}"
)
except Exception as e:
logger.error(
f"Error testing TTS provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: {e!s}"
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
try:
logger.debug(
f"Sending health check audio to provider: {status_info['name']}",
)
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable"
status_info["error"] = (
"STT test failed: sample audio file not found."
)
logger.warning(
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
)
else:
text_result = await provider.get_text(sample_audio_path)
if isinstance(text_result, str) and text_result:
status_info["status"] = "available"
snippet = (
text_result[:70] + "..."
if len(text_result) > 70
else text_result
)
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
)
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"STT test failed: unexpected result type {type(text_result)}"
)
logger.warning(
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
)
except Exception as e:
logger.error(
f"Error testing STT provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {e!s}"
elif provider_capability_type == ProviderType.RERANK:
try:
assert isinstance(provider, RerankProvider)
await provider.rerank("Apple", documents=["apple", "banana"])
status_info["status"] = "available"
except Exception as e:
logger.error(
f"Error testing rerank provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"Rerank test failed: {e!s}"
else:
logger.debug( logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}", f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}",
)
status_info["status"] = "available"
status_info["error"] = (
"This provider type is not tested and is assumed to be available."
) )
return status_info return status_info

View File

@@ -1,4 +1,3 @@
import asyncio
import json import json
import os import os
import ssl import ssl
@@ -20,10 +19,6 @@ from astrbot.core.star.star_manager import PluginManager
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
PLUGIN_UPDATE_CONCURRENCY = (
3 # limit concurrent updates to avoid overwhelming plugin sources
)
class PluginRoute(Route): class PluginRoute(Route):
def __init__( def __init__(
@@ -38,7 +33,6 @@ class PluginRoute(Route):
"/plugin/install": ("POST", self.install_plugin), "/plugin/install": ("POST", self.install_plugin),
"/plugin/install-upload": ("POST", self.install_plugin_upload), "/plugin/install-upload": ("POST", self.install_plugin_upload),
"/plugin/update": ("POST", self.update_plugin), "/plugin/update": ("POST", self.update_plugin),
"/plugin/update-all": ("POST", self.update_all_plugins),
"/plugin/uninstall": ("POST", self.uninstall_plugin), "/plugin/uninstall": ("POST", self.uninstall_plugin),
"/plugin/market_list": ("GET", self.get_online_plugins), "/plugin/market_list": ("GET", self.get_online_plugins),
"/plugin/off": ("POST", self.off_plugin), "/plugin/off": ("POST", self.off_plugin),
@@ -69,7 +63,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
data = await request.get_json() data = await request.json
plugin_name = data.get("name", None) plugin_name = data.get("name", None)
try: try:
success, message = await self.plugin_manager.reload(plugin_name) success, message = await self.plugin_manager.reload(plugin_name)
@@ -352,7 +346,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
repo_url = post_data["url"] repo_url = post_data["url"]
proxy: str = post_data.get("proxy", None) proxy: str = post_data.get("proxy", None)
@@ -399,7 +393,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
delete_config = post_data.get("delete_config", False) delete_config = post_data.get("delete_config", False)
delete_data = post_data.get("delete_data", False) delete_data = post_data.get("delete_data", False)
@@ -424,7 +418,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
proxy: str = post_data.get("proxy", None) proxy: str = post_data.get("proxy", None)
try: try:
@@ -438,59 +432,6 @@ class PluginRoute(Route):
logger.error(f"/api/plugin/update: {traceback.format_exc()}") logger.error(f"/api/plugin/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__ return Response().error(str(e)).__dict__
async def update_all_plugins(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.get_json()
plugin_names: list[str] = post_data.get("names") or []
proxy: str = post_data.get("proxy", "")
if not isinstance(plugin_names, list) or not plugin_names:
return Response().error("插件列表不能为空").__dict__
results = []
sem = asyncio.Semaphore(PLUGIN_UPDATE_CONCURRENCY)
async def _update_one(name: str):
async with sem:
try:
logger.info(f"批量更新插件 {name}")
await self.plugin_manager.update_plugin(name, proxy)
return {"name": name, "status": "ok", "message": "更新成功"}
except Exception as e:
logger.error(
f"/api/plugin/update-all: 更新插件 {name} 失败: {traceback.format_exc()}",
)
return {"name": name, "status": "error", "message": str(e)}
raw_results = await asyncio.gather(
*(_update_one(name) for name in plugin_names),
return_exceptions=True,
)
for name, result in zip(plugin_names, raw_results):
if isinstance(result, asyncio.CancelledError):
raise result
if isinstance(result, BaseException):
results.append(
{"name": name, "status": "error", "message": str(result)}
)
else:
results.append(result)
failed = [r for r in results if r["status"] == "error"]
message = (
"批量更新完成,全部成功。"
if not failed
else f"批量更新完成,其中 {len(failed)}/{len(results)} 个插件失败。"
)
return Response().ok({"results": results}, message).__dict__
async def off_plugin(self): async def off_plugin(self):
if DEMO_MODE: if DEMO_MODE:
return ( return (
@@ -499,7 +440,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
try: try:
await self.plugin_manager.turn_off_plugin(plugin_name) await self.plugin_manager.turn_off_plugin(plugin_name)
@@ -517,7 +458,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
try: try:
await self.plugin_manager.turn_on_plugin(plugin_name) await self.plugin_manager.turn_on_plugin(plugin_name)

View File

@@ -1,25 +0,0 @@
## What's Changed
1. 修复使用非默认配置文件情况下时,第三方 Agent Runner (Dify、Coze、阿里云百炼应用等)无法正常工作的问题
2. 修复当“聊天模型”未设置,并且模型提供商中仅有 Agent Runner 时,无法正常使用 Agent Runner 的问题
3. 修复部分情况下报错 `pydantic_core._pydantic_core.ValidationError: 1 validation error for Message content` 的问题
4. 新增群聊模式下的专用图片转述模型配置 ([#3822](https://github.com/AstrBotDevs/AstrBot/issues/3822))
---
重构:
- 将 Dify、Coze、阿里云百炼应用等 LLMOps 提供商迁移到 Agent 执行器层,理清和本地 Agent 执行器的边界。详见:[Agent 执行器](https://docs.astrbot.app/use/agent-runner.html)
- 将「会话管理」功能重构为「自定义规则」功能,理清和多配置文件功能的边界。详见:[自定义规则](https://docs.astrbot.app/use/custom-rules.html)
优化:
- Dify、阿里云百炼应用支持流式输出
- 防止分段回复正则表达式解析错误导致消息不发送
- 群聊上下文感知记录 At 信息
- 优化模型提供商页面的测试提供商功能
新增:
- 支持在配置文件页面快速测试对话
- 为配置文件配置项内容添加国际化支持
修复:
- 在更新 MCP Server 配置后MCP 无法正常重启的问题

View File

@@ -159,10 +159,6 @@
"prompt_prefix": { "prompt_prefix": {
"description": "User Prompt", "description": "User Prompt",
"hint": "You can use {{prompt}} as a placeholder for user input. If no placeholder is provided, it will be added before the user input." "hint": "You can use {{prompt}} as a placeholder for user input. If no placeholder is provided, it will be added before the user input."
},
"reachability_check": {
"description": "Provider Reachability Check",
"hint": "When running the /provider command, test provider connectivity in parallel. This actively pings models and may consume extra tokens."
} }
}, },
"provider_tts_settings": { "provider_tts_settings": {
@@ -383,11 +379,7 @@
}, },
"image_caption": { "image_caption": {
"description": "Auto-understand Images", "description": "Auto-understand Images",
"hint": "Requires setting a group chat image caption model." "hint": "Requires setting a default image caption model."
},
"image_caption_provider_id": {
"description": "Group Chat Image Caption Model",
"hint": "Used for image understanding in group chat context awareness, configured separately from the default image caption model."
}, },
"active_reply": { "active_reply": {
"enable": { "enable": {
@@ -457,4 +449,4 @@
} }
} }
} }
} }

View File

@@ -32,8 +32,7 @@
"actions": "Actions", "actions": "Actions",
"back": "Back", "back": "Back",
"selectFile": "Select File", "selectFile": "Select File",
"refresh": "Refresh", "refresh": "Refresh"
"updateAll": "Update All"
}, },
"status": { "status": {
"enabled": "Enabled", "enabled": "Enabled",
@@ -142,9 +141,7 @@
"confirmDelete": "Are you sure you want to delete this extension?", "confirmDelete": "Are you sure you want to delete this extension?",
"fillUrlOrFile": "Please fill in extension URL or upload extension file", "fillUrlOrFile": "Please fill in extension URL or upload extension file",
"dontFillBoth": "Please don't fill in both extension URL and upload file", "dontFillBoth": "Please don't fill in both extension URL and upload file",
"supportedFormats": "Supports .zip extension files", "supportedFormats": "Supports .zip extension files"
"updateAllSuccess": "All upgradable extensions have been updated!",
"updateAllFailed": "{failed} of {total} extensions failed to update:"
}, },
"upload": { "upload": {
"fromFile": "Install from File", "fromFile": "Install from File",

View File

@@ -159,10 +159,6 @@
"prompt_prefix": { "prompt_prefix": {
"description": "用户提示词", "description": "用户提示词",
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。" "hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。"
},
"reachability_check": {
"description": "提供商可达性检测",
"hint": "/provider 命令列出模型时并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。"
} }
}, },
"provider_tts_settings": { "provider_tts_settings": {
@@ -383,11 +379,7 @@
}, },
"image_caption": { "image_caption": {
"description": "自动理解图片", "description": "自动理解图片",
"hint": "需要设置群聊图片转述模型。" "hint": "需要设置默认图片转述模型。"
},
"image_caption_provider_id": {
"description": "群聊图片转述模型",
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。"
}, },
"active_reply": { "active_reply": {
"enable": { "enable": {

View File

@@ -32,8 +32,7 @@
"actions": "操作", "actions": "操作",
"back": "返回", "back": "返回",
"selectFile": "选择文件", "selectFile": "选择文件",
"refresh": "刷新", "refresh": "刷新"
"updateAll": "更新全部插件"
}, },
"status": { "status": {
"enabled": "启用", "enabled": "启用",
@@ -142,9 +141,7 @@
"confirmDelete": "确定要删除插件吗?", "confirmDelete": "确定要删除插件吗?",
"fillUrlOrFile": "请填写插件链接或上传插件文件", "fillUrlOrFile": "请填写插件链接或上传插件文件",
"dontFillBoth": "请不要同时填写插件链接和上传文件", "dontFillBoth": "请不要同时填写插件链接和上传文件",
"supportedFormats": "支持 .zip 格式的插件文件", "supportedFormats": "支持 .zip 格式的插件文件"
"updateAllSuccess": "所有可更新的插件都已更新!",
"updateAllFailed": "有 {failed}/{total} 个插件更新失败:"
}, },
"upload": { "upload": {
"fromFile": "从文件安装", "fromFile": "从文件安装",

View File

@@ -42,7 +42,6 @@ const loadingDialog = reactive({
const showPluginInfoDialog = ref(false); const showPluginInfoDialog = ref(false);
const selectedPlugin = ref({}); const selectedPlugin = ref({});
const curr_namespace = ref(""); const curr_namespace = ref("");
const updatingAll = ref(false);
const readmeDialog = reactive({ const readmeDialog = reactive({
show: false, show: false,
@@ -227,10 +226,6 @@ const paginatedPlugins = computed(() => {
return sortedPlugins.value.slice(start, end); return sortedPlugins.value.slice(start, end);
}); });
const updatableExtensions = computed(() => {
return extension_data?.data?.filter(ext => ext.has_update) || [];
});
// 方法 // 方法
const toggleShowReserved = () => { const toggleShowReserved = () => {
showReserved.value = !showReserved.value; showReserved.value = !showReserved.value;
@@ -377,56 +372,6 @@ const updateExtension = async (extension_name) => {
} }
}; };
const updateAllExtensions = async () => {
if (updatingAll.value || updatableExtensions.value.length === 0) return;
updatingAll.value = true;
loadingDialog.title = tm('status.loading');
loadingDialog.statusCode = 0;
loadingDialog.result = "";
loadingDialog.show = true;
const targets = updatableExtensions.value.map(ext => ext.name);
try {
const res = await axios.post('/api/plugin/update-all', {
names: targets,
proxy: localStorage.getItem('selectedGitHubProxy') || ""
});
if (res.data.status === "error") {
onLoadingDialogResult(2, res.data.message || tm('messages.updateAllFailed', {
failed: targets.length,
total: targets.length
}), -1);
return;
}
const results = res.data.data?.results || [];
const failures = results.filter(r => r.status !== 'ok');
try {
await getExtensions();
} catch (err) {
const errorMsg = err.response?.data?.message || err.message || String(err);
failures.push({ name: 'refresh', status: 'error', message: errorMsg });
}
if (failures.length === 0) {
onLoadingDialogResult(1, tm('messages.updateAllSuccess'));
} else {
const failureText = tm('messages.updateAllFailed', {
failed: failures.length,
total: targets.length
});
const detail = failures.map(f => `${f.name}: ${f.message}`).join('\n');
onLoadingDialogResult(2, `${failureText}\n${detail}`, -1);
}
} catch (err) {
const errorMsg = err.response?.data?.message || err.message || String(err);
onLoadingDialogResult(2, errorMsg, -1);
} finally {
updatingAll.value = false;
}
};
const pluginOn = async (extension) => { const pluginOn = async (extension) => {
try { try {
const res = await axios.post('/api/plugin/on', { name: extension.name }); const res = await axios.post('/api/plugin/on', { name: extension.name });
@@ -775,12 +720,6 @@ watch(marketSearch, (newVal) => {
{{ showReserved ? tm('buttons.hideSystemPlugins') : tm('buttons.showSystemPlugins') }} {{ showReserved ? tm('buttons.hideSystemPlugins') : tm('buttons.showSystemPlugins') }}
</v-btn> </v-btn>
<v-btn class="ml-2" color="warning" variant="tonal" :disabled="updatableExtensions.length === 0"
:loading="updatingAll" @click="updateAllExtensions">
<v-icon>mdi-update</v-icon>
{{ tm('buttons.updateAll') }}
</v-btn>
<v-btn class="ml-2" color="primary" variant="tonal" @click="dialog = true"> <v-btn class="ml-2" color="primary" variant="tonal" @click="dialog = true">
<v-icon>mdi-plus</v-icon> <v-icon>mdi-plus</v-icon>
{{ tm('buttons.install') }} {{ tm('buttons.install') }}

View File

@@ -1,7 +1,5 @@
import asyncio
import re import re
from astrbot import logger
from astrbot.api import star from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.entities import ProviderType
@@ -11,53 +9,6 @@ class ProviderCommands:
def __init__(self, context: star.Context): def __init__(self, context: star.Context):
self.context = context self.context = context
def _log_reachability_failure(
self,
provider,
provider_capability_type: ProviderType | None,
err_code: str,
err_reason: str,
):
"""记录不可达原因到日志。"""
meta = provider.meta()
logger.warning(
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
meta.id,
provider_capability_type.name if provider_capability_type else "unknown",
err_code,
err_reason,
)
async def _test_provider_capability(self, provider):
"""测试单个 provider 的可用性"""
meta = provider.meta()
provider_capability_type = meta.provider_type
try:
result = await provider.test()
if result:
return True, None, None
err_code = "TEST_FAILED"
err_reason = "Provider test returned False"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
except Exception as exc:
err_code = (
getattr(exc, "status_code", None)
or getattr(exc, "code", None)
or getattr(exc, "error_code", None)
)
err_reason = str(exc)
if not err_code:
err_code = exc.__class__.__name__
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
async def provider( async def provider(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
@@ -66,131 +17,46 @@ class ProviderCommands:
): ):
"""查看或者切换 LLM Provider""" """查看或者切换 LLM Provider"""
umo = event.unified_msg_origin umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {})
reachability_check_enabled = cfg.get("reachability_check", True)
if idx is None: if idx is None:
parts = ["## 载入的 LLM 提供商\n"] parts = ["## 载入的 LLM 提供商\n"]
for idx, llm in enumerate(self.context.get_all_providers()):
# 获取所有类型的提供商 id_ = llm.meta().id
llms = list(self.context.get_all_providers()) line = f"{idx + 1}. {id_} ({llm.meta().model})"
ttss = self.context.get_all_tts_providers()
stts = self.context.get_all_stt_providers()
# 构造待检测列表: [(provider, type_label), ...]
all_providers = []
all_providers.extend([(p, "llm") for p in llms])
all_providers.extend([(p, "tts") for p in ttss])
all_providers.extend([(p, "stt") for p in stts])
# 并发测试连通性
if reachability_check_enabled:
if all_providers:
await event.send(
MessageEventResult().message(
"正在进行提供商可达性测试,请稍候..."
)
)
check_results = await asyncio.gather(
*[self._test_provider_capability(p) for p, _ in all_providers],
return_exceptions=True,
)
else:
# 用 None 表示未检测
check_results = [None for _ in all_providers]
# 整合结果
display_data = []
for (p, p_type), reachable in zip(all_providers, check_results):
meta = p.meta()
id_ = meta.id
error_code = None
if isinstance(reachable, Exception):
# 异常情况下兜底处理,避免单个 provider 导致列表失败
self._log_reachability_failure(
p,
None,
reachable.__class__.__name__,
str(reachable),
)
reachable_flag = False
error_code = reachable.__class__.__name__
elif isinstance(reachable, tuple):
reachable_flag, error_code, _ = reachable
else:
reachable_flag = reachable
# 根据类型构建显示名称
if p_type == "llm":
info = f"{id_} ({meta.model})"
else:
info = f"{id_}"
# 确定状态标记
if reachable_flag is True:
mark = ""
elif reachable_flag is False:
if error_code:
mark = f" ❌(错误码: {error_code})"
else:
mark = ""
else:
mark = "" # 不支持检测时不显示标记
display_data.append(
{
"type": p_type,
"info": info,
"mark": mark,
"provider": p,
}
)
# 分组输出
# 1. LLM
llm_data = [d for d in display_data if d["type"] == "llm"]
for i, d in enumerate(llm_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
provider_using = self.context.get_using_provider(umo=umo) provider_using = self.context.get_using_provider(umo=umo)
if ( if provider_using and provider_using.meta().id == id_:
provider_using
and provider_using.meta().id == d["provider"].meta().id
):
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
# 2. TTS tts_providers = self.context.get_all_tts_providers()
tts_data = [d for d in display_data if d["type"] == "tts"] if tts_providers:
if tts_data:
parts.append("\n## 载入的 TTS 提供商\n") parts.append("\n## 载入的 TTS 提供商\n")
for i, d in enumerate(tts_data): for idx, tts in enumerate(tts_providers):
line = f"{i + 1}. {d['info']}{d['mark']}" id_ = tts.meta().id
line = f"{idx + 1}. {id_}"
tts_using = self.context.get_using_tts_provider(umo=umo) tts_using = self.context.get_using_tts_provider(umo=umo)
if tts_using and tts_using.meta().id == d["provider"].meta().id: if tts_using and tts_using.meta().id == id_:
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
# 3. STT stt_providers = self.context.get_all_stt_providers()
stt_data = [d for d in display_data if d["type"] == "stt"] if stt_providers:
if stt_data:
parts.append("\n## 载入的 STT 提供商\n") parts.append("\n## 载入的 STT 提供商\n")
for i, d in enumerate(stt_data): for idx, stt in enumerate(stt_providers):
line = f"{i + 1}. {d['info']}{d['mark']}" id_ = stt.meta().id
line = f"{idx + 1}. {id_}"
stt_using = self.context.get_using_stt_provider(umo=umo) stt_using = self.context.get_using_stt_provider(umo=umo)
if stt_using and stt_using.meta().id == d["provider"].meta().id: if stt_using and stt_using.meta().id == id_:
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。") parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
ret = "".join(parts) ret = "".join(parts)
if ttss: if tts_providers:
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
if stts: if stt_providers:
ret += "\n使用 /provider stt <序号> 切换 STT 提供商。" ret += "\n使用 /provider stt <切换> STT 提供商。"
if not reachability_check_enabled:
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
event.set_result(MessageEventResult().message(ret)) event.set_result(MessageEventResult().message(ret))
elif idx == "tts": elif idx == "tts":

View File

@@ -30,13 +30,16 @@ class LongTermMemory:
except BaseException as e: except BaseException as e:
logger.error(e) logger.error(e)
max_cnt = 300 max_cnt = 300
image_caption = (
True
if cfg["provider_settings"]["default_image_caption_provider_id"]
and cfg["provider_ltm_settings"]["image_caption"]
else False
)
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"] image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
image_caption_provider_id = cfg["provider_ltm_settings"].get( image_caption_provider_id = cfg["provider_settings"][
"image_caption_provider_id" "default_image_caption_provider_id"
) ]
image_caption = cfg["provider_ltm_settings"]["image_caption"] and bool(
image_caption_provider_id
)
active_reply = cfg["provider_ltm_settings"]["active_reply"] active_reply = cfg["provider_ltm_settings"]["active_reply"]
enable_active_reply = active_reply.get("enable", False) enable_active_reply = active_reply.get("enable", False)
ar_method = active_reply["method"] ar_method = active_reply["method"]

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "AstrBot" name = "AstrBot"
version = "4.7.3" version = "4.7.1"
description = "Easy-to-use multi-platform LLM chatbot and development framework" description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"