Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
26467fbe22 feat: implement unified provider availability testing across components
- Added a `test` method to each provider class to standardize availability checks.
- Updated the dashboard and command routes to utilize the new `test` method for provider reachability verification, simplifying the logic and improving maintainability.
- Removed redundant reachability check logic from the command handler.
2025-12-01 00:39:14 +08:00
16 changed files with 102 additions and 261 deletions

View File

@@ -1 +1 @@
__version__ = "4.7.4" __version__ = "4.7.3"

View File

@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.7.4" VERSION = "4.7.3"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置 # 默认配置
@@ -76,11 +76,6 @@ DEFAULT_CONFIG = {
"reachability_check": False, "reachability_check": False,
"max_agent_step": 30, "max_agent_step": 30,
"tool_call_timeout": 60, "tool_call_timeout": 60,
"file_extract": {
"enable": False,
"provider": "moonshotai",
"moonshotai_api_key": "",
},
}, },
"provider_stt_settings": { "provider_stt_settings": {
"enable": False, "enable": False,
@@ -2074,20 +2069,6 @@ CONFIG_METADATA_2 = {
"tool_call_timeout": { "tool_call_timeout": {
"type": "int", "type": "int",
}, },
"file_extract": {
"type": "object",
"items": {
"enable": {
"type": "bool",
},
"provider": {
"type": "string",
},
"moonshotai_api_key": {
"type": "string",
},
},
},
}, },
}, },
"provider_stt_settings": { "provider_stt_settings": {
@@ -2422,36 +2403,6 @@ CONFIG_METADATA_3 = {
"provider_settings.enable": True, "provider_settings.enable": True,
}, },
}, },
# "file_extract": {
# "description": "文档解析能力 [beta]",
# "type": "object",
# "items": {
# "provider_settings.file_extract.enable": {
# "description": "启用文档解析能力",
# "type": "bool",
# },
# "provider_settings.file_extract.provider": {
# "description": "文档解析提供商",
# "type": "string",
# "options": ["moonshotai"],
# "condition": {
# "provider_settings.file_extract.enable": True,
# },
# },
# "provider_settings.file_extract.moonshotai_api_key": {
# "description": "Moonshot AI API Key",
# "type": "string",
# "condition": {
# "provider_settings.file_extract.provider": "moonshotai",
# "provider_settings.file_extract.enable": True,
# },
# },
# },
# "condition": {
# "provider_settings.agent_runner_type": "local",
# "provider_settings.enable": True,
# },
# },
"others": { "others": {
"description": "其他配置", "description": "其他配置",
"type": "object", "type": "object",

View File

@@ -722,12 +722,7 @@ class File(BaseMessageComponent):
"""下载文件""" """下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp") download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True) os.makedirs(download_dir, exist_ok=True)
if self.name: file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
name, ext = os.path.splitext(self.name)
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
else:
filename = f"{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path) await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path) self.file_ = os.path.abspath(file_path)

View File

@@ -9,7 +9,7 @@ from astrbot.core import logger
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Reply from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
MessageEventResult, MessageEventResult,
@@ -22,7 +22,6 @@ from astrbot.core.provider.entities import (
ProviderRequest, ProviderRequest,
) )
from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.metrics import Metric from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager from astrbot.core.utils.session_lock import session_lock_manager
@@ -57,13 +56,6 @@ class InternalAgentSubStage(Stage):
self.show_reasoning = settings.get("display_reasoning_text", False) self.show_reasoning = settings.get("display_reasoning_text", False)
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False) self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
file_extract_conf: dict = settings.get("file_extract", {})
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
self.file_extract_msh_api_key: str = file_extract_conf.get(
"moonshotai_api_key", ""
)
self.conv_manager = ctx.plugin_manager.context.conversation_manager self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent): def _select_provider(self, event: AstrMessageEvent):
@@ -122,50 +114,6 @@ class InternalAgentSubStage(Stage):
req.func_tool = ToolSet() req.func_tool = ToolSet()
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL) req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
async def _apply_file_extract(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""Apply file extract to the provider request"""
file_paths = []
file_names = []
for comp in event.message_obj.message:
if isinstance(comp, File):
file_paths.append(await comp.get_file())
file_names.append(comp.name)
elif isinstance(comp, Reply) and comp.chain:
for reply_comp in comp.chain:
if isinstance(reply_comp, File):
file_paths.append(await reply_comp.get_file())
file_names.append(reply_comp.name)
if not file_paths:
return
if not req.prompt:
req.prompt = "总结一下文件里面讲了什么?"
if self.file_extract_prov == "moonshotai":
if not self.file_extract_msh_api_key:
logger.error("Moonshot AI API key for file extract is not set")
return
file_contents = await asyncio.gather(
*[
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
for file_path in file_paths
]
)
else:
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
return
# add file extract results to contexts
for file_content, file_name in zip(file_contents, file_names):
req.contexts.append(
{
"role": "system",
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
},
)
def _truncate_contexts( def _truncate_contexts(
self, self,
contexts: list[dict], contexts: list[dict],
@@ -398,17 +346,6 @@ class InternalAgentSubStage(Stage):
event.set_extra("provider_request", req) event.set_extra("provider_request", req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# apply file extract
if self.file_extract_enabled:
try:
await self._apply_file_extract(event, req)
except Exception as e:
logger.error(f"Error occurred while applying file extract: {e}")
if not req.prompt and not req.image_urls: if not req.prompt and not req.image_urls:
return return
@@ -419,6 +356,10 @@ class InternalAgentSubStage(Stage):
# apply knowledge base feature # apply knowledge base feature
await self._apply_kb(event, req) await self._apply_kb(event, req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# truncate contexts to fit max length # truncate contexts to fit max length
if req.contexts: if req.contexts:
req.contexts = self._truncate_contexts(req.contexts) req.contexts = self._truncate_contexts(req.contexts)

View File

@@ -246,13 +246,7 @@ class AiocqhttpAdapter(Platform):
if m["data"].get("url") and m["data"].get("url").startswith("http"): if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange # Lagrange
logger.info("guessing lagrange") logger.info("guessing lagrange")
# 检查多个可能的文件名字段 file_name = m["data"].get("file_name", "file")
file_name = (
m["data"].get("file_name", "")
or m["data"].get("name", "")
or m["data"].get("file", "")
or "file"
)
abm.message.append(File(name=file_name, url=m["data"]["url"])) abm.message.append(File(name=file_name, url=m["data"]["url"]))
else: else:
try: try:
@@ -271,14 +265,7 @@ class AiocqhttpAdapter(Platform):
) )
if ret and "url" in ret: if ret and "url" in ret:
file_url = ret["url"] # https file_url = ret["url"] # https
# 优先从 API 返回值获取文件名,其次从原始消息数据获取 a = File(name="", url=file_url)
file_name = (
ret.get("file_name", "")
or ret.get("name", "")
or m["data"].get("file", "")
or m["data"].get("file_name", "")
)
a = File(name=file_name, url=file_url)
abm.message.append(a) abm.message.append(a)
else: else:
logger.error(f"获取文件失败: {ret}") logger.error(f"获取文件失败: {ret}")

View File

@@ -381,9 +381,7 @@ class TelegramPlatformAdapter(Platform):
f"Telegram document file_path is None, cannot save the file {file_name}.", f"Telegram document file_path is None, cannot save the file {file_name}.",
) )
else: else:
message.message.append( message.message.append(Comp.File(file=file_path, name=file_name))
Comp.File(file=file_path, name=file_name, url=file_path)
)
elif update.message.video: elif update.message.video:
file = await update.message.video.get_file() file = await update.message.video.get_file()

View File

@@ -45,13 +45,13 @@ class AbstractProvider(abc.ABC):
) )
return meta return meta
async def test(self): async def test(self) -> bool:
"""test the provider is a """test the provider is a
raises: Returns:
Exception: if the provider is not available bool: the provider is available
""" """
... return True
class Provider(AbstractProvider): class Provider(AbstractProvider):
@@ -175,11 +175,15 @@ class Provider(AbstractProvider):
return dicts return dicts
async def test(self, timeout: float = 45.0): async def test(self, timeout: float = 45.0) -> bool:
await asyncio.wait_for( try:
self.text_chat(prompt="REPLY `PONG` ONLY"), response = await asyncio.wait_for(
timeout=timeout, self.text_chat(prompt="REPLY `PONG` ONLY"),
) timeout=timeout,
)
return response is not None
except Exception:
return False
class STTProvider(AbstractProvider): class STTProvider(AbstractProvider):
@@ -193,13 +197,19 @@ class STTProvider(AbstractProvider):
"""获取音频的文本""" """获取音频的文本"""
raise NotImplementedError raise NotImplementedError
async def test(self): async def test(self) -> bool:
sample_audio_path = os.path.join( try:
get_astrbot_path(), sample_audio_path = os.path.join(
"samples", get_astrbot_path(),
"stt_health_check.wav", "samples",
) "stt_health_check.wav",
await self.get_text(sample_audio_path) )
if not os.path.exists(sample_audio_path):
return False
text_result = await self.get_text(sample_audio_path)
return isinstance(text_result, str) and bool(text_result)
except Exception:
return False
class TTSProvider(AbstractProvider): class TTSProvider(AbstractProvider):
@@ -213,8 +223,12 @@ class TTSProvider(AbstractProvider):
"""获取文本的音频,返回音频文件路径""" """获取文本的音频,返回音频文件路径"""
raise NotImplementedError raise NotImplementedError
async def test(self): async def test(self) -> bool:
await self.get_audio("hi") try:
audio_result = await self.get_audio("hi")
return isinstance(audio_result, str) and bool(audio_result)
except Exception:
return False
class EmbeddingProvider(AbstractProvider): class EmbeddingProvider(AbstractProvider):
@@ -238,8 +252,14 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度""" """获取向量的维度"""
... ...
async def test(self): async def test(self) -> bool:
await self.get_embedding("astrbot") try:
embedding_result = await self.get_embedding("health_check")
return isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
)
except Exception:
return False
async def get_embeddings_batch( async def get_embeddings_batch(
self, self,
@@ -325,7 +345,9 @@ class RerankProvider(AbstractProvider):
"""获取查询和文档的重排序分数""" """获取查询和文档的重排序分数"""
... ...
async def test(self): async def test(self) -> bool:
result = await self.rerank("Apple", documents=["apple", "banana"]) try:
if not result: await self.rerank("Apple", documents=["apple", "banana"])
raise Exception("Rerank provider test failed, no results returned") return True
except Exception:
return False

View File

@@ -1,23 +0,0 @@
from pathlib import Path
from openai import AsyncOpenAI
async def extract_file_moonshotai(file_path: str, api_key: str) -> str:
"""Extract text from a file using Moonshot AI API"""
"""
Args:
file_path: The path to the file to extract text from
api_key: The API key to use to extract text from the file
Returns:
The text extracted from the file
"""
client = AsyncOpenAI(
api_key=api_key,
base_url="https://api.moonshot.cn/v1",
)
file_object = await client.files.create(
file=Path(file_path),
purpose="file-extract", # type: ignore
)
return (await client.files.content(file_id=file_object.id)).text

View File

@@ -354,11 +354,17 @@ class ConfigRoute(Route):
) )
try: try:
await provider.test() result = await provider.test()
status_info["status"] = "available" if result:
logger.info( status_info["status"] = "available"
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", logger.info(
) f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
)
else:
status_info["error"] = "Provider test returned False."
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.",
)
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
status_info["error"] = error_message status_info["error"] = error_message

View File

@@ -1,7 +0,0 @@
## What's Changed
1. 修复assistant message 中 tool_call 存在但 content 不存在时,导致验证错误的问题 ([#3862](https://github.com/AstrBotDevs/AstrBot/issues/3862))
2. 修复fix: aiocqhttp 适配器 NapCat 文件名获取为空 ([#3853](https://github.com/AstrBotDevs/AstrBot/issues/3853))
3. 新增:升级所有插件按钮
4. 新增:/provider 指令支持同时测试提供商可用性
5. 优化:主动回复的 prompt

View File

@@ -109,22 +109,6 @@
} }
} }
}, },
"file_extract": {
"description": "File Extract",
"provider_settings": {
"file_extract": {
"enable": {
"description": "Enable File Extract"
},
"provider": {
"description": "File Extract Provider"
},
"moonshotai_api_key": {
"description": "Moonshot AI API Key"
}
}
}
},
"others": { "others": {
"description": "Other Settings", "description": "Other Settings",
"provider_settings": { "provider_settings": {

View File

@@ -11,12 +11,7 @@
}, },
"agent_runner_type": { "agent_runner_type": {
"description": "执行器", "description": "执行器",
"labels": [ "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"]
"内置 Agent",
"Dify",
"Coze",
"阿里云百炼应用"
]
}, },
"coze_agent_runner_provider_id": { "coze_agent_runner_provider_id": {
"description": "Coze Agent 执行器提供商 ID" "description": "Coze Agent 执行器提供商 ID"
@@ -114,22 +109,6 @@
} }
} }
}, },
"file_extract": {
"description": "文档解析能力",
"provider_settings": {
"file_extract": {
"enable": {
"description": "启用文档解析能力"
},
"provider": {
"description": "文档解析提供商"
},
"moonshotai_api_key": {
"description": "Moonshot AI API Key"
}
}
}
},
"others": { "others": {
"description": "其他配置", "description": "其他配置",
"provider_settings": { "provider_settings": {
@@ -163,10 +142,7 @@
"unsupported_streaming_strategy": { "unsupported_streaming_strategy": {
"description": "不支持流式回复的平台", "description": "不支持流式回复的平台",
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容", "hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
"labels": [ "labels": ["实时分段回复", "关闭流式回复"]
"实时分段回复",
"关闭流式回复"
]
}, },
"max_context_length": { "max_context_length": {
"description": "最多携带对话轮数", "description": "最多携带对话轮数",
@@ -481,4 +457,4 @@
} }
} }
} }
} }

View File

@@ -34,11 +34,25 @@ class ProviderCommands:
provider_capability_type = meta.provider_type provider_capability_type = meta.provider_type
try: try:
await provider.test() result = await provider.test()
return True, None, None if result:
except Exception as e: return True, None, None
err_code = "TEST_FAILED" err_code = "TEST_FAILED"
err_reason = str(e) err_reason = "Provider test returned False"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
except Exception as exc:
err_code = (
getattr(exc, "status_code", None)
or getattr(exc, "code", None)
or getattr(exc, "error_code", None)
)
err_reason = str(exc)
if not err_code:
err_code = exc.__class__.__name__
self._log_reachability_failure( self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason provider, provider_capability_type, err_code, err_reason
) )

View File

@@ -8,7 +8,7 @@ from astrbot.api import star
from astrbot.api.event import AstrMessageEvent from astrbot.api.event import AstrMessageEvent
from astrbot.api.message_components import At, Image, Plain from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import MessageType from astrbot.api.platform import MessageType
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
""" """
@@ -158,12 +158,8 @@ class LongTermMemory:
cfg = self.cfg(event) cfg = self.cfg(event)
if cfg["enable_active_reply"]: if cfg["enable_active_reply"]:
prompt = req.prompt prompt = req.prompt
req.prompt = ( req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
f"You are now in a chatroom. The chat history is as follows:\n{chats_str}" req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
f"\nNow, a new message is coming: `{prompt}`. "
"Please react to it. Only output your response and do not output any other information. "
"You MUST use the SAME language as the chatroom is using."
)
req.contexts = [] # 清空上下文当使用了主动回复所有聊天记录都在一个prompt中。 req.contexts = [] # 清空上下文当使用了主动回复所有聊天记录都在一个prompt中。
else: else:
req.system_prompt += ( req.system_prompt += (
@@ -171,15 +167,13 @@ class LongTermMemory:
) )
req.system_prompt += chats_str req.system_prompt += chats_str
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse): async def after_req_llm(self, event: AstrMessageEvent):
if event.unified_msg_origin not in self.session_chats: if event.unified_msg_origin not in self.session_chats:
return return
if llm_resp.completion_text: if event.get_result() and event.get_result().is_llm_result():
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}" final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
logger.debug( logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
)
self.session_chats[event.unified_msg_origin].append(final_message) self.session_chats[event.unified_msg_origin].append(final_message)
cfg = self.cfg(event) cfg = self.cfg(event)
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:

View File

@@ -322,7 +322,7 @@ class Main(star.Star):
@filter.on_llm_response() @filter.on_llm_response()
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse): async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话""" """在 LLM 响应后基于配置注入思考过程文本"""
umo = event.unified_msg_origin umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {}) cfg = self.context.get_config(umo).get("provider_settings", {})
show_reasoning = cfg.get("display_reasoning_text", False) show_reasoning = cfg.get("display_reasoning_text", False)
@@ -331,9 +331,12 @@ class Main(star.Star):
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}" f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
) )
@filter.after_message_sent()
async def after_llm_req(self, event: AstrMessageEvent):
"""在 LLM 请求后记录对话"""
if self.ltm and self.ltm_enabled(event): if self.ltm and self.ltm_enabled(event):
try: try:
await self.ltm.after_req_llm(event, resp) await self.ltm.after_req_llm(event)
except Exception as e: except Exception as e:
logger.error(f"ltm: {e}") logger.error(f"ltm: {e}")

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "AstrBot" name = "AstrBot"
version = "4.7.4" version = "4.7.3"
description = "Easy-to-use multi-platform LLM chatbot and development framework" description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"