Compare commits

..

13 Commits

Author SHA1 Message Date
Soulter 1de377e749 fix: enhance file naming logic in File component and update prompt handling in InternalAgentSubStage 2025-12-01 18:12:03 +08:00
Soulter 6aa6963ab5 feat: add condition settings for local agent runner in default configuration 2025-12-01 17:41:14 +08:00
Soulter d3001d8148 fix: update file name handling in InternalAgentSubStage to correctly associate file names with extracted content 2025-12-01 17:40:30 +08:00
Soulter 380c4faf17 fix: add error handling for file extract application in InternalAgentSubStage 2025-12-01 14:58:18 +08:00
Soulter bd2a88783c fix: correct indentation in default configuration file 2025-12-01 14:56:32 +08:00
Soulter 17d7f822e7 feat: introduce file extract capability
powered by MoonshotAI
2025-12-01 14:54:25 +08:00
雪語 0e034f0fbd fix: aiocqhttp 适配器 NapCat 文件名获取为空 (#3853)
* aiocqhttp 适配器 NapCat 文件名获取为空

修复使用 NapCat 时,文件消息的 File.name 为空的问题。原代码硬编码 name="",导致下游插件无法获取文件名和扩展名

* Enhance file name retrieval from message data

Updated file name extraction logic to check multiple fields for better accuracy.
2025-12-01 13:36:19 +08:00
Soulter 2a7d03f9e1 fix: fit language and log AI responses more clearly (#3864)
* fix: fit language and log AI responses more clearly

* chore: ruff format
2025-12-01 13:24:52 +08:00
Soulter 72fac4b9f1 feat: implement unified provider availability testing across components (#3865)
- Added a `test` method to each provider class to standardize availability checks.
- Updated the dashboard and command routes to utilize the new `test` method for provider reachability verification, simplifying the logic and improving maintainability.
- Removed redundant reachability check logic from the command handler.
2025-12-01 13:17:20 +08:00
Soulter 38281ba2cf refactor: restore reachability check configuration in default settings and localization files 2025-12-01 00:38:30 +08:00
Soulter 21aa3174f4 fix: disable reachability check in default configuration 2025-12-01 00:16:11 +08:00
邹永赫 dcda871fc0 feat: provider availability reachability improvements (#3708) 2025-12-01 01:06:10 +09:00
Soulter c13c51f499 fix: assistant message validation error when tool_call exists but content not exists (#3862)
* fix: assistant message validation error when tool_call exists but content not exists

* fix: enhance content validation in Message model to allow None for assistant role with tool_calls
2025-11-30 23:42:37 +08:00
14 changed files with 462 additions and 207 deletions
+2
View File
@@ -48,3 +48,5 @@ astrbot.lock
chroma chroma
venv/* venv/*
pytest.ini pytest.ini
AGENTS.md
IFLOW.md
+55
View File
@@ -73,8 +73,14 @@ DEFAULT_CONFIG = {
"coze_agent_runner_provider_id": "", "coze_agent_runner_provider_id": "",
"dashscope_agent_runner_provider_id": "", "dashscope_agent_runner_provider_id": "",
"unsupported_streaming_strategy": "realtime_segmenting", "unsupported_streaming_strategy": "realtime_segmenting",
"reachability_check": False,
"max_agent_step": 30, "max_agent_step": 30,
"tool_call_timeout": 60, "tool_call_timeout": 60,
"file_extract": {
"enable": False,
"provider": "moonshotai",
"moonshotai_api_key": "",
},
}, },
"provider_stt_settings": { "provider_stt_settings": {
"enable": False, "enable": False,
@@ -2068,6 +2074,20 @@ CONFIG_METADATA_2 = {
"tool_call_timeout": { "tool_call_timeout": {
"type": "int", "type": "int",
}, },
"file_extract": {
"type": "object",
"items": {
"enable": {
"type": "bool",
},
"provider": {
"type": "string",
},
"moonshotai_api_key": {
"type": "string",
},
},
},
}, },
}, },
"provider_stt_settings": { "provider_stt_settings": {
@@ -2402,6 +2422,36 @@ CONFIG_METADATA_3 = {
"provider_settings.enable": True, "provider_settings.enable": True,
}, },
}, },
"file_extract": {
"description": "文档解析能力",
"type": "object",
"items": {
"provider_settings.file_extract.enable": {
"description": "启用文档解析能力",
"type": "bool",
},
"provider_settings.file_extract.provider": {
"description": "文档解析提供商",
"type": "string",
"options": ["moonshotai"],
"condition": {
"provider_settings.file_extract.enable": True,
},
},
"provider_settings.file_extract.moonshotai_api_key": {
"description": "Moonshot AI API Key",
"type": "string",
"condition": {
"provider_settings.file_extract.provider": "moonshotai",
"provider_settings.file_extract.enable": True,
},
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"others": { "others": {
"description": "其他配置", "description": "其他配置",
"type": "object", "type": "object",
@@ -2496,6 +2546,11 @@ CONFIG_METADATA_3 = {
"description": "开启 TTS 时同时输出语音和文字内容", "description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool", "type": "bool",
}, },
"provider_settings.reachability_check": {
"description": "提供商可达性检测",
"type": "bool",
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
},
}, },
"condition": { "condition": {
"provider_settings.enable": True, "provider_settings.enable": True,
+6 -1
View File
@@ -722,7 +722,12 @@ class File(BaseMessageComponent):
"""下载文件""" """下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp") download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True) os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}") if self.name:
name, ext = os.path.splitext(self.name)
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
else:
filename = f"{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path) await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path) self.file_ = os.path.abspath(file_path)
@@ -9,7 +9,7 @@ from astrbot.core import logger
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import Image from astrbot.core.message.components import File, Image, Reply
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
MessageEventResult, MessageEventResult,
@@ -22,6 +22,7 @@ from astrbot.core.provider.entities import (
ProviderRequest, ProviderRequest,
) )
from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.metrics import Metric from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager from astrbot.core.utils.session_lock import session_lock_manager
@@ -56,6 +57,13 @@ class InternalAgentSubStage(Stage):
self.show_reasoning = settings.get("display_reasoning_text", False) self.show_reasoning = settings.get("display_reasoning_text", False)
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False) self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
file_extract_conf: dict = settings.get("file_extract", {})
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
self.file_extract_msh_api_key: str = file_extract_conf.get(
"moonshotai_api_key", ""
)
self.conv_manager = ctx.plugin_manager.context.conversation_manager self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent): def _select_provider(self, event: AstrMessageEvent):
@@ -114,6 +122,50 @@ class InternalAgentSubStage(Stage):
req.func_tool = ToolSet() req.func_tool = ToolSet()
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL) req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
async def _apply_file_extract(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""Apply file extract to the provider request"""
file_paths = []
file_names = []
for comp in event.message_obj.message:
if isinstance(comp, File):
file_paths.append(await comp.get_file())
file_names.append(comp.name)
elif isinstance(comp, Reply) and comp.chain:
for reply_comp in comp.chain:
if isinstance(reply_comp, File):
file_paths.append(await reply_comp.get_file())
file_names.append(reply_comp.name)
if not file_paths:
return
if not req.prompt:
req.prompt = "总结一下文件里面讲了什么?"
if self.file_extract_prov == "moonshotai":
if not self.file_extract_msh_api_key:
logger.error("Moonshot AI API key for file extract is not set")
return
file_contents = await asyncio.gather(
*[
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
for file_path in file_paths
]
)
else:
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
return
# add file extract results to contexts
for file_content, file_name in zip(file_contents, file_names):
req.contexts.append(
{
"role": "system",
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
},
)
def _truncate_contexts( def _truncate_contexts(
self, self,
contexts: list[dict], contexts: list[dict],
@@ -346,6 +398,17 @@ class InternalAgentSubStage(Stage):
event.set_extra("provider_request", req) event.set_extra("provider_request", req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# apply file extract
if self.file_extract_enabled:
try:
await self._apply_file_extract(event, req)
except Exception as e:
logger.error(f"Error occurred while applying file extract: {e}")
if not req.prompt and not req.image_urls: if not req.prompt and not req.image_urls:
return return
@@ -356,10 +419,6 @@ class InternalAgentSubStage(Stage):
# apply knowledge base feature # apply knowledge base feature
await self._apply_kb(event, req) await self._apply_kb(event, req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# truncate contexts to fit max length # truncate contexts to fit max length
if req.contexts: if req.contexts:
req.contexts = self._truncate_contexts(req.contexts) req.contexts = self._truncate_contexts(req.contexts)
@@ -246,7 +246,13 @@ class AiocqhttpAdapter(Platform):
if m["data"].get("url") and m["data"].get("url").startswith("http"): if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange # Lagrange
logger.info("guessing lagrange") logger.info("guessing lagrange")
file_name = m["data"].get("file_name", "file") # 检查多个可能的文件名字段
file_name = (
m["data"].get("file_name", "")
or m["data"].get("name", "")
or m["data"].get("file", "")
or "file"
)
abm.message.append(File(name=file_name, url=m["data"]["url"])) abm.message.append(File(name=file_name, url=m["data"]["url"]))
else: else:
try: try:
@@ -265,7 +271,14 @@ class AiocqhttpAdapter(Platform):
) )
if ret and "url" in ret: if ret and "url" in ret:
file_url = ret["url"] # https file_url = ret["url"] # https
a = File(name="", url=file_url) # 优先从 API 返回值获取文件名,其次从原始消息数据获取
file_name = (
ret.get("file_name", "")
or ret.get("name", "")
or m["data"].get("file", "")
or m["data"].get("file_name", "")
)
a = File(name=file_name, url=file_url)
abm.message.append(a) abm.message.append(a)
else: else:
logger.error(f"获取文件失败: {ret}") logger.error(f"获取文件失败: {ret}")
@@ -381,7 +381,9 @@ class TelegramPlatformAdapter(Platform):
f"Telegram document file_path is None, cannot save the file {file_name}.", f"Telegram document file_path is None, cannot save the file {file_name}.",
) )
else: else:
message.message.append(Comp.File(file=file_path, name=file_name)) message.message.append(
Comp.File(file=file_path, name=file_name, url=file_path)
)
elif update.message.video: elif update.message.video:
file = await update.message.video.get_file() file = await update.message.video.get_file()
+57
View File
@@ -1,5 +1,6 @@
import abc import abc
import asyncio import asyncio
import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from astrbot.core.agent.message import Message from astrbot.core.agent.message import Message
@@ -11,6 +12,7 @@ from astrbot.core.provider.entities import (
ToolCallsResult, ToolCallsResult,
) )
from astrbot.core.provider.register import provider_cls_map from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
class AbstractProvider(abc.ABC): class AbstractProvider(abc.ABC):
@@ -43,6 +45,14 @@ class AbstractProvider(abc.ABC):
) )
return meta return meta
async def test(self) -> bool:
"""test the provider is a
Returns:
bool: the provider is available
"""
return True
class Provider(AbstractProvider): class Provider(AbstractProvider):
"""Chat Provider""" """Chat Provider"""
@@ -165,6 +175,16 @@ class Provider(AbstractProvider):
return dicts return dicts
async def test(self, timeout: float = 45.0) -> bool:
try:
response = await asyncio.wait_for(
self.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=timeout,
)
return response is not None
except Exception:
return False
class STTProvider(AbstractProvider): class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -177,6 +197,20 @@ class STTProvider(AbstractProvider):
"""获取音频的文本""" """获取音频的文本"""
raise NotImplementedError raise NotImplementedError
async def test(self) -> bool:
try:
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
return False
text_result = await self.get_text(sample_audio_path)
return isinstance(text_result, str) and bool(text_result)
except Exception:
return False
class TTSProvider(AbstractProvider): class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -189,6 +223,13 @@ class TTSProvider(AbstractProvider):
"""获取文本的音频,返回音频文件路径""" """获取文本的音频,返回音频文件路径"""
raise NotImplementedError raise NotImplementedError
async def test(self) -> bool:
try:
audio_result = await self.get_audio("hi")
return isinstance(audio_result, str) and bool(audio_result)
except Exception:
return False
class EmbeddingProvider(AbstractProvider): class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -211,6 +252,15 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度""" """获取向量的维度"""
... ...
async def test(self) -> bool:
try:
embedding_result = await self.get_embedding("health_check")
return isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
)
except Exception:
return False
async def get_embeddings_batch( async def get_embeddings_batch(
self, self,
texts: list[str], texts: list[str],
@@ -294,3 +344,10 @@ class RerankProvider(AbstractProvider):
) -> list[RerankResult]: ) -> list[RerankResult]:
"""获取查询和文档的重排序分数""" """获取查询和文档的重排序分数"""
... ...
async def test(self) -> bool:
try:
await self.rerank("Apple", documents=["apple", "banana"])
return True
except Exception:
return False
+23
View File
@@ -0,0 +1,23 @@
from pathlib import Path
from openai import AsyncOpenAI
async def extract_file_moonshotai(file_path: str, api_key: str) -> str:
"""Extract text from a file using Moonshot AI API"""
"""
Args:
file_path: The path to the file to extract text from
api_key: The API key to use to extract text from the file
Returns:
The text extracted from the file
"""
client = AsyncOpenAI(
api_key=api_key,
base_url="https://api.moonshot.cn/v1",
)
file_object = await client.files.create(
file=Path(file_path),
purpose="file-extract", # type: ignore
)
return (await client.files.content(file_id=file_object.id)).text
+17 -163
View File
@@ -18,11 +18,8 @@ from astrbot.core.config.i18n_utils import ConfigMetadataI18n
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.provider.register import provider_registry from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry from astrbot.core.star.star import star_registry
from astrbot.core.utils.astrbot_path import get_astrbot_path
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
@@ -356,169 +353,26 @@ class ConfigRoute(Route):
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
) )
if provider_capability_type == ProviderType.CHAT_COMPLETION: try:
try: result = await provider.test()
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") if result:
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=45.0,
)
logger.debug(
f"Received response from {status_info['name']}: {response}",
)
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if (
hasattr(response, "completion_text")
and response.completion_text
):
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
else response.completion_text
)
elif hasattr(response, "result_chain") and response.result_chain:
try:
response_text_snippet = (
response.result_chain.get_plain_text()[:70] + "..."
if len(response.result_chain.get_plain_text()) > 70
else response.result_chain.get_plain_text()
)
except Exception as _:
pass
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
)
else:
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
)
except asyncio.TimeoutError:
status_info["error"] = (
"Connection timed out after 45 seconds during test call."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
)
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
)
elif provider_capability_type == ProviderType.EMBEDDING:
try:
# For embedding, we can call the get_embedding method with a short prompt.
embedding_result = await provider.get_embedding("health_check")
if isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"Embedding test failed: unexpected result type {type(embedding_result)}"
)
except Exception as e:
logger.error(
f"Error testing embedding provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: {e!s}"
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
try:
# For TTS, we can call the get_audio method with a short prompt.
audio_result = await provider.get_audio("你好")
if isinstance(audio_result, str) and audio_result:
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"TTS test failed: unexpected result type {type(audio_result)}"
)
except Exception as e:
logger.error(
f"Error testing TTS provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: {e!s}"
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
try:
logger.debug(
f"Sending health check audio to provider: {status_info['name']}",
)
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable"
status_info["error"] = (
"STT test failed: sample audio file not found."
)
logger.warning(
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
)
else:
text_result = await provider.get_text(sample_audio_path)
if isinstance(text_result, str) and text_result:
status_info["status"] = "available"
snippet = (
text_result[:70] + "..."
if len(text_result) > 70
else text_result
)
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
)
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"STT test failed: unexpected result type {type(text_result)}"
)
logger.warning(
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
)
except Exception as e:
logger.error(
f"Error testing STT provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {e!s}"
elif provider_capability_type == ProviderType.RERANK:
try:
assert isinstance(provider, RerankProvider)
await provider.rerank("Apple", documents=["apple", "banana"])
status_info["status"] = "available" status_info["status"] = "available"
except Exception as e: logger.info(
logger.error( f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
f"Error testing rerank provider {provider_name}: {e}",
exc_info=True,
) )
status_info["status"] = "unavailable" else:
status_info["error"] = f"Rerank test failed: {e!s}" status_info["error"] = "Provider test returned False."
logger.warning(
else: f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.",
logger.debug( )
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}", except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
) )
status_info["status"] = "available" logger.debug(
status_info["error"] = ( f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
"This provider type is not tested and is assumed to be available."
) )
return status_info return status_info
@@ -109,6 +109,22 @@
} }
} }
}, },
"file_extract": {
"description": "File Extract",
"provider_settings": {
"file_extract": {
"enable": {
"description": "Enable File Extract"
},
"provider": {
"description": "File Extract Provider"
},
"moonshotai_api_key": {
"description": "Moonshot AI API Key"
}
}
}
},
"others": { "others": {
"description": "Other Settings", "description": "Other Settings",
"provider_settings": { "provider_settings": {
@@ -159,6 +175,10 @@
"prompt_prefix": { "prompt_prefix": {
"description": "User Prompt", "description": "User Prompt",
"hint": "You can use {{prompt}} as a placeholder for user input. If no placeholder is provided, it will be added before the user input." "hint": "You can use {{prompt}} as a placeholder for user input. If no placeholder is provided, it will be added before the user input."
},
"reachability_check": {
"description": "Provider Reachability Check",
"hint": "When running the /provider command, test provider connectivity in parallel. This actively pings models and may consume extra tokens."
} }
}, },
"provider_tts_settings": { "provider_tts_settings": {
@@ -453,4 +473,4 @@
} }
} }
} }
} }
@@ -11,7 +11,12 @@
}, },
"agent_runner_type": { "agent_runner_type": {
"description": "执行器", "description": "执行器",
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"] "labels": [
"内置 Agent",
"Dify",
"Coze",
"阿里云百炼应用"
]
}, },
"coze_agent_runner_provider_id": { "coze_agent_runner_provider_id": {
"description": "Coze Agent 执行器提供商 ID" "description": "Coze Agent 执行器提供商 ID"
@@ -109,6 +114,22 @@
} }
} }
}, },
"file_extract": {
"description": "文档解析能力",
"provider_settings": {
"file_extract": {
"enable": {
"description": "启用文档解析能力"
},
"provider": {
"description": "文档解析提供商"
},
"moonshotai_api_key": {
"description": "Moonshot AI API Key"
}
}
}
},
"others": { "others": {
"description": "其他配置", "description": "其他配置",
"provider_settings": { "provider_settings": {
@@ -142,7 +163,10 @@
"unsupported_streaming_strategy": { "unsupported_streaming_strategy": {
"description": "不支持流式回复的平台", "description": "不支持流式回复的平台",
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容", "hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
"labels": ["实时分段回复", "关闭流式回复"] "labels": [
"实时分段回复",
"关闭流式回复"
]
}, },
"max_context_length": { "max_context_length": {
"description": "最多携带对话轮数", "description": "最多携带对话轮数",
@@ -159,6 +183,10 @@
"prompt_prefix": { "prompt_prefix": {
"description": "用户提示词", "description": "用户提示词",
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。" "hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。"
},
"reachability_check": {
"description": "提供商可达性检测",
"hint": "/provider 命令列出模型时并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。"
} }
}, },
"provider_tts_settings": { "provider_tts_settings": {
@@ -453,4 +481,4 @@
} }
} }
} }
} }
+153 -19
View File
@@ -1,5 +1,7 @@
import asyncio
import re import re
from astrbot import logger
from astrbot.api import star from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.entities import ProviderType
@@ -9,6 +11,53 @@ class ProviderCommands:
def __init__(self, context: star.Context): def __init__(self, context: star.Context):
self.context = context self.context = context
def _log_reachability_failure(
self,
provider,
provider_capability_type: ProviderType | None,
err_code: str,
err_reason: str,
):
"""记录不可达原因到日志。"""
meta = provider.meta()
logger.warning(
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
meta.id,
provider_capability_type.name if provider_capability_type else "unknown",
err_code,
err_reason,
)
async def _test_provider_capability(self, provider):
"""测试单个 provider 的可用性"""
meta = provider.meta()
provider_capability_type = meta.provider_type
try:
result = await provider.test()
if result:
return True, None, None
err_code = "TEST_FAILED"
err_reason = "Provider test returned False"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
except Exception as exc:
err_code = (
getattr(exc, "status_code", None)
or getattr(exc, "code", None)
or getattr(exc, "error_code", None)
)
err_reason = str(exc)
if not err_code:
err_code = exc.__class__.__name__
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
async def provider( async def provider(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
@@ -17,46 +66,131 @@ class ProviderCommands:
): ):
"""查看或者切换 LLM Provider""" """查看或者切换 LLM Provider"""
umo = event.unified_msg_origin umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {})
reachability_check_enabled = cfg.get("reachability_check", True)
if idx is None: if idx is None:
parts = ["## 载入的 LLM 提供商\n"] parts = ["## 载入的 LLM 提供商\n"]
for idx, llm in enumerate(self.context.get_all_providers()):
id_ = llm.meta().id # 获取所有类型的提供商
line = f"{idx + 1}. {id_} ({llm.meta().model})" llms = list(self.context.get_all_providers())
ttss = self.context.get_all_tts_providers()
stts = self.context.get_all_stt_providers()
# 构造待检测列表: [(provider, type_label), ...]
all_providers = []
all_providers.extend([(p, "llm") for p in llms])
all_providers.extend([(p, "tts") for p in ttss])
all_providers.extend([(p, "stt") for p in stts])
# 并发测试连通性
if reachability_check_enabled:
if all_providers:
await event.send(
MessageEventResult().message(
"正在进行提供商可达性测试,请稍候..."
)
)
check_results = await asyncio.gather(
*[self._test_provider_capability(p) for p, _ in all_providers],
return_exceptions=True,
)
else:
# 用 None 表示未检测
check_results = [None for _ in all_providers]
# 整合结果
display_data = []
for (p, p_type), reachable in zip(all_providers, check_results):
meta = p.meta()
id_ = meta.id
error_code = None
if isinstance(reachable, Exception):
# 异常情况下兜底处理,避免单个 provider 导致列表失败
self._log_reachability_failure(
p,
None,
reachable.__class__.__name__,
str(reachable),
)
reachable_flag = False
error_code = reachable.__class__.__name__
elif isinstance(reachable, tuple):
reachable_flag, error_code, _ = reachable
else:
reachable_flag = reachable
# 根据类型构建显示名称
if p_type == "llm":
info = f"{id_} ({meta.model})"
else:
info = f"{id_}"
# 确定状态标记
if reachable_flag is True:
mark = ""
elif reachable_flag is False:
if error_code:
mark = f" ❌(错误码: {error_code})"
else:
mark = ""
else:
mark = "" # 不支持检测时不显示标记
display_data.append(
{
"type": p_type,
"info": info,
"mark": mark,
"provider": p,
}
)
# 分组输出
# 1. LLM
llm_data = [d for d in display_data if d["type"] == "llm"]
for i, d in enumerate(llm_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
provider_using = self.context.get_using_provider(umo=umo) provider_using = self.context.get_using_provider(umo=umo)
if provider_using and provider_using.meta().id == id_: if (
provider_using
and provider_using.meta().id == d["provider"].meta().id
):
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
tts_providers = self.context.get_all_tts_providers() # 2. TTS
if tts_providers: tts_data = [d for d in display_data if d["type"] == "tts"]
if tts_data:
parts.append("\n## 载入的 TTS 提供商\n") parts.append("\n## 载入的 TTS 提供商\n")
for idx, tts in enumerate(tts_providers): for i, d in enumerate(tts_data):
id_ = tts.meta().id line = f"{i + 1}. {d['info']}{d['mark']}"
line = f"{idx + 1}. {id_}"
tts_using = self.context.get_using_tts_provider(umo=umo) tts_using = self.context.get_using_tts_provider(umo=umo)
if tts_using and tts_using.meta().id == id_: if tts_using and tts_using.meta().id == d["provider"].meta().id:
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
stt_providers = self.context.get_all_stt_providers() # 3. STT
if stt_providers: stt_data = [d for d in display_data if d["type"] == "stt"]
if stt_data:
parts.append("\n## 载入的 STT 提供商\n") parts.append("\n## 载入的 STT 提供商\n")
for idx, stt in enumerate(stt_providers): for i, d in enumerate(stt_data):
id_ = stt.meta().id line = f"{i + 1}. {d['info']}{d['mark']}"
line = f"{idx + 1}. {id_}"
stt_using = self.context.get_using_stt_provider(umo=umo) stt_using = self.context.get_using_stt_provider(umo=umo)
if stt_using and stt_using.meta().id == id_: if stt_using and stt_using.meta().id == d["provider"].meta().id:
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。") parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
ret = "".join(parts) ret = "".join(parts)
if tts_providers: if ttss:
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
if stt_providers: if stts:
ret += "\n使用 /provider stt <切换> STT 提供商。" ret += "\n使用 /provider stt <序号> 切换 STT 提供商。"
if not reachability_check_enabled:
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
event.set_result(MessageEventResult().message(ret)) event.set_result(MessageEventResult().message(ret))
elif idx == "tts": elif idx == "tts":
+13 -7
View File
@@ -8,7 +8,7 @@ from astrbot.api import star
from astrbot.api.event import AstrMessageEvent from astrbot.api.event import AstrMessageEvent
from astrbot.api.message_components import At, Image, Plain from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import MessageType from astrbot.api.platform import MessageType
from astrbot.api.provider import Provider, ProviderRequest from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
""" """
@@ -158,8 +158,12 @@ class LongTermMemory:
cfg = self.cfg(event) cfg = self.cfg(event)
if cfg["enable_active_reply"]: if cfg["enable_active_reply"]:
prompt = req.prompt prompt = req.prompt
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}" req.prompt = (
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information." f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
f"\nNow, a new message is coming: `{prompt}`. "
"Please react to it. Only output your response and do not output any other information. "
"You MUST use the SAME language as the chatroom is using."
)
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。 req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
else: else:
req.system_prompt += ( req.system_prompt += (
@@ -167,13 +171,15 @@ class LongTermMemory:
) )
req.system_prompt += chats_str req.system_prompt += chats_str
async def after_req_llm(self, event: AstrMessageEvent): async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
if event.unified_msg_origin not in self.session_chats: if event.unified_msg_origin not in self.session_chats:
return return
if event.get_result() and event.get_result().is_llm_result(): if llm_resp.completion_text:
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}" final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}"
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") logger.debug(
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
)
self.session_chats[event.unified_msg_origin].append(final_message) self.session_chats[event.unified_msg_origin].append(final_message)
cfg = self.cfg(event) cfg = self.cfg(event)
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
+2 -5
View File
@@ -322,7 +322,7 @@ class Main(star.Star):
@filter.on_llm_response() @filter.on_llm_response()
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse): async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
"""在 LLM 响应后基于配置注入思考过程文本""" """在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
umo = event.unified_msg_origin umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {}) cfg = self.context.get_config(umo).get("provider_settings", {})
show_reasoning = cfg.get("display_reasoning_text", False) show_reasoning = cfg.get("display_reasoning_text", False)
@@ -331,12 +331,9 @@ class Main(star.Star):
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}" f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
) )
@filter.after_message_sent()
async def after_llm_req(self, event: AstrMessageEvent):
"""在 LLM 请求后记录对话"""
if self.ltm and self.ltm_enabled(event): if self.ltm and self.ltm_enabled(event):
try: try:
await self.ltm.after_req_llm(event) await self.ltm.after_req_llm(event, resp)
except Exception as e: except Exception as e:
logger.error(f"ltm: {e}") logger.error(f"ltm: {e}")