Compare commits
9 Commits
perf/provi
...
feat/file-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1de377e749 | ||
|
|
6aa6963ab5 | ||
|
|
d3001d8148 | ||
|
|
380c4faf17 | ||
|
|
bd2a88783c | ||
|
|
17d7f822e7 | ||
|
|
0e034f0fbd | ||
|
|
2a7d03f9e1 | ||
|
|
72fac4b9f1 |
@@ -76,6 +76,11 @@ DEFAULT_CONFIG = {
|
|||||||
"reachability_check": False,
|
"reachability_check": False,
|
||||||
"max_agent_step": 30,
|
"max_agent_step": 30,
|
||||||
"tool_call_timeout": 60,
|
"tool_call_timeout": 60,
|
||||||
|
"file_extract": {
|
||||||
|
"enable": False,
|
||||||
|
"provider": "moonshotai",
|
||||||
|
"moonshotai_api_key": "",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"provider_stt_settings": {
|
"provider_stt_settings": {
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -2069,6 +2074,20 @@ CONFIG_METADATA_2 = {
|
|||||||
"tool_call_timeout": {
|
"tool_call_timeout": {
|
||||||
"type": "int",
|
"type": "int",
|
||||||
},
|
},
|
||||||
|
"file_extract": {
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"enable": {
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"moonshotai_api_key": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_stt_settings": {
|
"provider_stt_settings": {
|
||||||
@@ -2403,6 +2422,36 @@ CONFIG_METADATA_3 = {
|
|||||||
"provider_settings.enable": True,
|
"provider_settings.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"file_extract": {
|
||||||
|
"description": "文档解析能力",
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"provider_settings.file_extract.enable": {
|
||||||
|
"description": "启用文档解析能力",
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
|
"provider_settings.file_extract.provider": {
|
||||||
|
"description": "文档解析提供商",
|
||||||
|
"type": "string",
|
||||||
|
"options": ["moonshotai"],
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.file_extract.enable": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"provider_settings.file_extract.moonshotai_api_key": {
|
||||||
|
"description": "Moonshot AI API Key",
|
||||||
|
"type": "string",
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.file_extract.provider": "moonshotai",
|
||||||
|
"provider_settings.file_extract.enable": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.agent_runner_type": "local",
|
||||||
|
"provider_settings.enable": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
"others": {
|
"others": {
|
||||||
"description": "其他配置",
|
"description": "其他配置",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
@@ -722,7 +722,12 @@ class File(BaseMessageComponent):
|
|||||||
"""下载文件"""
|
"""下载文件"""
|
||||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
os.makedirs(download_dir, exist_ok=True)
|
os.makedirs(download_dir, exist_ok=True)
|
||||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
if self.name:
|
||||||
|
name, ext = os.path.splitext(self.name)
|
||||||
|
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
|
||||||
|
else:
|
||||||
|
filename = f"{uuid.uuid4().hex}"
|
||||||
|
file_path = os.path.join(download_dir, filename)
|
||||||
await download_file(self.url, file_path)
|
await download_file(self.url, file_path)
|
||||||
self.file_ = os.path.abspath(file_path)
|
self.file_ = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from astrbot.core import logger
|
|||||||
from astrbot.core.agent.tool import ToolSet
|
from astrbot.core.agent.tool import ToolSet
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
from astrbot.core.conversation_mgr import Conversation
|
from astrbot.core.conversation_mgr import Conversation
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import File, Image, Reply
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageChain,
|
MessageChain,
|
||||||
MessageEventResult,
|
MessageEventResult,
|
||||||
@@ -22,6 +22,7 @@ from astrbot.core.provider.entities import (
|
|||||||
ProviderRequest,
|
ProviderRequest,
|
||||||
)
|
)
|
||||||
from astrbot.core.star.star_handler import EventType, star_map
|
from astrbot.core.star.star_handler import EventType, star_map
|
||||||
|
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from astrbot.core.utils.session_lock import session_lock_manager
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
@@ -56,6 +57,13 @@ class InternalAgentSubStage(Stage):
|
|||||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||||
|
|
||||||
|
file_extract_conf: dict = settings.get("file_extract", {})
|
||||||
|
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
|
||||||
|
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
|
||||||
|
self.file_extract_msh_api_key: str = file_extract_conf.get(
|
||||||
|
"moonshotai_api_key", ""
|
||||||
|
)
|
||||||
|
|
||||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||||
|
|
||||||
def _select_provider(self, event: AstrMessageEvent):
|
def _select_provider(self, event: AstrMessageEvent):
|
||||||
@@ -114,6 +122,50 @@ class InternalAgentSubStage(Stage):
|
|||||||
req.func_tool = ToolSet()
|
req.func_tool = ToolSet()
|
||||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||||
|
|
||||||
|
async def _apply_file_extract(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
):
|
||||||
|
"""Apply file extract to the provider request"""
|
||||||
|
file_paths = []
|
||||||
|
file_names = []
|
||||||
|
for comp in event.message_obj.message:
|
||||||
|
if isinstance(comp, File):
|
||||||
|
file_paths.append(await comp.get_file())
|
||||||
|
file_names.append(comp.name)
|
||||||
|
elif isinstance(comp, Reply) and comp.chain:
|
||||||
|
for reply_comp in comp.chain:
|
||||||
|
if isinstance(reply_comp, File):
|
||||||
|
file_paths.append(await reply_comp.get_file())
|
||||||
|
file_names.append(reply_comp.name)
|
||||||
|
if not file_paths:
|
||||||
|
return
|
||||||
|
if not req.prompt:
|
||||||
|
req.prompt = "总结一下文件里面讲了什么?"
|
||||||
|
if self.file_extract_prov == "moonshotai":
|
||||||
|
if not self.file_extract_msh_api_key:
|
||||||
|
logger.error("Moonshot AI API key for file extract is not set")
|
||||||
|
return
|
||||||
|
file_contents = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
|
||||||
|
for file_path in file_paths
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# add file extract results to contexts
|
||||||
|
for file_content, file_name in zip(file_contents, file_names):
|
||||||
|
req.contexts.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def _truncate_contexts(
|
def _truncate_contexts(
|
||||||
self,
|
self,
|
||||||
contexts: list[dict],
|
contexts: list[dict],
|
||||||
@@ -346,6 +398,17 @@ class InternalAgentSubStage(Stage):
|
|||||||
|
|
||||||
event.set_extra("provider_request", req)
|
event.set_extra("provider_request", req)
|
||||||
|
|
||||||
|
# fix contexts json str
|
||||||
|
if isinstance(req.contexts, str):
|
||||||
|
req.contexts = json.loads(req.contexts)
|
||||||
|
|
||||||
|
# apply file extract
|
||||||
|
if self.file_extract_enabled:
|
||||||
|
try:
|
||||||
|
await self._apply_file_extract(event, req)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred while applying file extract: {e}")
|
||||||
|
|
||||||
if not req.prompt and not req.image_urls:
|
if not req.prompt and not req.image_urls:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -356,10 +419,6 @@ class InternalAgentSubStage(Stage):
|
|||||||
# apply knowledge base feature
|
# apply knowledge base feature
|
||||||
await self._apply_kb(event, req)
|
await self._apply_kb(event, req)
|
||||||
|
|
||||||
# fix contexts json str
|
|
||||||
if isinstance(req.contexts, str):
|
|
||||||
req.contexts = json.loads(req.contexts)
|
|
||||||
|
|
||||||
# truncate contexts to fit max length
|
# truncate contexts to fit max length
|
||||||
if req.contexts:
|
if req.contexts:
|
||||||
req.contexts = self._truncate_contexts(req.contexts)
|
req.contexts = self._truncate_contexts(req.contexts)
|
||||||
|
|||||||
@@ -246,7 +246,13 @@ class AiocqhttpAdapter(Platform):
|
|||||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||||
# Lagrange
|
# Lagrange
|
||||||
logger.info("guessing lagrange")
|
logger.info("guessing lagrange")
|
||||||
file_name = m["data"].get("file_name", "file")
|
# 检查多个可能的文件名字段
|
||||||
|
file_name = (
|
||||||
|
m["data"].get("file_name", "")
|
||||||
|
or m["data"].get("name", "")
|
||||||
|
or m["data"].get("file", "")
|
||||||
|
or "file"
|
||||||
|
)
|
||||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@@ -265,7 +271,14 @@ class AiocqhttpAdapter(Platform):
|
|||||||
)
|
)
|
||||||
if ret and "url" in ret:
|
if ret and "url" in ret:
|
||||||
file_url = ret["url"] # https
|
file_url = ret["url"] # https
|
||||||
a = File(name="", url=file_url)
|
# 优先从 API 返回值获取文件名,其次从原始消息数据获取
|
||||||
|
file_name = (
|
||||||
|
ret.get("file_name", "")
|
||||||
|
or ret.get("name", "")
|
||||||
|
or m["data"].get("file", "")
|
||||||
|
or m["data"].get("file_name", "")
|
||||||
|
)
|
||||||
|
a = File(name=file_name, url=file_url)
|
||||||
abm.message.append(a)
|
abm.message.append(a)
|
||||||
else:
|
else:
|
||||||
logger.error(f"获取文件失败: {ret}")
|
logger.error(f"获取文件失败: {ret}")
|
||||||
|
|||||||
@@ -381,7 +381,9 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
f"Telegram document file_path is None, cannot save the file {file_name}.",
|
f"Telegram document file_path is None, cannot save the file {file_name}.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
message.message.append(
|
||||||
|
Comp.File(file=file_path, name=file_name, url=file_path)
|
||||||
|
)
|
||||||
|
|
||||||
elif update.message.video:
|
elif update.message.video:
|
||||||
file = await update.message.video.get_file()
|
file = await update.message.video.get_file()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from astrbot.core.agent.message import Message
|
from astrbot.core.agent.message import Message
|
||||||
@@ -11,6 +12,7 @@ from astrbot.core.provider.entities import (
|
|||||||
ToolCallsResult,
|
ToolCallsResult,
|
||||||
)
|
)
|
||||||
from astrbot.core.provider.register import provider_cls_map
|
from astrbot.core.provider.register import provider_cls_map
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||||
|
|
||||||
|
|
||||||
class AbstractProvider(abc.ABC):
|
class AbstractProvider(abc.ABC):
|
||||||
@@ -43,6 +45,14 @@ class AbstractProvider(abc.ABC):
|
|||||||
)
|
)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
|
async def test(self) -> bool:
|
||||||
|
"""test the provider is a
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: the provider is available
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Provider(AbstractProvider):
|
class Provider(AbstractProvider):
|
||||||
"""Chat Provider"""
|
"""Chat Provider"""
|
||||||
@@ -165,6 +175,16 @@ class Provider(AbstractProvider):
|
|||||||
|
|
||||||
return dicts
|
return dicts
|
||||||
|
|
||||||
|
async def test(self, timeout: float = 45.0) -> bool:
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
return response is not None
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class STTProvider(AbstractProvider):
|
class STTProvider(AbstractProvider):
|
||||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||||
@@ -177,6 +197,20 @@ class STTProvider(AbstractProvider):
|
|||||||
"""获取音频的文本"""
|
"""获取音频的文本"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def test(self) -> bool:
|
||||||
|
try:
|
||||||
|
sample_audio_path = os.path.join(
|
||||||
|
get_astrbot_path(),
|
||||||
|
"samples",
|
||||||
|
"stt_health_check.wav",
|
||||||
|
)
|
||||||
|
if not os.path.exists(sample_audio_path):
|
||||||
|
return False
|
||||||
|
text_result = await self.get_text(sample_audio_path)
|
||||||
|
return isinstance(text_result, str) and bool(text_result)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TTSProvider(AbstractProvider):
|
class TTSProvider(AbstractProvider):
|
||||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||||
@@ -189,6 +223,13 @@ class TTSProvider(AbstractProvider):
|
|||||||
"""获取文本的音频,返回音频文件路径"""
|
"""获取文本的音频,返回音频文件路径"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def test(self) -> bool:
|
||||||
|
try:
|
||||||
|
audio_result = await self.get_audio("hi")
|
||||||
|
return isinstance(audio_result, str) and bool(audio_result)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProvider(AbstractProvider):
|
class EmbeddingProvider(AbstractProvider):
|
||||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||||
@@ -211,6 +252,15 @@ class EmbeddingProvider(AbstractProvider):
|
|||||||
"""获取向量的维度"""
|
"""获取向量的维度"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def test(self) -> bool:
|
||||||
|
try:
|
||||||
|
embedding_result = await self.get_embedding("health_check")
|
||||||
|
return isinstance(embedding_result, list) and (
|
||||||
|
not embedding_result or isinstance(embedding_result[0], float)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
async def get_embeddings_batch(
|
async def get_embeddings_batch(
|
||||||
self,
|
self,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@@ -294,3 +344,10 @@ class RerankProvider(AbstractProvider):
|
|||||||
) -> list[RerankResult]:
|
) -> list[RerankResult]:
|
||||||
"""获取查询和文档的重排序分数"""
|
"""获取查询和文档的重排序分数"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def test(self) -> bool:
|
||||||
|
try:
|
||||||
|
await self.rerank("Apple", documents=["apple", "banana"])
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|||||||
23
astrbot/core/utils/file_extract.py
Normal file
23
astrbot/core/utils/file_extract.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_file_moonshotai(file_path: str, api_key: str) -> str:
|
||||||
|
"""Extract text from a file using Moonshot AI API"""
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
file_path: The path to the file to extract text from
|
||||||
|
api_key: The API key to use to extract text from the file
|
||||||
|
Returns:
|
||||||
|
The text extracted from the file
|
||||||
|
"""
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url="https://api.moonshot.cn/v1",
|
||||||
|
)
|
||||||
|
file_object = await client.files.create(
|
||||||
|
file=Path(file_path),
|
||||||
|
purpose="file-extract", # type: ignore
|
||||||
|
)
|
||||||
|
return (await client.files.content(file_id=file_object.id)).text
|
||||||
@@ -18,11 +18,8 @@ from astrbot.core.config.i18n_utils import ConfigMetadataI18n
|
|||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
from astrbot.core.platform.register import platform_cls_map, platform_registry
|
from astrbot.core.platform.register import platform_cls_map, platform_registry
|
||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.provider.entities import ProviderType
|
|
||||||
from astrbot.core.provider.provider import RerankProvider
|
|
||||||
from astrbot.core.provider.register import provider_registry
|
from astrbot.core.provider.register import provider_registry
|
||||||
from astrbot.core.star.star import star_registry
|
from astrbot.core.star.star import star_registry
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
|
||||||
|
|
||||||
from .route import Response, Route, RouteContext
|
from .route import Response, Route, RouteContext
|
||||||
|
|
||||||
@@ -356,169 +353,26 @@ class ConfigRoute(Route):
|
|||||||
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
|
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider_capability_type == ProviderType.CHAT_COMPLETION:
|
try:
|
||||||
try:
|
result = await provider.test()
|
||||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
if result:
|
||||||
response = await asyncio.wait_for(
|
|
||||||
provider.text_chat(prompt="REPLY `PONG` ONLY"),
|
|
||||||
timeout=45.0,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Received response from {status_info['name']}: {response}",
|
|
||||||
)
|
|
||||||
if response is not None:
|
|
||||||
status_info["status"] = "available"
|
|
||||||
response_text_snippet = ""
|
|
||||||
if (
|
|
||||||
hasattr(response, "completion_text")
|
|
||||||
and response.completion_text
|
|
||||||
):
|
|
||||||
response_text_snippet = (
|
|
||||||
response.completion_text[:70] + "..."
|
|
||||||
if len(response.completion_text) > 70
|
|
||||||
else response.completion_text
|
|
||||||
)
|
|
||||||
elif hasattr(response, "result_chain") and response.result_chain:
|
|
||||||
try:
|
|
||||||
response_text_snippet = (
|
|
||||||
response.result_chain.get_plain_text()[:70] + "..."
|
|
||||||
if len(response.result_chain.get_plain_text()) > 70
|
|
||||||
else response.result_chain.get_plain_text()
|
|
||||||
)
|
|
||||||
except Exception as _:
|
|
||||||
pass
|
|
||||||
logger.info(
|
|
||||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
status_info["error"] = (
|
|
||||||
"Test call returned None, but expected an LLMResponse object."
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
|
|
||||||
)
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
status_info["error"] = (
|
|
||||||
"Connection timed out after 45 seconds during test call."
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
status_info["error"] = error_message
|
|
||||||
logger.warning(
|
|
||||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
|
|
||||||
)
|
|
||||||
|
|
||||||
elif provider_capability_type == ProviderType.EMBEDDING:
|
|
||||||
try:
|
|
||||||
# For embedding, we can call the get_embedding method with a short prompt.
|
|
||||||
embedding_result = await provider.get_embedding("health_check")
|
|
||||||
if isinstance(embedding_result, list) and (
|
|
||||||
not embedding_result or isinstance(embedding_result[0], float)
|
|
||||||
):
|
|
||||||
status_info["status"] = "available"
|
|
||||||
else:
|
|
||||||
status_info["status"] = "unavailable"
|
|
||||||
status_info["error"] = (
|
|
||||||
f"Embedding test failed: unexpected result type {type(embedding_result)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error testing embedding provider {provider_name}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
status_info["status"] = "unavailable"
|
|
||||||
status_info["error"] = f"Embedding test failed: {e!s}"
|
|
||||||
|
|
||||||
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
|
|
||||||
try:
|
|
||||||
# For TTS, we can call the get_audio method with a short prompt.
|
|
||||||
audio_result = await provider.get_audio("你好")
|
|
||||||
if isinstance(audio_result, str) and audio_result:
|
|
||||||
status_info["status"] = "available"
|
|
||||||
else:
|
|
||||||
status_info["status"] = "unavailable"
|
|
||||||
status_info["error"] = (
|
|
||||||
f"TTS test failed: unexpected result type {type(audio_result)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error testing TTS provider {provider_name}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
status_info["status"] = "unavailable"
|
|
||||||
status_info["error"] = f"TTS test failed: {e!s}"
|
|
||||||
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"Sending health check audio to provider: {status_info['name']}",
|
|
||||||
)
|
|
||||||
sample_audio_path = os.path.join(
|
|
||||||
get_astrbot_path(),
|
|
||||||
"samples",
|
|
||||||
"stt_health_check.wav",
|
|
||||||
)
|
|
||||||
if not os.path.exists(sample_audio_path):
|
|
||||||
status_info["status"] = "unavailable"
|
|
||||||
status_info["error"] = (
|
|
||||||
"STT test failed: sample audio file not found."
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text_result = await provider.get_text(sample_audio_path)
|
|
||||||
if isinstance(text_result, str) and text_result:
|
|
||||||
status_info["status"] = "available"
|
|
||||||
snippet = (
|
|
||||||
text_result[:70] + "..."
|
|
||||||
if len(text_result) > 70
|
|
||||||
else text_result
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
status_info["status"] = "unavailable"
|
|
||||||
status_info["error"] = (
|
|
||||||
f"STT test failed: unexpected result type {type(text_result)}"
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error testing STT provider {provider_name}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
status_info["status"] = "unavailable"
|
|
||||||
status_info["error"] = f"STT test failed: {e!s}"
|
|
||||||
elif provider_capability_type == ProviderType.RERANK:
|
|
||||||
try:
|
|
||||||
assert isinstance(provider, RerankProvider)
|
|
||||||
await provider.rerank("Apple", documents=["apple", "banana"])
|
|
||||||
status_info["status"] = "available"
|
status_info["status"] = "available"
|
||||||
except Exception as e:
|
logger.info(
|
||||||
logger.error(
|
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
|
||||||
f"Error testing rerank provider {provider_name}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
status_info["status"] = "unavailable"
|
else:
|
||||||
status_info["error"] = f"Rerank test failed: {e!s}"
|
status_info["error"] = "Provider test returned False."
|
||||||
|
logger.warning(
|
||||||
else:
|
f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.",
|
||||||
logger.debug(
|
)
|
||||||
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}",
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
status_info["error"] = error_message
|
||||||
|
logger.warning(
|
||||||
|
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
|
||||||
)
|
)
|
||||||
status_info["status"] = "available"
|
logger.debug(
|
||||||
status_info["error"] = (
|
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
|
||||||
"This provider type is not tested and is assumed to be available."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return status_info
|
return status_info
|
||||||
|
|||||||
@@ -109,6 +109,22 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"file_extract": {
|
||||||
|
"description": "File Extract",
|
||||||
|
"provider_settings": {
|
||||||
|
"file_extract": {
|
||||||
|
"enable": {
|
||||||
|
"description": "Enable File Extract"
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"description": "File Extract Provider"
|
||||||
|
},
|
||||||
|
"moonshotai_api_key": {
|
||||||
|
"description": "Moonshot AI API Key"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"others": {
|
"others": {
|
||||||
"description": "Other Settings",
|
"description": "Other Settings",
|
||||||
"provider_settings": {
|
"provider_settings": {
|
||||||
|
|||||||
@@ -11,7 +11,12 @@
|
|||||||
},
|
},
|
||||||
"agent_runner_type": {
|
"agent_runner_type": {
|
||||||
"description": "执行器",
|
"description": "执行器",
|
||||||
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"]
|
"labels": [
|
||||||
|
"内置 Agent",
|
||||||
|
"Dify",
|
||||||
|
"Coze",
|
||||||
|
"阿里云百炼应用"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"coze_agent_runner_provider_id": {
|
"coze_agent_runner_provider_id": {
|
||||||
"description": "Coze Agent 执行器提供商 ID"
|
"description": "Coze Agent 执行器提供商 ID"
|
||||||
@@ -109,6 +114,22 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"file_extract": {
|
||||||
|
"description": "文档解析能力",
|
||||||
|
"provider_settings": {
|
||||||
|
"file_extract": {
|
||||||
|
"enable": {
|
||||||
|
"description": "启用文档解析能力"
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"description": "文档解析提供商"
|
||||||
|
},
|
||||||
|
"moonshotai_api_key": {
|
||||||
|
"description": "Moonshot AI API Key"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"others": {
|
"others": {
|
||||||
"description": "其他配置",
|
"description": "其他配置",
|
||||||
"provider_settings": {
|
"provider_settings": {
|
||||||
@@ -142,7 +163,10 @@
|
|||||||
"unsupported_streaming_strategy": {
|
"unsupported_streaming_strategy": {
|
||||||
"description": "不支持流式回复的平台",
|
"description": "不支持流式回复的平台",
|
||||||
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
|
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
|
||||||
"labels": ["实时分段回复", "关闭流式回复"]
|
"labels": [
|
||||||
|
"实时分段回复",
|
||||||
|
"关闭流式回复"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"max_context_length": {
|
"max_context_length": {
|
||||||
"description": "最多携带对话轮数",
|
"description": "最多携带对话轮数",
|
||||||
@@ -457,4 +481,4 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,15 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api import star
|
from astrbot.api import star
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||||
from astrbot.core.provider.entities import ProviderType
|
from astrbot.core.provider.entities import ProviderType
|
||||||
from astrbot.core.provider.provider import RerankProvider
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
|
||||||
|
|
||||||
REACHABILITY_CHECK_TIMEOUT = 30.0
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderCommands:
|
class ProviderCommands:
|
||||||
@@ -34,121 +29,20 @@ class ProviderCommands:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _test_provider_capability(self, provider):
|
async def _test_provider_capability(self, provider):
|
||||||
"""测试单个 provider 的可用性 (复用 Dashboard 的检测逻辑)"""
|
"""测试单个 provider 的可用性"""
|
||||||
meta = provider.meta()
|
meta = provider.meta()
|
||||||
provider_capability_type = meta.provider_type
|
provider_capability_type = meta.provider_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if provider_capability_type == ProviderType.CHAT_COMPLETION:
|
result = await provider.test()
|
||||||
# 发送 "Ping" 测试对话
|
if result:
|
||||||
response = await asyncio.wait_for(
|
return True, None, None
|
||||||
provider.text_chat(prompt="REPLY `PONG` ONLY"),
|
err_code = "TEST_FAILED"
|
||||||
timeout=REACHABILITY_CHECK_TIMEOUT,
|
err_reason = "Provider test returned False"
|
||||||
)
|
self._log_reachability_failure(
|
||||||
if response is not None:
|
provider, provider_capability_type, err_code, err_reason
|
||||||
return True, None, None
|
)
|
||||||
err_code = "EMPTY_RESPONSE"
|
return False, err_code, err_reason
|
||||||
err_reason = "Provider returned empty response"
|
|
||||||
self._log_reachability_failure(
|
|
||||||
provider, provider_capability_type, err_code, err_reason
|
|
||||||
)
|
|
||||||
return False, err_code, err_reason
|
|
||||||
|
|
||||||
elif provider_capability_type == ProviderType.EMBEDDING:
|
|
||||||
# 测试 Embedding
|
|
||||||
embedding_result = await asyncio.wait_for(
|
|
||||||
provider.get_embedding("health_check"),
|
|
||||||
timeout=REACHABILITY_CHECK_TIMEOUT,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
isinstance(embedding_result, list)
|
|
||||||
and embedding_result
|
|
||||||
and all(isinstance(x, (int, float)) for x in embedding_result)
|
|
||||||
):
|
|
||||||
return True, None, None
|
|
||||||
err_code = "INVALID_EMBEDDING"
|
|
||||||
err_reason = "Provider returned invalid embedding"
|
|
||||||
self._log_reachability_failure(
|
|
||||||
provider, provider_capability_type, err_code, err_reason
|
|
||||||
)
|
|
||||||
return False, err_code, err_reason
|
|
||||||
|
|
||||||
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
|
|
||||||
# 测试 TTS
|
|
||||||
audio_result = await asyncio.wait_for(
|
|
||||||
provider.get_audio("你好"),
|
|
||||||
timeout=REACHABILITY_CHECK_TIMEOUT,
|
|
||||||
)
|
|
||||||
if isinstance(audio_result, str) and audio_result:
|
|
||||||
# 清理检测生成的临时音频文件,避免频繁检测时堆积
|
|
||||||
if os.path.isfile(audio_result):
|
|
||||||
try:
|
|
||||||
os.remove(audio_result)
|
|
||||||
except OSError as e:
|
|
||||||
logger.debug(
|
|
||||||
"Failed to cleanup TTS health check file %s: %s",
|
|
||||||
audio_result,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
return True, None, None
|
|
||||||
err_code = "INVALID_AUDIO"
|
|
||||||
err_reason = "Provider returned invalid audio"
|
|
||||||
self._log_reachability_failure(
|
|
||||||
provider, provider_capability_type, err_code, err_reason
|
|
||||||
)
|
|
||||||
return False, err_code, err_reason
|
|
||||||
|
|
||||||
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
|
|
||||||
# 测试 STT
|
|
||||||
sample_audio_path = os.path.join(
|
|
||||||
get_astrbot_path(),
|
|
||||||
"samples",
|
|
||||||
"stt_health_check.wav",
|
|
||||||
)
|
|
||||||
if not os.path.exists(sample_audio_path):
|
|
||||||
# 如果样本文件不存在,降级为检查是否实现了方法
|
|
||||||
return hasattr(provider, "get_text"), None, None
|
|
||||||
|
|
||||||
text_result = await asyncio.wait_for(
|
|
||||||
provider.get_text(sample_audio_path),
|
|
||||||
timeout=REACHABILITY_CHECK_TIMEOUT,
|
|
||||||
)
|
|
||||||
if isinstance(text_result, str) and text_result:
|
|
||||||
return True, None, None
|
|
||||||
err_code = "INVALID_TEXT"
|
|
||||||
err_reason = "Provider returned invalid text"
|
|
||||||
self._log_reachability_failure(
|
|
||||||
provider, provider_capability_type, err_code, err_reason
|
|
||||||
)
|
|
||||||
return False, err_code, err_reason
|
|
||||||
|
|
||||||
elif provider_capability_type == ProviderType.RERANK:
|
|
||||||
# 测试 Rerank
|
|
||||||
if isinstance(provider, RerankProvider):
|
|
||||||
await asyncio.wait_for(
|
|
||||||
provider.rerank("Apple", documents=["apple", "banana"]),
|
|
||||||
timeout=REACHABILITY_CHECK_TIMEOUT,
|
|
||||||
)
|
|
||||||
return True, None, None
|
|
||||||
err_code = "NOT_RERANK_PROVIDER"
|
|
||||||
err_reason = "Provider is not RerankProvider"
|
|
||||||
self._log_reachability_failure(
|
|
||||||
provider, provider_capability_type, err_code, err_reason
|
|
||||||
)
|
|
||||||
return False, err_code, err_reason
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 其他类型暂时视为通过,或者回退到 get_models
|
|
||||||
if hasattr(provider, "get_models"):
|
|
||||||
await asyncio.wait_for(
|
|
||||||
provider.get_models(), timeout=REACHABILITY_CHECK_TIMEOUT
|
|
||||||
)
|
|
||||||
return True, None, None
|
|
||||||
return True, None, None # 未知类型默认通过
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
err_code = "TIMEOUT"
|
|
||||||
err_reason = "Reachability check timed out"
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
err_code = (
|
err_code = (
|
||||||
getattr(exc, "status_code", None)
|
getattr(exc, "status_code", None)
|
||||||
@@ -159,10 +53,10 @@ class ProviderCommands:
|
|||||||
if not err_code:
|
if not err_code:
|
||||||
err_code = exc.__class__.__name__
|
err_code = exc.__class__.__name__
|
||||||
|
|
||||||
self._log_reachability_failure(
|
self._log_reachability_failure(
|
||||||
provider, provider_capability_type, err_code, err_reason
|
provider, provider_capability_type, err_code, err_reason
|
||||||
)
|
)
|
||||||
return False, err_code, err_reason
|
return False, err_code, err_reason
|
||||||
|
|
||||||
async def provider(
|
async def provider(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from astrbot.api import star
|
|||||||
from astrbot.api.event import AstrMessageEvent
|
from astrbot.api.event import AstrMessageEvent
|
||||||
from astrbot.api.message_components import At, Image, Plain
|
from astrbot.api.message_components import At, Image, Plain
|
||||||
from astrbot.api.platform import MessageType
|
from astrbot.api.platform import MessageType
|
||||||
from astrbot.api.provider import Provider, ProviderRequest
|
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
|
||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -158,8 +158,12 @@ class LongTermMemory:
|
|||||||
cfg = self.cfg(event)
|
cfg = self.cfg(event)
|
||||||
if cfg["enable_active_reply"]:
|
if cfg["enable_active_reply"]:
|
||||||
prompt = req.prompt
|
prompt = req.prompt
|
||||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
req.prompt = (
|
||||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||||
|
f"\nNow, a new message is coming: `{prompt}`. "
|
||||||
|
"Please react to it. Only output your response and do not output any other information. "
|
||||||
|
"You MUST use the SAME language as the chatroom is using."
|
||||||
|
)
|
||||||
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
||||||
else:
|
else:
|
||||||
req.system_prompt += (
|
req.system_prompt += (
|
||||||
@@ -167,13 +171,15 @@ class LongTermMemory:
|
|||||||
)
|
)
|
||||||
req.system_prompt += chats_str
|
req.system_prompt += chats_str
|
||||||
|
|
||||||
async def after_req_llm(self, event: AstrMessageEvent):
|
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
|
||||||
if event.unified_msg_origin not in self.session_chats:
|
if event.unified_msg_origin not in self.session_chats:
|
||||||
return
|
return
|
||||||
|
|
||||||
if event.get_result() and event.get_result().is_llm_result():
|
if llm_resp.completion_text:
|
||||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}"
|
||||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
logger.debug(
|
||||||
|
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
|
||||||
|
)
|
||||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||||
cfg = self.cfg(event)
|
cfg = self.cfg(event)
|
||||||
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ class Main(star.Star):
|
|||||||
|
|
||||||
@filter.on_llm_response()
|
@filter.on_llm_response()
|
||||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||||
"""在 LLM 响应后基于配置注入思考过程文本"""
|
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
|
||||||
umo = event.unified_msg_origin
|
umo = event.unified_msg_origin
|
||||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||||
@@ -331,12 +331,9 @@ class Main(star.Star):
|
|||||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@filter.after_message_sent()
|
|
||||||
async def after_llm_req(self, event: AstrMessageEvent):
|
|
||||||
"""在 LLM 请求后记录对话"""
|
|
||||||
if self.ltm and self.ltm_enabled(event):
|
if self.ltm and self.ltm_enabled(event):
|
||||||
try:
|
try:
|
||||||
await self.ltm.after_req_llm(event)
|
await self.ltm.after_req_llm(event, resp)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"ltm: {e}")
|
logger.error(f"ltm: {e}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user