Compare commits
40 Commits
refactor-2
...
feat/file-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1de377e749 | ||
|
|
6aa6963ab5 | ||
|
|
d3001d8148 | ||
|
|
380c4faf17 | ||
|
|
bd2a88783c | ||
|
|
17d7f822e7 | ||
|
|
0e034f0fbd | ||
|
|
2a7d03f9e1 | ||
|
|
72fac4b9f1 | ||
|
|
38281ba2cf | ||
|
|
21aa3174f4 | ||
|
|
dcda871fc0 | ||
|
|
c13c51f499 | ||
|
|
a130db5cf4 | ||
|
|
7faeb5cea8 | ||
|
|
8d3ff61e0d | ||
|
|
4c03e82570 | ||
|
|
e7e8664ab4 | ||
|
|
1dd1623e7d | ||
|
|
80d8161d58 | ||
|
|
fc80d7d681 | ||
|
|
c2f036b27c | ||
|
|
4087bbb512 | ||
|
|
e1c728582d | ||
|
|
93c69a639a | ||
|
|
a7fdc98b29 | ||
|
|
85b7f104df | ||
|
|
d76d1bd7fe | ||
|
|
df4412aa80 | ||
|
|
ab2c94e19a | ||
|
|
37cc4e2121 | ||
|
|
60dfdd0a66 | ||
|
|
bb8b2cb194 | ||
|
|
4e29684aa3 | ||
|
|
0e17e3553d | ||
|
|
0a55060e89 | ||
|
|
77859c7daa | ||
|
|
ba39c393a0 | ||
|
|
6a50d316d9 | ||
|
|
88c1d77f0b |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -34,6 +34,7 @@ dashboard/node_modules/
|
||||
dashboard/dist/
|
||||
package-lock.json
|
||||
package.json
|
||||
yarn.lock
|
||||
|
||||
# Operating System
|
||||
**/.DS_Store
|
||||
@@ -47,3 +48,5 @@ astrbot.lock
|
||||
chroma
|
||||
venv/*
|
||||
pytest.ini
|
||||
AGENTS.md
|
||||
IFLOW.md
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "3.5.23"
|
||||
__version__ = "4.7.3"
|
||||
|
||||
@@ -345,9 +345,6 @@ class MCPClient:
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources including old exit stacks from reconnections"""
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
|
||||
# Close current exit stack
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
@@ -359,6 +356,9 @@ class MCPClient:
|
||||
# Just clear the list to release references
|
||||
self._old_exit_stacks.clear()
|
||||
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
|
||||
|
||||
class MCPTool(FunctionTool, Generic[TContext]):
|
||||
"""A function tool that calls an MCP service."""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -145,22 +145,39 @@ class Message(BaseModel):
|
||||
"tool",
|
||||
]
|
||||
|
||||
content: str | list[ContentPart]
|
||||
content: str | list[ContentPart] | None = None
|
||||
"""The content of the message."""
|
||||
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
"""The tool calls of the message."""
|
||||
|
||||
tool_call_id: str | None = None
|
||||
"""The ID of the tool call."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content_required(self):
|
||||
# assistant + tool_calls is not None: allow content to be None
|
||||
if self.role == "assistant" and self.tool_calls is not None:
|
||||
return self
|
||||
|
||||
# other all cases: content is required
|
||||
if self.content is None:
|
||||
raise ValueError(
|
||||
"content is required unless role='assistant' and tool_calls is not None"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
|
||||
|
||||
class ToolCallMessageSegment(Message):
|
||||
"""A message segment representing a tool call."""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class UserMessageSegment(Message):
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.6.1"
|
||||
VERSION = "4.7.3"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -73,8 +73,14 @@ DEFAULT_CONFIG = {
|
||||
"coze_agent_runner_provider_id": "",
|
||||
"dashscope_agent_runner_provider_id": "",
|
||||
"unsupported_streaming_strategy": "realtime_segmenting",
|
||||
"reachability_check": False,
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
"file_extract": {
|
||||
"enable": False,
|
||||
"provider": "moonshotai",
|
||||
"moonshotai_api_key": "",
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -90,6 +96,7 @@ DEFAULT_CONFIG = {
|
||||
"group_icl_enable": False,
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_provider_id": "",
|
||||
"active_reply": {
|
||||
"enable": False,
|
||||
"method": "possibility_reply",
|
||||
@@ -2067,6 +2074,20 @@ CONFIG_METADATA_2 = {
|
||||
"tool_call_timeout": {
|
||||
"type": "int",
|
||||
},
|
||||
"file_extract": {
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"type": "bool",
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
},
|
||||
"moonshotai_api_key": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
@@ -2109,6 +2130,9 @@ CONFIG_METADATA_2 = {
|
||||
"image_caption": {
|
||||
"type": "bool",
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"type": "string",
|
||||
},
|
||||
@@ -2398,6 +2422,36 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"file_extract": {
|
||||
"description": "文档解析能力",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.file_extract.enable": {
|
||||
"description": "启用文档解析能力",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.file_extract.provider": {
|
||||
"description": "文档解析提供商",
|
||||
"type": "string",
|
||||
"options": ["moonshotai"],
|
||||
"condition": {
|
||||
"provider_settings.file_extract.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.file_extract.moonshotai_api_key": {
|
||||
"description": "Moonshot AI API Key",
|
||||
"type": "string",
|
||||
"condition": {
|
||||
"provider_settings.file_extract.provider": "moonshotai",
|
||||
"provider_settings.file_extract.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2492,6 +2546,11 @@ CONFIG_METADATA_3 = {
|
||||
"description": "开启 TTS 时同时输出语音和文字内容",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.reachability_check": {
|
||||
"description": "提供商可达性检测",
|
||||
"type": "bool",
|
||||
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
@@ -2785,7 +2844,16 @@ CONFIG_METADATA_3 = {
|
||||
"provider_ltm_settings.image_caption": {
|
||||
"description": "自动理解图片",
|
||||
"type": "bool",
|
||||
"hint": "需要设置默认图片转述模型。",
|
||||
"hint": "需要设置群聊图片转述模型。",
|
||||
},
|
||||
"provider_ltm_settings.image_caption_provider_id": {
|
||||
"description": "群聊图片转述模型",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。",
|
||||
"condition": {
|
||||
"provider_ltm_settings.image_caption": True,
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings.active_reply.enable": {
|
||||
"description": "主动回复",
|
||||
|
||||
@@ -722,7 +722,12 @@ class File(BaseMessageComponent):
|
||||
"""下载文件"""
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
if self.name:
|
||||
name, ext = os.path.splitext(self.name)
|
||||
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
else:
|
||||
filename = f"{uuid.uuid4().hex}"
|
||||
file_path = os.path.join(download_dir, filename)
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from astrbot.core import logger
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.components import File, Image, Reply
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
@@ -22,6 +22,7 @@ from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -56,6 +57,13 @@ class InternalAgentSubStage(Stage):
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||
|
||||
file_extract_conf: dict = settings.get("file_extract", {})
|
||||
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
|
||||
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
|
||||
self.file_extract_msh_api_key: str = file_extract_conf.get(
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -114,6 +122,50 @@ class InternalAgentSubStage(Stage):
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
async def _apply_file_extract(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply file extract to the provider request"""
|
||||
file_paths = []
|
||||
file_names = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
file_paths.append(await comp.get_file())
|
||||
file_names.append(comp.name)
|
||||
elif isinstance(comp, Reply) and comp.chain:
|
||||
for reply_comp in comp.chain:
|
||||
if isinstance(reply_comp, File):
|
||||
file_paths.append(await reply_comp.get_file())
|
||||
file_names.append(reply_comp.name)
|
||||
if not file_paths:
|
||||
return
|
||||
if not req.prompt:
|
||||
req.prompt = "总结一下文件里面讲了什么?"
|
||||
if self.file_extract_prov == "moonshotai":
|
||||
if not self.file_extract_msh_api_key:
|
||||
logger.error("Moonshot AI API key for file extract is not set")
|
||||
return
|
||||
file_contents = await asyncio.gather(
|
||||
*[
|
||||
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
|
||||
for file_path in file_paths
|
||||
]
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
|
||||
return
|
||||
|
||||
# add file extract results to contexts
|
||||
for file_content, file_name in zip(file_contents, file_names):
|
||||
req.contexts.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
@@ -346,6 +398,17 @@ class InternalAgentSubStage(Stage):
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
@@ -356,10 +419,6 @@ class InternalAgentSubStage(Stage):
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
|
||||
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
|
||||
DashscopeAgentRunner,
|
||||
@@ -88,12 +88,15 @@ class ThirdPartyAgentSubStage(Stage):
|
||||
return
|
||||
|
||||
self.prov_cfg: dict = next(
|
||||
(p for p in self.conf["provider"] if p["id"] == self.prov_id),
|
||||
(p for p in astrbot_config["provider"] if p["id"] == self.prov_id),
|
||||
{},
|
||||
)
|
||||
if not self.prov_id or not self.prov_cfg:
|
||||
if not self.prov_id:
|
||||
logger.error("没有填写 Agent Runner 提供商 ID,请前往配置页面配置。")
|
||||
return
|
||||
if not self.prov_cfg:
|
||||
logger.error(
|
||||
"Third Party Agent Runner provider ID is not configured properly."
|
||||
f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
@@ -63,12 +62,5 @@ class ProcessStage(Stage):
|
||||
if (
|
||||
event.get_result() and not event.get_result().is_stopped()
|
||||
) or not event.get_result():
|
||||
# 事件没有终止传播
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||
return
|
||||
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -246,7 +246,13 @@ class AiocqhttpAdapter(Platform):
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
# 检查多个可能的文件名字段
|
||||
file_name = (
|
||||
m["data"].get("file_name", "")
|
||||
or m["data"].get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or "file"
|
||||
)
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
@@ -265,7 +271,14 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
a = File(name="", url=file_url)
|
||||
# 优先从 API 返回值获取文件名,其次从原始消息数据获取
|
||||
file_name = (
|
||||
ret.get("file_name", "")
|
||||
or ret.get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or m["data"].get("file_name", "")
|
||||
)
|
||||
a = File(name=file_name, url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
|
||||
@@ -250,7 +250,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
async def terminate(self):
|
||||
def monkey_patch_close():
|
||||
raise Exception("Graceful shutdown")
|
||||
raise KeyboardInterrupt("Graceful shutdown")
|
||||
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
|
||||
@@ -381,7 +381,9 @@ class TelegramPlatformAdapter(Platform):
|
||||
f"Telegram document file_path is None, cannot save the file {file_name}.",
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
||||
message.message.append(
|
||||
Comp.File(file=file_path, name=file_name, url=file_path)
|
||||
)
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core import astrbot_config, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
@@ -24,6 +24,7 @@ class ProviderManager:
|
||||
db_helper: BaseDatabase,
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.reload_lock = asyncio.Lock()
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
@@ -226,6 +227,7 @@ class ProviderManager:
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
if not provider_config["enable"]:
|
||||
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
||||
return
|
||||
if provider_config.get("provider_type", "") == "agent_runner":
|
||||
return
|
||||
@@ -434,40 +436,46 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
async def reload(self, provider_config: dict):
|
||||
await self.terminate_provider(provider_config["id"])
|
||||
if provider_config["enable"]:
|
||||
await self.load_provider(provider_config)
|
||||
async with self.reload_lock:
|
||||
await self.terminate_provider(provider_config["id"])
|
||||
if provider_config["enable"]:
|
||||
await self.load_provider(provider_config)
|
||||
|
||||
# 和配置文件保持同步
|
||||
config_ids = [provider["id"] for provider in self.providers_config]
|
||||
logger.debug(f"providers in user's config: {config_ids}")
|
||||
for key in list(self.inst_map.keys()):
|
||||
if key not in config_ids:
|
||||
await self.terminate_provider(key)
|
||||
# 和配置文件保持同步
|
||||
self.providers_config = astrbot_config["provider"]
|
||||
config_ids = [provider["id"] for provider in self.providers_config]
|
||||
logger.info(f"providers in user's config: {config_ids}")
|
||||
for key in list(self.inst_map.keys()):
|
||||
if key not in config_ids:
|
||||
await self.terminate_provider(key)
|
||||
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
||||
)
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
||||
)
|
||||
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
||||
)
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif (
|
||||
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
|
||||
):
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
||||
)
|
||||
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
||||
)
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif (
|
||||
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
|
||||
):
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
||||
)
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
@@ -11,6 +12,7 @@ from astrbot.core.provider.entities import (
|
||||
ToolCallsResult,
|
||||
)
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
@@ -43,6 +45,14 @@ class AbstractProvider(abc.ABC):
|
||||
)
|
||||
return meta
|
||||
|
||||
async def test(self) -> bool:
|
||||
"""test the provider is a
|
||||
|
||||
Returns:
|
||||
bool: the provider is available
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class Provider(AbstractProvider):
|
||||
"""Chat Provider"""
|
||||
@@ -165,6 +175,16 @@ class Provider(AbstractProvider):
|
||||
|
||||
return dicts
|
||||
|
||||
async def test(self, timeout: float = 45.0) -> bool:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class STTProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -177,6 +197,20 @@ class STTProvider(AbstractProvider):
|
||||
"""获取音频的文本"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def test(self) -> bool:
|
||||
try:
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"samples",
|
||||
"stt_health_check.wav",
|
||||
)
|
||||
if not os.path.exists(sample_audio_path):
|
||||
return False
|
||||
text_result = await self.get_text(sample_audio_path)
|
||||
return isinstance(text_result, str) and bool(text_result)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class TTSProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -189,6 +223,13 @@ class TTSProvider(AbstractProvider):
|
||||
"""获取文本的音频,返回音频文件路径"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def test(self) -> bool:
|
||||
try:
|
||||
audio_result = await self.get_audio("hi")
|
||||
return isinstance(audio_result, str) and bool(audio_result)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class EmbeddingProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -211,6 +252,15 @@ class EmbeddingProvider(AbstractProvider):
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
async def test(self) -> bool:
|
||||
try:
|
||||
embedding_result = await self.get_embedding("health_check")
|
||||
return isinstance(embedding_result, list) and (
|
||||
not embedding_result or isinstance(embedding_result[0], float)
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_embeddings_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
@@ -294,3 +344,10 @@ class RerankProvider(AbstractProvider):
|
||||
) -> list[RerankResult]:
|
||||
"""获取查询和文档的重排序分数"""
|
||||
...
|
||||
|
||||
async def test(self) -> bool:
|
||||
try:
|
||||
await self.rerank("Apple", documents=["apple", "banana"])
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -290,7 +290,7 @@ class ProviderAnthropic(Provider):
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -111,9 +111,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...",
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
||||
)
|
||||
# logger.error(
|
||||
# f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
||||
# )
|
||||
raise e
|
||||
|
||||
async def _prepare_query_config(
|
||||
|
||||
@@ -433,7 +433,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
payloads.pop("tools", None)
|
||||
return False, chosen_key, available_api_keys, payloads, context_query, None
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
# logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
|
||||
@@ -171,110 +171,3 @@ class SessionServiceManager:
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_session_status(session_id: str, enabled: bool) -> None:
|
||||
"""设置会话的整体启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["session_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_session_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理会话请求(会话整体启停检查)
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_session_enabled(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话命名相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_custom_name(session_id: str) -> str | None:
|
||||
"""获取会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 自定义名称,如果没有设置则返回None
|
||||
|
||||
"""
|
||||
session_services = sp.get(
|
||||
"session_service_config",
|
||||
{},
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
return session_services.get("custom_name")
|
||||
|
||||
@staticmethod
|
||||
def set_session_custom_name(session_id: str, custom_name: str) -> None:
|
||||
"""设置会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
custom_name: 自定义名称,可以为空字符串来清除名称
|
||||
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
if custom_name and custom_name.strip():
|
||||
session_config["custom_name"] = custom_name.strip()
|
||||
else:
|
||||
# 如果传入空名称,则删除自定义名称
|
||||
session_config.pop("custom_name", None)
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_display_name(session_id: str) -> str:
|
||||
"""获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 显示名称
|
||||
|
||||
"""
|
||||
custom_name = SessionServiceManager.get_session_custom_name(session_id)
|
||||
if custom_name:
|
||||
return custom_name
|
||||
|
||||
# 如果没有自定义名称,返回session_id的最后一段
|
||||
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||
|
||||
@@ -42,87 +42,6 @@ class SessionPluginManager:
|
||||
# 如果都没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_plugin_status_for_session(
|
||||
session_id: str,
|
||||
plugin_name: str,
|
||||
enabled: bool,
|
||||
) -> None:
|
||||
"""设置插件在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
enabled: True表示启用,False表示禁用
|
||||
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config",
|
||||
{},
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
if session_id not in session_plugin_config:
|
||||
session_plugin_config[session_id] = {
|
||||
"enabled_plugins": [],
|
||||
"disabled_plugins": [],
|
||||
}
|
||||
|
||||
session_config = session_plugin_config[session_id]
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
if enabled:
|
||||
# 启用插件
|
||||
if plugin_name in disabled_plugins:
|
||||
disabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in enabled_plugins:
|
||||
enabled_plugins.append(plugin_name)
|
||||
else:
|
||||
# 禁用插件
|
||||
if plugin_name in enabled_plugins:
|
||||
enabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in disabled_plugins:
|
||||
disabled_plugins.append(plugin_name)
|
||||
|
||||
# 保存配置
|
||||
session_config["enabled_plugins"] = enabled_plugins
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put(
|
||||
"session_plugin_config",
|
||||
session_plugin_config,
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_plugin_config(session_id: str) -> dict[str, list[str]]:
|
||||
"""获取指定会话的插件配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||
|
||||
"""
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config",
|
||||
{},
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
return session_plugin_config.get(
|
||||
session_id,
|
||||
{"enabled_plugins": [], "disabled_plugins": []},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
|
||||
"""根据会话配置过滤处理器列表
|
||||
|
||||
23
astrbot/core/utils/file_extract.py
Normal file
23
astrbot/core/utils/file_extract.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pathlib import Path
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
async def extract_file_moonshotai(file_path: str, api_key: str) -> str:
|
||||
"""Extract text from a file using Moonshot AI API"""
|
||||
"""
|
||||
Args:
|
||||
file_path: The path to the file to extract text from
|
||||
api_key: The API key to use to extract text from the file
|
||||
Returns:
|
||||
The text extracted from the file
|
||||
"""
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.moonshot.cn/v1",
|
||||
)
|
||||
file_object = await client.files.create(
|
||||
file=Path(file_path),
|
||||
purpose="file-extract", # type: ignore
|
||||
)
|
||||
return (await client.files.content(file_id=file_object.id)).text
|
||||
@@ -18,11 +18,8 @@ from astrbot.core.config.i18n_utils import ConfigMetadataI18n
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_cls_map, platform_registry
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -356,169 +353,26 @@ class ConfigRoute(Route):
|
||||
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
|
||||
)
|
||||
|
||||
if provider_capability_type == ProviderType.CHAT_COMPLETION:
|
||||
try:
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||
timeout=45.0,
|
||||
)
|
||||
logger.debug(
|
||||
f"Received response from {status_info['name']}: {response}",
|
||||
)
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
if (
|
||||
hasattr(response, "completion_text")
|
||||
and response.completion_text
|
||||
):
|
||||
response_text_snippet = (
|
||||
response.completion_text[:70] + "..."
|
||||
if len(response.completion_text) > 70
|
||||
else response.completion_text
|
||||
)
|
||||
elif hasattr(response, "result_chain") and response.result_chain:
|
||||
try:
|
||||
response_text_snippet = (
|
||||
response.result_chain.get_plain_text()[:70] + "..."
|
||||
if len(response.result_chain.get_plain_text()) > 70
|
||||
else response.result_chain.get_plain_text()
|
||||
)
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
|
||||
)
|
||||
else:
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
status_info["error"] = (
|
||||
"Connection timed out after 45 seconds during test call."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
|
||||
)
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
|
||||
)
|
||||
|
||||
elif provider_capability_type == ProviderType.EMBEDDING:
|
||||
try:
|
||||
# For embedding, we can call the get_embedding method with a short prompt.
|
||||
embedding_result = await provider.get_embedding("health_check")
|
||||
if isinstance(embedding_result, list) and (
|
||||
not embedding_result or isinstance(embedding_result[0], float)
|
||||
):
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"Embedding test failed: unexpected result type {type(embedding_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing embedding provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Embedding test failed: {e!s}"
|
||||
|
||||
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
|
||||
try:
|
||||
# For TTS, we can call the get_audio method with a short prompt.
|
||||
audio_result = await provider.get_audio("你好")
|
||||
if isinstance(audio_result, str) and audio_result:
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"TTS test failed: unexpected result type {type(audio_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing TTS provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"TTS test failed: {e!s}"
|
||||
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
|
||||
try:
|
||||
logger.debug(
|
||||
f"Sending health check audio to provider: {status_info['name']}",
|
||||
)
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"samples",
|
||||
"stt_health_check.wav",
|
||||
)
|
||||
if not os.path.exists(sample_audio_path):
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
"STT test failed: sample audio file not found."
|
||||
)
|
||||
logger.warning(
|
||||
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
|
||||
)
|
||||
else:
|
||||
text_result = await provider.get_text(sample_audio_path)
|
||||
if isinstance(text_result, str) and text_result:
|
||||
status_info["status"] = "available"
|
||||
snippet = (
|
||||
text_result[:70] + "..."
|
||||
if len(text_result) > 70
|
||||
else text_result
|
||||
)
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
|
||||
)
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"STT test failed: unexpected result type {type(text_result)}"
|
||||
)
|
||||
logger.warning(
|
||||
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing STT provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"STT test failed: {e!s}"
|
||||
elif provider_capability_type == ProviderType.RERANK:
|
||||
try:
|
||||
assert isinstance(provider, RerankProvider)
|
||||
await provider.rerank("Apple", documents=["apple", "banana"])
|
||||
try:
|
||||
result = await provider.test()
|
||||
if result:
|
||||
status_info["status"] = "available"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing rerank provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Rerank test failed: {e!s}"
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}",
|
||||
else:
|
||||
status_info["error"] = "Provider test returned False."
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.",
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
|
||||
)
|
||||
status_info["status"] = "available"
|
||||
status_info["error"] = (
|
||||
"This provider type is not tested and is assumed to be available."
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
|
||||
)
|
||||
|
||||
return status_info
|
||||
|
||||
@@ -60,10 +60,6 @@ class KnowledgeBaseRoute(Route):
|
||||
# "/kb/media/delete": ("POST", self.delete_media),
|
||||
# 检索
|
||||
"/kb/retrieve": ("POST", self.retrieve),
|
||||
# 会话知识库配置
|
||||
"/kb/session/config/get": ("GET", self.get_session_kb_config),
|
||||
"/kb/session/config/set": ("POST", self.set_session_kb_config),
|
||||
"/kb/session/config/delete": ("POST", self.delete_session_kb_config),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
@@ -920,158 +916,6 @@ class KnowledgeBaseRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"检索失败: {e!s}").__dict__
|
||||
|
||||
# ===== 会话知识库配置 API =====
|
||||
|
||||
async def get_session_kb_config(self):
|
||||
"""获取会话的知识库配置
|
||||
|
||||
Query 参数:
|
||||
- session_id: 会话 ID (必填)
|
||||
|
||||
返回:
|
||||
- kb_ids: 知识库 ID 列表
|
||||
- top_k: 返回结果数量
|
||||
- enable_rerank: 是否启用重排序
|
||||
"""
|
||||
try:
|
||||
from astrbot.core import sp
|
||||
|
||||
session_id = request.args.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少参数 session_id").__dict__
|
||||
|
||||
# 从 SharedPreferences 获取配置
|
||||
config = await sp.session_get(session_id, "kb_config", default={})
|
||||
|
||||
logger.debug(f"[KB配置] 读取到配置: session_id={session_id}")
|
||||
|
||||
# 如果没有配置,返回默认值
|
||||
if not config:
|
||||
config = {"kb_ids": [], "top_k": 5, "enable_rerank": True}
|
||||
|
||||
return Response().ok(config).__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True)
|
||||
return Response().error(f"获取会话知识库配置失败: {e!s}").__dict__
|
||||
|
||||
async def set_session_kb_config(self):
|
||||
"""设置会话的知识库配置
|
||||
|
||||
Body:
|
||||
- scope: 配置范围 (目前只支持 "session")
|
||||
- scope_id: 会话 ID (必填)
|
||||
- kb_ids: 知识库 ID 列表 (必填)
|
||||
- top_k: 返回结果数量 (可选, 默认 5)
|
||||
- enable_rerank: 是否启用重排序 (可选, 默认 true)
|
||||
"""
|
||||
try:
|
||||
from astrbot.core import sp
|
||||
|
||||
data = await request.json
|
||||
|
||||
scope = data.get("scope")
|
||||
scope_id = data.get("scope_id")
|
||||
kb_ids = data.get("kb_ids", [])
|
||||
top_k = data.get("top_k", 5)
|
||||
enable_rerank = data.get("enable_rerank", True)
|
||||
|
||||
# 验证参数
|
||||
if scope != "session":
|
||||
return Response().error("目前仅支持 session 范围的配置").__dict__
|
||||
|
||||
if not scope_id:
|
||||
return Response().error("缺少参数 scope_id").__dict__
|
||||
|
||||
if not isinstance(kb_ids, list):
|
||||
return Response().error("kb_ids 必须是列表").__dict__
|
||||
|
||||
# 验证知识库是否存在
|
||||
kb_mgr = self._get_kb_manager()
|
||||
invalid_ids = []
|
||||
valid_ids = []
|
||||
for kb_id in kb_ids:
|
||||
kb_helper = await kb_mgr.get_kb(kb_id)
|
||||
if kb_helper:
|
||||
valid_ids.append(kb_id)
|
||||
else:
|
||||
invalid_ids.append(kb_id)
|
||||
logger.warning(f"[KB配置] 知识库不存在: {kb_id}")
|
||||
|
||||
if invalid_ids:
|
||||
logger.warning(f"[KB配置] 以下知识库ID无效: {invalid_ids}")
|
||||
|
||||
# 允许保存空列表,表示明确不使用任何知识库
|
||||
if kb_ids and not valid_ids:
|
||||
# 只有当用户提供了 kb_ids 但全部无效时才报错
|
||||
return Response().error(f"所有提供的知识库ID都无效: {kb_ids}").__dict__
|
||||
|
||||
# 如果 kb_ids 为空列表,表示用户想清空配置
|
||||
if not kb_ids:
|
||||
valid_ids = []
|
||||
|
||||
# 构建配置对象(只保存有效的ID)
|
||||
config = {
|
||||
"kb_ids": valid_ids,
|
||||
"top_k": top_k,
|
||||
"enable_rerank": enable_rerank,
|
||||
}
|
||||
|
||||
# 保存到 SharedPreferences
|
||||
await sp.session_put(scope_id, "kb_config", config)
|
||||
|
||||
# 立即验证是否保存成功
|
||||
verify_config = await sp.session_get(scope_id, "kb_config", default={})
|
||||
|
||||
if verify_config == config:
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{"valid_ids": valid_ids, "invalid_ids": invalid_ids},
|
||||
"保存知识库配置成功",
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
logger.error("[KB配置] 配置保存失败,验证不匹配")
|
||||
return Response().error("配置保存失败").__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True)
|
||||
return Response().error(f"设置会话知识库配置失败: {e!s}").__dict__
|
||||
|
||||
async def delete_session_kb_config(self):
|
||||
"""删除会话的知识库配置
|
||||
|
||||
Body:
|
||||
- scope: 配置范围 (目前只支持 "session")
|
||||
- scope_id: 会话 ID (必填)
|
||||
"""
|
||||
try:
|
||||
from astrbot.core import sp
|
||||
|
||||
data = await request.json
|
||||
|
||||
scope = data.get("scope")
|
||||
scope_id = data.get("scope_id")
|
||||
|
||||
# 验证参数
|
||||
if scope != "session":
|
||||
return Response().error("目前仅支持 session 范围的配置").__dict__
|
||||
|
||||
if not scope_id:
|
||||
return Response().error("缺少参数 scope_id").__dict__
|
||||
|
||||
# 从 SharedPreferences 删除配置
|
||||
await sp.session_remove(scope_id, "kb_config")
|
||||
|
||||
return Response().ok(message="删除知识库配置成功").__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除会话知识库配置失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除会话知识库配置失败: {e!s}").__dict__
|
||||
|
||||
async def upload_document_from_url(self):
|
||||
"""从 URL 上传文档
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
@@ -19,6 +20,10 @@ from astrbot.core.star.star_manager import PluginManager
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
PLUGIN_UPDATE_CONCURRENCY = (
|
||||
3 # limit concurrent updates to avoid overwhelming plugin sources
|
||||
)
|
||||
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(
|
||||
@@ -33,6 +38,7 @@ class PluginRoute(Route):
|
||||
"/plugin/install": ("POST", self.install_plugin),
|
||||
"/plugin/install-upload": ("POST", self.install_plugin_upload),
|
||||
"/plugin/update": ("POST", self.update_plugin),
|
||||
"/plugin/update-all": ("POST", self.update_all_plugins),
|
||||
"/plugin/uninstall": ("POST", self.uninstall_plugin),
|
||||
"/plugin/market_list": ("GET", self.get_online_plugins),
|
||||
"/plugin/off": ("POST", self.off_plugin),
|
||||
@@ -63,7 +69,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
data = await request.json
|
||||
data = await request.get_json()
|
||||
plugin_name = data.get("name", None)
|
||||
try:
|
||||
success, message = await self.plugin_manager.reload(plugin_name)
|
||||
@@ -346,7 +352,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
repo_url = post_data["url"]
|
||||
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
@@ -393,7 +399,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
plugin_name = post_data["name"]
|
||||
delete_config = post_data.get("delete_config", False)
|
||||
delete_data = post_data.get("delete_data", False)
|
||||
@@ -418,7 +424,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
plugin_name = post_data["name"]
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
try:
|
||||
@@ -432,6 +438,59 @@ class PluginRoute(Route):
|
||||
logger.error(f"/api/plugin/update: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def update_all_plugins(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.get_json()
|
||||
plugin_names: list[str] = post_data.get("names") or []
|
||||
proxy: str = post_data.get("proxy", "")
|
||||
|
||||
if not isinstance(plugin_names, list) or not plugin_names:
|
||||
return Response().error("插件列表不能为空").__dict__
|
||||
|
||||
results = []
|
||||
sem = asyncio.Semaphore(PLUGIN_UPDATE_CONCURRENCY)
|
||||
|
||||
async def _update_one(name: str):
|
||||
async with sem:
|
||||
try:
|
||||
logger.info(f"批量更新插件 {name}")
|
||||
await self.plugin_manager.update_plugin(name, proxy)
|
||||
return {"name": name, "status": "ok", "message": "更新成功"}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/plugin/update-all: 更新插件 {name} 失败: {traceback.format_exc()}",
|
||||
)
|
||||
return {"name": name, "status": "error", "message": str(e)}
|
||||
|
||||
raw_results = await asyncio.gather(
|
||||
*(_update_one(name) for name in plugin_names),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for name, result in zip(plugin_names, raw_results):
|
||||
if isinstance(result, asyncio.CancelledError):
|
||||
raise result
|
||||
if isinstance(result, BaseException):
|
||||
results.append(
|
||||
{"name": name, "status": "error", "message": str(result)}
|
||||
)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
failed = [r for r in results if r["status"] == "error"]
|
||||
message = (
|
||||
"批量更新完成,全部成功。"
|
||||
if not failed
|
||||
else f"批量更新完成,其中 {len(failed)}/{len(results)} 个插件失败。"
|
||||
)
|
||||
|
||||
return Response().ok({"results": results}, message).__dict__
|
||||
|
||||
async def off_plugin(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
@@ -440,7 +499,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
await self.plugin_manager.turn_off_plugin(plugin_name)
|
||||
@@ -458,7 +517,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
await self.plugin_manager.turn_on_plugin(plugin_name)
|
||||
|
||||
@@ -74,7 +74,10 @@ class SessionManagementRoute(Route):
|
||||
umo_id = pref.scope_id
|
||||
if umo_id not in umo_rules:
|
||||
umo_rules[umo_id] = {}
|
||||
umo_rules[umo_id][pref.key] = pref.value["val"]
|
||||
if pref.key == "session_plugin_config" and umo_id in pref.value["val"]:
|
||||
umo_rules[umo_id][pref.key] = pref.value["val"][umo_id]
|
||||
else:
|
||||
umo_rules[umo_id][pref.key] = pref.value["val"]
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
@@ -185,6 +188,35 @@ class SessionManagementRoute(Route):
|
||||
for p in provider_manager.tts_provider_insts
|
||||
]
|
||||
|
||||
# 获取可用的插件列表(排除 reserved 的系统插件)
|
||||
plugin_manager = self.core_lifecycle.plugin_manager
|
||||
available_plugins = [
|
||||
{
|
||||
"name": p.name,
|
||||
"display_name": p.display_name or p.name,
|
||||
"desc": p.desc,
|
||||
}
|
||||
for p in plugin_manager.context.get_all_stars()
|
||||
if not p.reserved and p.name
|
||||
]
|
||||
|
||||
# 获取可用的知识库列表
|
||||
available_kbs = []
|
||||
kb_manager = self.core_lifecycle.kb_manager
|
||||
if kb_manager:
|
||||
try:
|
||||
kbs = await kb_manager.list_kbs()
|
||||
available_kbs = [
|
||||
{
|
||||
"kb_id": kb.kb_id,
|
||||
"kb_name": kb.kb_name,
|
||||
"emoji": kb.emoji,
|
||||
}
|
||||
for kb in kbs
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"获取知识库列表失败: {e!s}")
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
@@ -197,6 +229,8 @@ class SessionManagementRoute(Route):
|
||||
"available_chat_providers": available_chat_providers,
|
||||
"available_stt_providers": available_stt_providers,
|
||||
"available_tts_providers": available_tts_providers,
|
||||
"available_plugins": available_plugins,
|
||||
"available_kbs": available_kbs,
|
||||
"available_rule_keys": AVAILABLE_SESSION_RULE_KEYS,
|
||||
}
|
||||
)
|
||||
@@ -229,6 +263,11 @@ class SessionManagementRoute(Route):
|
||||
if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
|
||||
return Response().error(f"不支持的规则键: {rule_key}").__dict__
|
||||
|
||||
if rule_key == "session_plugin_config":
|
||||
rule_value = {
|
||||
umo: rule_value,
|
||||
}
|
||||
|
||||
# 使用 shared preferences 更新规则
|
||||
await sp.session_put(umo, rule_key, rule_value)
|
||||
|
||||
|
||||
18
changelogs/v4.7.0.md
Normal file
18
changelogs/v4.7.0.md
Normal file
@@ -0,0 +1,18 @@
|
||||
## What's Changed
|
||||
|
||||
重构:
|
||||
- 将 Dify、Coze、阿里云百炼应用等 LLMOps 提供商迁移到 Agent 执行器层,理清和本地 Agent 执行器的边界
|
||||
- 将「会话管理」功能重构为「自定义规则」功能,理清和多配置文件功能的边界。详见:[自定义规则](https://docs.astrbot.app/use/custom-rules.html)
|
||||
|
||||
优化:
|
||||
- Dify、阿里云百炼应用支持流式输出
|
||||
- 防止分段回复正则表达式解析错误导致消息不发送
|
||||
- 群聊上下文感知记录 At 信息
|
||||
- 优化模型提供商页面的测试提供商功能
|
||||
|
||||
新增:
|
||||
- 支持在配置文件页面快速测试对话
|
||||
- 为配置文件配置项内容添加国际化支持
|
||||
|
||||
修复:
|
||||
- 在更新 MCP Server 配置后,MCP 无法正常重启的问题
|
||||
22
changelogs/v4.7.1.md
Normal file
22
changelogs/v4.7.1.md
Normal file
@@ -0,0 +1,22 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复了自定义规则页面无法设置插件和知识库的规则的问题
|
||||
|
||||
---
|
||||
|
||||
重构:
|
||||
- 将 Dify、Coze、阿里云百炼应用等 LLMOps 提供商迁移到 Agent 执行器层,理清和本地 Agent 执行器的边界。详见:[Agent 执行器](https://docs.astrbot.app/use/agent-runner.html)
|
||||
- 将「会话管理」功能重构为「自定义规则」功能,理清和多配置文件功能的边界。详见:[自定义规则](https://docs.astrbot.app/use/custom-rules.html)
|
||||
|
||||
优化:
|
||||
- Dify、阿里云百炼应用支持流式输出
|
||||
- 防止分段回复正则表达式解析错误导致消息不发送
|
||||
- 群聊上下文感知记录 At 信息
|
||||
- 优化模型提供商页面的测试提供商功能
|
||||
|
||||
新增:
|
||||
- 支持在配置文件页面快速测试对话
|
||||
- 为配置文件配置项内容添加国际化支持
|
||||
|
||||
修复:
|
||||
- 在更新 MCP Server 配置后,MCP 无法正常重启的问题
|
||||
25
changelogs/v4.7.3.md
Normal file
25
changelogs/v4.7.3.md
Normal file
@@ -0,0 +1,25 @@
|
||||
## What's Changed
|
||||
|
||||
1. 修复使用非默认配置文件情况下时,第三方 Agent Runner (Dify、Coze、阿里云百炼应用等)无法正常工作的问题
|
||||
2. 修复当“聊天模型”未设置,并且模型提供商中仅有 Agent Runner 时,无法正常使用 Agent Runner 的问题
|
||||
3. 修复部分情况下报错 `pydantic_core._pydantic_core.ValidationError: 1 validation error for Message content` 的问题
|
||||
4. 新增群聊模式下的专用图片转述模型配置 ([#3822](https://github.com/AstrBotDevs/AstrBot/issues/3822))
|
||||
|
||||
---
|
||||
|
||||
重构:
|
||||
- 将 Dify、Coze、阿里云百炼应用等 LLMOps 提供商迁移到 Agent 执行器层,理清和本地 Agent 执行器的边界。详见:[Agent 执行器](https://docs.astrbot.app/use/agent-runner.html)
|
||||
- 将「会话管理」功能重构为「自定义规则」功能,理清和多配置文件功能的边界。详见:[自定义规则](https://docs.astrbot.app/use/custom-rules.html)
|
||||
|
||||
优化:
|
||||
- Dify、阿里云百炼应用支持流式输出
|
||||
- 防止分段回复正则表达式解析错误导致消息不发送
|
||||
- 群聊上下文感知记录 At 信息
|
||||
- 优化模型提供商页面的测试提供商功能
|
||||
|
||||
新增:
|
||||
- 支持在配置文件页面快速测试对话
|
||||
- 为配置文件配置项内容添加国际化支持
|
||||
|
||||
修复:
|
||||
- 在更新 MCP Server 配置后,MCP 无法正常重启的问题
|
||||
@@ -84,7 +84,7 @@
|
||||
v-model:prompt="prompt"
|
||||
:stagedImagesUrl="stagedImagesUrl"
|
||||
:stagedAudioUrl="stagedAudioUrl"
|
||||
:disabled="isStreaming || isConvRunning"
|
||||
:disabled="isStreaming"
|
||||
:enableStreaming="enableStreaming"
|
||||
:isRecording="isRecording"
|
||||
:session-id="currSessionId || null"
|
||||
|
||||
@@ -549,7 +549,7 @@ export default {
|
||||
}
|
||||
|
||||
.bot-embedded-image {
|
||||
max-width: 80%;
|
||||
max-width: 40%;
|
||||
width: auto;
|
||||
height: auto;
|
||||
border-radius: 8px;
|
||||
@@ -558,10 +558,6 @@ export default {
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.bot-embedded-image:hover {
|
||||
transform: scale(1.02);
|
||||
}
|
||||
|
||||
.embedded-audio {
|
||||
width: 300px;
|
||||
margin-top: 8px;
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
v-model:prompt="prompt"
|
||||
:stagedImagesUrl="stagedImagesUrl"
|
||||
:stagedAudioUrl="stagedAudioUrl"
|
||||
:disabled="isStreaming || isConvRunning"
|
||||
:disabled="isStreaming"
|
||||
:enableStreaming="enableStreaming"
|
||||
:isRecording="isRecording"
|
||||
:session-id="currSessionId || null"
|
||||
|
||||
@@ -7,8 +7,8 @@ import { useCommonStore } from '@/stores/common';
|
||||
<!-- 添加筛选级别控件 -->
|
||||
<div class="filter-controls mb-2" v-if="showLevelBtns">
|
||||
<v-chip-group v-model="selectedLevels" column multiple>
|
||||
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter
|
||||
:text-color="level === 'DEBUG' || level === 'INFO' ? 'black' : 'white'">
|
||||
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter variant="flat" size="small"
|
||||
:text-color="level === 'DEBUG' || level === 'INFO' ? 'black' : 'white'" class="font-weight-medium">
|
||||
{{ level }}
|
||||
</v-chip>
|
||||
</v-chip-group>
|
||||
@@ -168,6 +168,7 @@ export default {
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
margin-bottom: 8px;
|
||||
margin-left: 20px;
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
|
||||
@@ -109,6 +109,22 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"file_extract": {
|
||||
"description": "File Extract",
|
||||
"provider_settings": {
|
||||
"file_extract": {
|
||||
"enable": {
|
||||
"description": "Enable File Extract"
|
||||
},
|
||||
"provider": {
|
||||
"description": "File Extract Provider"
|
||||
},
|
||||
"moonshotai_api_key": {
|
||||
"description": "Moonshot AI API Key"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "Other Settings",
|
||||
"provider_settings": {
|
||||
@@ -159,6 +175,10 @@
|
||||
"prompt_prefix": {
|
||||
"description": "User Prompt",
|
||||
"hint": "You can use {{prompt}} as a placeholder for user input. If no placeholder is provided, it will be added before the user input."
|
||||
},
|
||||
"reachability_check": {
|
||||
"description": "Provider Reachability Check",
|
||||
"hint": "When running the /provider command, test provider connectivity in parallel. This actively pings models and may consume extra tokens."
|
||||
}
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
@@ -379,7 +399,11 @@
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "Auto-understand Images",
|
||||
"hint": "Requires setting a default image caption model."
|
||||
"hint": "Requires setting a group chat image caption model."
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"description": "Group Chat Image Caption Model",
|
||||
"hint": "Used for image understanding in group chat context awareness, configured separately from the default image caption model."
|
||||
},
|
||||
"active_reply": {
|
||||
"enable": {
|
||||
@@ -449,4 +473,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,8 @@
|
||||
"actions": "Actions",
|
||||
"back": "Back",
|
||||
"selectFile": "Select File",
|
||||
"refresh": "Refresh"
|
||||
"refresh": "Refresh",
|
||||
"updateAll": "Update All"
|
||||
},
|
||||
"status": {
|
||||
"enabled": "Enabled",
|
||||
@@ -141,7 +142,9 @@
|
||||
"confirmDelete": "Are you sure you want to delete this extension?",
|
||||
"fillUrlOrFile": "Please fill in extension URL or upload extension file",
|
||||
"dontFillBoth": "Please don't fill in both extension URL and upload file",
|
||||
"supportedFormats": "Supports .zip extension files"
|
||||
"supportedFormats": "Supports .zip extension files",
|
||||
"updateAllSuccess": "All upgradable extensions have been updated!",
|
||||
"updateAllFailed": "{failed} of {total} extensions failed to update:"
|
||||
},
|
||||
"upload": {
|
||||
"fromFile": "Install from File",
|
||||
|
||||
@@ -73,6 +73,17 @@
|
||||
"title": "Persona Configuration",
|
||||
"selectPersona": "Select Persona",
|
||||
"hint": "Persona settings affect the conversation style and behavior of the LLM"
|
||||
},
|
||||
"pluginConfig": {
|
||||
"title": "Plugin Configuration",
|
||||
"disabledPlugins": "Disabled Plugins",
|
||||
"hint": "Select plugins to disable for this session. Unselected plugins will remain enabled."
|
||||
},
|
||||
"kbConfig": {
|
||||
"title": "Knowledge Base Configuration",
|
||||
"selectKbs": "Select Knowledge Bases",
|
||||
"topK": "Top K Results",
|
||||
"enableRerank": "Enable Reranking"
|
||||
}
|
||||
},
|
||||
"deleteConfirm": {
|
||||
|
||||
@@ -11,7 +11,12 @@
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"description": "执行器",
|
||||
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"]
|
||||
"labels": [
|
||||
"内置 Agent",
|
||||
"Dify",
|
||||
"Coze",
|
||||
"阿里云百炼应用"
|
||||
]
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent 执行器提供商 ID"
|
||||
@@ -109,6 +114,22 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"file_extract": {
|
||||
"description": "文档解析能力",
|
||||
"provider_settings": {
|
||||
"file_extract": {
|
||||
"enable": {
|
||||
"description": "启用文档解析能力"
|
||||
},
|
||||
"provider": {
|
||||
"description": "文档解析提供商"
|
||||
},
|
||||
"moonshotai_api_key": {
|
||||
"description": "Moonshot AI API Key"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"provider_settings": {
|
||||
@@ -142,7 +163,10 @@
|
||||
"unsupported_streaming_strategy": {
|
||||
"description": "不支持流式回复的平台",
|
||||
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
|
||||
"labels": ["实时分段回复", "关闭流式回复"]
|
||||
"labels": [
|
||||
"实时分段回复",
|
||||
"关闭流式回复"
|
||||
]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
@@ -159,6 +183,10 @@
|
||||
"prompt_prefix": {
|
||||
"description": "用户提示词",
|
||||
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。"
|
||||
},
|
||||
"reachability_check": {
|
||||
"description": "提供商可达性检测",
|
||||
"hint": "/provider 命令列出模型时并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。"
|
||||
}
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
@@ -379,7 +407,11 @@
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "自动理解图片",
|
||||
"hint": "需要设置默认图片转述模型。"
|
||||
"hint": "需要设置群聊图片转述模型。"
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"description": "群聊图片转述模型",
|
||||
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。"
|
||||
},
|
||||
"active_reply": {
|
||||
"enable": {
|
||||
@@ -449,4 +481,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,7 +32,8 @@
|
||||
"actions": "操作",
|
||||
"back": "返回",
|
||||
"selectFile": "选择文件",
|
||||
"refresh": "刷新"
|
||||
"refresh": "刷新",
|
||||
"updateAll": "更新全部插件"
|
||||
},
|
||||
"status": {
|
||||
"enabled": "启用",
|
||||
@@ -141,7 +142,9 @@
|
||||
"confirmDelete": "确定要删除插件吗?",
|
||||
"fillUrlOrFile": "请填写插件链接或上传插件文件",
|
||||
"dontFillBoth": "请不要同时填写插件链接和上传文件",
|
||||
"supportedFormats": "支持 .zip 格式的插件文件"
|
||||
"supportedFormats": "支持 .zip 格式的插件文件",
|
||||
"updateAllSuccess": "所有可更新的插件都已更新!",
|
||||
"updateAllFailed": "有 {failed}/{total} 个插件更新失败:"
|
||||
},
|
||||
"upload": {
|
||||
"fromFile": "从文件安装",
|
||||
|
||||
@@ -73,6 +73,17 @@
|
||||
"title": "人格配置",
|
||||
"selectPersona": "选择人格",
|
||||
"hint": "应用人格配置后,将会强制该来源的所有对话使用该人格。"
|
||||
},
|
||||
"pluginConfig": {
|
||||
"title": "插件配置",
|
||||
"disabledPlugins": "禁用的插件",
|
||||
"hint": "选择要在此会话中禁用的插件。未选择的插件将保持启用状态。"
|
||||
},
|
||||
"kbConfig": {
|
||||
"title": "知识库配置",
|
||||
"selectKbs": "选择知识库",
|
||||
"topK": "返回结果数量 (Top K)",
|
||||
"enableRerank": "启用重排序"
|
||||
}
|
||||
},
|
||||
"deleteConfirm": {
|
||||
|
||||
@@ -42,6 +42,7 @@ const loadingDialog = reactive({
|
||||
const showPluginInfoDialog = ref(false);
|
||||
const selectedPlugin = ref({});
|
||||
const curr_namespace = ref("");
|
||||
const updatingAll = ref(false);
|
||||
|
||||
const readmeDialog = reactive({
|
||||
show: false,
|
||||
@@ -226,6 +227,10 @@ const paginatedPlugins = computed(() => {
|
||||
return sortedPlugins.value.slice(start, end);
|
||||
});
|
||||
|
||||
const updatableExtensions = computed(() => {
|
||||
return extension_data?.data?.filter(ext => ext.has_update) || [];
|
||||
});
|
||||
|
||||
// 方法
|
||||
const toggleShowReserved = () => {
|
||||
showReserved.value = !showReserved.value;
|
||||
@@ -372,6 +377,56 @@ const updateExtension = async (extension_name) => {
|
||||
}
|
||||
};
|
||||
|
||||
const updateAllExtensions = async () => {
|
||||
if (updatingAll.value || updatableExtensions.value.length === 0) return;
|
||||
updatingAll.value = true;
|
||||
loadingDialog.title = tm('status.loading');
|
||||
loadingDialog.statusCode = 0;
|
||||
loadingDialog.result = "";
|
||||
loadingDialog.show = true;
|
||||
|
||||
const targets = updatableExtensions.value.map(ext => ext.name);
|
||||
try {
|
||||
const res = await axios.post('/api/plugin/update-all', {
|
||||
names: targets,
|
||||
proxy: localStorage.getItem('selectedGitHubProxy') || ""
|
||||
});
|
||||
|
||||
if (res.data.status === "error") {
|
||||
onLoadingDialogResult(2, res.data.message || tm('messages.updateAllFailed', {
|
||||
failed: targets.length,
|
||||
total: targets.length
|
||||
}), -1);
|
||||
return;
|
||||
}
|
||||
|
||||
const results = res.data.data?.results || [];
|
||||
const failures = results.filter(r => r.status !== 'ok');
|
||||
try {
|
||||
await getExtensions();
|
||||
} catch (err) {
|
||||
const errorMsg = err.response?.data?.message || err.message || String(err);
|
||||
failures.push({ name: 'refresh', status: 'error', message: errorMsg });
|
||||
}
|
||||
|
||||
if (failures.length === 0) {
|
||||
onLoadingDialogResult(1, tm('messages.updateAllSuccess'));
|
||||
} else {
|
||||
const failureText = tm('messages.updateAllFailed', {
|
||||
failed: failures.length,
|
||||
total: targets.length
|
||||
});
|
||||
const detail = failures.map(f => `${f.name}: ${f.message}`).join('\n');
|
||||
onLoadingDialogResult(2, `${failureText}\n${detail}`, -1);
|
||||
}
|
||||
} catch (err) {
|
||||
const errorMsg = err.response?.data?.message || err.message || String(err);
|
||||
onLoadingDialogResult(2, errorMsg, -1);
|
||||
} finally {
|
||||
updatingAll.value = false;
|
||||
}
|
||||
};
|
||||
|
||||
const pluginOn = async (extension) => {
|
||||
try {
|
||||
const res = await axios.post('/api/plugin/on', { name: extension.name });
|
||||
@@ -720,6 +775,12 @@ watch(marketSearch, (newVal) => {
|
||||
{{ showReserved ? tm('buttons.hideSystemPlugins') : tm('buttons.showSystemPlugins') }}
|
||||
</v-btn>
|
||||
|
||||
<v-btn class="ml-2" color="warning" variant="tonal" :disabled="updatableExtensions.length === 0"
|
||||
:loading="updatingAll" @click="updateAllExtensions">
|
||||
<v-icon>mdi-update</v-icon>
|
||||
{{ tm('buttons.updateAll') }}
|
||||
</v-btn>
|
||||
|
||||
<v-btn class="ml-2" color="primary" variant="tonal" @click="dialog = true">
|
||||
<v-icon>mdi-plus</v-icon>
|
||||
{{ tm('buttons.install') }}
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
</div>
|
||||
|
||||
<!-- 日志部分 -->
|
||||
<v-card elevation="0" class="mt-4">
|
||||
<v-card elevation="0" class="mt-4 mb-10">
|
||||
<v-card-title class="d-flex align-center py-3 px-4">
|
||||
<v-icon class="me-2">mdi-console-line</v-icon>
|
||||
<span class="text-h4">{{ tm('logs.title') }}</span>
|
||||
@@ -233,5 +233,6 @@ export default {
|
||||
.platform-page {
|
||||
padding: 20px;
|
||||
padding-top: 8px;
|
||||
padding-bottom: 40px;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -69,6 +69,25 @@
|
||||
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
|
||||
@copy="copyProvider" :show-copy-button="true">
|
||||
<template #item-details="{ item }">
|
||||
<!-- 测试状态 chip -->
|
||||
<v-tooltip v-if="getProviderStatus(item.id)" location="top" max-width="300">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-chip v-bind="props" :color="getStatusColor(getProviderStatus(item.id).status)" size="small">
|
||||
<v-icon start size="small">
|
||||
{{ getProviderStatus(item.id).status === 'available' ? 'mdi-check-circle' :
|
||||
getProviderStatus(item.id).status === 'unavailable' ? 'mdi-alert-circle' :
|
||||
'mdi-clock-outline' }}
|
||||
</v-icon>
|
||||
{{ getStatusText(getProviderStatus(item.id).status) }}
|
||||
</v-chip>
|
||||
</template>
|
||||
<span v-if="getProviderStatus(item.id).status === 'unavailable'">
|
||||
{{ getProviderStatus(item.id).error }}
|
||||
</span>
|
||||
<span v-else>{{ getStatusText(getProviderStatus(item.id).status) }}</span>
|
||||
</v-tooltip>
|
||||
</template>
|
||||
<template #actions="{ item }">
|
||||
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
|
||||
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
|
||||
@@ -96,75 +115,40 @@
|
||||
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
|
||||
@copy="copyProvider" :show-copy-button="true">
|
||||
|
||||
<template #item-details="{ item }">
|
||||
<!-- 测试状态 chip -->
|
||||
<v-tooltip v-if="getProviderStatus(item.id)" location="top" max-width="300">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-chip v-bind="props" :color="getStatusColor(getProviderStatus(item.id).status)" size="small">
|
||||
<v-icon start size="small">
|
||||
{{ getProviderStatus(item.id).status === 'available' ? 'mdi-check-circle' :
|
||||
getProviderStatus(item.id).status === 'unavailable' ? 'mdi-alert-circle' :
|
||||
'mdi-clock-outline' }}
|
||||
</v-icon>
|
||||
{{ getStatusText(getProviderStatus(item.id).status) }}
|
||||
</v-chip>
|
||||
</template>
|
||||
<span v-if="getProviderStatus(item.id).status === 'unavailable'">
|
||||
{{ getProviderStatus(item.id).error }}
|
||||
</span>
|
||||
<span v-else>{{ getStatusText(getProviderStatus(item.id).status) }}</span>
|
||||
</v-tooltip>
|
||||
</template>
|
||||
<template #actions="{ item }">
|
||||
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
|
||||
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
|
||||
{{ tm('availability.test') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<template v-slot:details="{ item }">
|
||||
</template>
|
||||
</item-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- 供应商状态部分 -->
|
||||
<v-card elevation="0" class="mt-4">
|
||||
<v-card-title class="d-flex align-center py-3 px-4">
|
||||
<v-icon class="me-2">mdi-heart-pulse</v-icon>
|
||||
<span class="text-h4">{{ tm('availability.title') }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="primary" variant="tonal" :loading="testingProviders.length > 0" @click="fetchProviderStatus">
|
||||
<v-icon left>mdi-refresh</v-icon>
|
||||
{{ tm('availability.refresh') }}
|
||||
</v-btn>
|
||||
<v-btn variant="text" color="primary" @click="showStatus = !showStatus" style="margin-left: 8px;">
|
||||
{{ showStatus ? tm('logs.collapse') : tm('logs.expand') }}
|
||||
<v-icon>{{ showStatus ? 'mdi-chevron-up' : 'mdi-chevron-down' }}</v-icon>
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-expand-transition>
|
||||
<v-card-text class="pa-0" v-if="showStatus">
|
||||
<v-card-text class="px-4 py-3">
|
||||
<v-alert v-if="providerStatuses.length === 0" type="info" variant="tonal">
|
||||
{{ tm('availability.noData') }}
|
||||
</v-alert>
|
||||
|
||||
<v-container v-else class="pa-0">
|
||||
<v-row>
|
||||
<v-col v-for="status in providerStatuses" :key="status.id" cols="12" sm="6" md="4">
|
||||
<v-card variant="outlined" class="status-card" :class="`status-${status.status}`">
|
||||
<v-card-item>
|
||||
<v-icon v-if="status.status === 'available'" color="success"
|
||||
class="me-2">mdi-check-circle</v-icon>
|
||||
<v-icon v-else-if="status.status === 'unavailable'" color="error"
|
||||
class="me-2">mdi-alert-circle</v-icon>
|
||||
<v-progress-circular v-else-if="status.status === 'pending'" indeterminate color="primary"
|
||||
size="20" width="2" class="me-2"></v-progress-circular>
|
||||
|
||||
<span class="font-weight-bold">{{ status.id }}</span>
|
||||
|
||||
<v-chip :color="getStatusColor(status.status)" size="small" class="ml-2">
|
||||
{{ getStatusText(status.status) }}
|
||||
</v-chip>
|
||||
</v-card-item>
|
||||
<v-card-text v-if="status.status === 'unavailable'" class="text-caption text-medium-emphasis">
|
||||
<span class="font-weight-bold">{{ tm('availability.errorMessage') }}:</span> {{ status.error }}
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
</v-card-text>
|
||||
</v-expand-transition>
|
||||
</v-card>
|
||||
|
||||
<!-- 日志部分 -->
|
||||
<v-card elevation="0" class="mt-4">
|
||||
<v-card elevation="0" class="mt-4 mb-10">
|
||||
<v-card-title class="d-flex align-center py-3 px-4">
|
||||
<v-icon class="me-2">mdi-console-line</v-icon>
|
||||
<span class="text-h4">{{ tm('logs.title') }}</span>
|
||||
@@ -751,11 +735,14 @@ export default {
|
||||
return this.testingProviders.includes(providerId);
|
||||
},
|
||||
|
||||
getProviderStatus(providerId) {
|
||||
return this.providerStatuses.find(s => s.id === providerId);
|
||||
},
|
||||
|
||||
async testSingleProvider(provider) {
|
||||
if (this.isProviderTesting(provider.id)) return;
|
||||
|
||||
this.testingProviders.push(provider.id);
|
||||
this.showStatus = true; // 自动展开状态部分
|
||||
|
||||
// 更新UI为pending状态
|
||||
const statusIndex = this.providerStatuses.findIndex(s => s.id === provider.id);
|
||||
@@ -862,6 +849,7 @@ export default {
|
||||
.provider-page {
|
||||
padding: 20px;
|
||||
padding-top: 8px;
|
||||
padding-bottom: 40px;
|
||||
}
|
||||
|
||||
.status-card {
|
||||
|
||||
@@ -143,11 +143,11 @@
|
||||
</v-dialog>
|
||||
|
||||
<!-- 规则编辑对话框 -->
|
||||
<v-dialog v-model="ruleDialog" max-width="700" scrollable>
|
||||
<v-dialog v-model="ruleDialog" max-width="550" scrollable>
|
||||
<v-card v-if="selectedUmo" class="d-flex flex-column" height="600">
|
||||
<v-card-title class="py-3 px-6 d-flex align-center border-b">
|
||||
<span>{{ tm('ruleEditor.title') }}</span>
|
||||
<v-chip size="small" class="ml-4 font-weight-regular" variant="outlined">
|
||||
<v-chip size="x-small" class="ml-2 font-weight-regular" variant="outlined">
|
||||
{{ selectedUmo.umo }}
|
||||
</v-chip>
|
||||
<v-spacer></v-spacer>
|
||||
@@ -241,6 +241,59 @@
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- Plugin Config Section -->
|
||||
<div class="d-flex align-center mb-4 mt-4">
|
||||
<h3 class="font-weight-bold mb-0">{{ tm('ruleEditor.pluginConfig.title') }}</h3>
|
||||
</div>
|
||||
|
||||
<v-row dense>
|
||||
<v-col cols="12">
|
||||
<v-select v-model="pluginConfig.disabled_plugins" :items="pluginOptions" item-title="label"
|
||||
item-value="value" :label="tm('ruleEditor.pluginConfig.disabledPlugins')" variant="outlined"
|
||||
hide-details multiple chips closable-chips clearable />
|
||||
</v-col>
|
||||
<v-col cols="12">
|
||||
<v-alert type="info" variant="tonal" class="mt-2" icon="mdi-information-outline">
|
||||
{{ tm('ruleEditor.pluginConfig.hint') }}
|
||||
</v-alert>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<div class="d-flex justify-end mt-4">
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="savePluginConfig" :loading="saving"
|
||||
prepend-icon="mdi-content-save">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- KB Config Section -->
|
||||
<div class="d-flex align-center mb-4 mt-4">
|
||||
<h3 class="font-weight-bold mb-0">{{ tm('ruleEditor.kbConfig.title') }}</h3>
|
||||
</div>
|
||||
|
||||
<v-row dense>
|
||||
<v-col cols="12">
|
||||
<v-select v-model="kbConfig.kb_ids" :items="kbOptions" item-title="label" item-value="value" :disabled="availableKbs.length === 0"
|
||||
:label="tm('ruleEditor.kbConfig.selectKbs')" variant="outlined" hide-details multiple chips
|
||||
closable-chips clearable />
|
||||
</v-col>
|
||||
<v-col cols="12" md="6">
|
||||
<v-text-field v-model.number="kbConfig.top_k" :label="tm('ruleEditor.kbConfig.topK')"
|
||||
variant="outlined" hide-details type="number" min="1" max="20" class="mt-3"/>
|
||||
</v-col>
|
||||
<v-col cols="12" md="6">
|
||||
<v-checkbox v-model="kbConfig.enable_rerank" :label="tm('ruleEditor.kbConfig.enableRerank')"
|
||||
color="primary" hide-details class="mt-3"/>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<div class="d-flex justify-end mt-4">
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="saveKbConfig" :loading="saving"
|
||||
prepend-icon="mdi-content-save">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
@@ -347,6 +400,8 @@ export default {
|
||||
availableChatProviders: [],
|
||||
availableSttProviders: [],
|
||||
availableTtsProviders: [],
|
||||
availablePlugins: [],
|
||||
availableKbs: [],
|
||||
|
||||
// 添加规则
|
||||
addRuleDialog: false,
|
||||
@@ -374,6 +429,19 @@ export default {
|
||||
text_to_speech: null,
|
||||
},
|
||||
|
||||
// 插件配置
|
||||
pluginConfig: {
|
||||
enabled_plugins: [],
|
||||
disabled_plugins: [],
|
||||
},
|
||||
|
||||
// 知识库配置
|
||||
kbConfig: {
|
||||
kb_ids: [],
|
||||
top_k: 5,
|
||||
enable_rerank: true,
|
||||
},
|
||||
|
||||
// 删除确认
|
||||
deleteDialog: false,
|
||||
deleteTarget: null,
|
||||
@@ -447,6 +515,20 @@ export default {
|
||||
}))
|
||||
]
|
||||
},
|
||||
|
||||
pluginOptions() {
|
||||
return this.availablePlugins.map(p => ({
|
||||
label: p.display_name || p.name,
|
||||
value: p.name
|
||||
}))
|
||||
},
|
||||
|
||||
kbOptions() {
|
||||
return this.availableKbs.map(kb => ({
|
||||
label: `${kb.emoji || '📚'} ${kb.kb_name}`,
|
||||
value: kb.kb_id
|
||||
}))
|
||||
},
|
||||
},
|
||||
|
||||
watch: {
|
||||
@@ -492,6 +574,8 @@ export default {
|
||||
this.availableChatProviders = data.available_chat_providers
|
||||
this.availableSttProviders = data.available_stt_providers
|
||||
this.availableTtsProviders = data.available_tts_providers
|
||||
this.availablePlugins = data.available_plugins || []
|
||||
this.availableKbs = data.available_kbs || []
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadError'))
|
||||
}
|
||||
@@ -589,6 +673,21 @@ export default {
|
||||
text_to_speech: this.editingRules['provider_perf_text_to_speech'] || null,
|
||||
}
|
||||
|
||||
// 初始化插件配置
|
||||
const pluginCfg = this.editingRules.session_plugin_config || {}
|
||||
this.pluginConfig = {
|
||||
enabled_plugins: pluginCfg.enabled_plugins || [],
|
||||
disabled_plugins: pluginCfg.disabled_plugins || [],
|
||||
}
|
||||
|
||||
// 初始化知识库配置
|
||||
const kbCfg = this.editingRules.kb_config || {}
|
||||
this.kbConfig = {
|
||||
kb_ids: kbCfg.kb_ids || [],
|
||||
top_k: kbCfg.top_k ?? 5,
|
||||
enable_rerank: kbCfg.enable_rerank !== false,
|
||||
}
|
||||
|
||||
this.ruleDialog = true
|
||||
},
|
||||
|
||||
@@ -708,6 +807,117 @@ export default {
|
||||
this.saving = false
|
||||
},
|
||||
|
||||
async savePluginConfig() {
|
||||
if (!this.selectedUmo) return
|
||||
|
||||
this.saving = true
|
||||
try {
|
||||
const config = {
|
||||
enabled_plugins: this.pluginConfig.enabled_plugins,
|
||||
disabled_plugins: this.pluginConfig.disabled_plugins,
|
||||
}
|
||||
|
||||
// 如果两个列表都为空,删除配置
|
||||
if (config.enabled_plugins.length === 0 && config.disabled_plugins.length === 0) {
|
||||
if (this.editingRules.session_plugin_config) {
|
||||
await axios.post('/api/session/delete-rule', {
|
||||
umo: this.selectedUmo.umo,
|
||||
rule_key: 'session_plugin_config'
|
||||
})
|
||||
delete this.editingRules.session_plugin_config
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
if (item) delete item.rules.session_plugin_config
|
||||
}
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
} else {
|
||||
const response = await axios.post('/api/session/update-rule', {
|
||||
umo: this.selectedUmo.umo,
|
||||
rule_key: 'session_plugin_config',
|
||||
rule_value: config
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
this.editingRules.session_plugin_config = config
|
||||
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
if (item) {
|
||||
item.rules.session_plugin_config = config
|
||||
} else {
|
||||
this.rulesList.push({
|
||||
umo: this.selectedUmo.umo,
|
||||
platform: this.selectedUmo.platform,
|
||||
message_type: this.selectedUmo.message_type,
|
||||
session_id: this.selectedUmo.session_id,
|
||||
rules: { session_plugin_config: config }
|
||||
})
|
||||
}
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.saveError'))
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
|
||||
}
|
||||
this.saving = false
|
||||
},
|
||||
|
||||
async saveKbConfig() {
|
||||
if (!this.selectedUmo) return
|
||||
|
||||
this.saving = true
|
||||
try {
|
||||
const config = {
|
||||
kb_ids: this.kbConfig.kb_ids,
|
||||
top_k: this.kbConfig.top_k,
|
||||
enable_rerank: this.kbConfig.enable_rerank,
|
||||
}
|
||||
|
||||
// 如果 kb_ids 为空,删除配置
|
||||
if (config.kb_ids.length === 0) {
|
||||
if (this.editingRules.kb_config) {
|
||||
await axios.post('/api/session/delete-rule', {
|
||||
umo: this.selectedUmo.umo,
|
||||
rule_key: 'kb_config'
|
||||
})
|
||||
delete this.editingRules.kb_config
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
if (item) delete item.rules.kb_config
|
||||
}
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
} else {
|
||||
const response = await axios.post('/api/session/update-rule', {
|
||||
umo: this.selectedUmo.umo,
|
||||
rule_key: 'kb_config',
|
||||
rule_value: config
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
this.editingRules.kb_config = config
|
||||
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
if (item) {
|
||||
item.rules.kb_config = config
|
||||
} else {
|
||||
this.rulesList.push({
|
||||
umo: this.selectedUmo.umo,
|
||||
platform: this.selectedUmo.platform,
|
||||
message_type: this.selectedUmo.message_type,
|
||||
session_id: this.selectedUmo.session_id,
|
||||
rules: { kb_config: config }
|
||||
})
|
||||
}
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.saveError'))
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
|
||||
}
|
||||
this.saving = false
|
||||
},
|
||||
|
||||
confirmDeleteRules(item) {
|
||||
this.deleteTarget = item
|
||||
this.deleteDialog = true
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
@@ -9,6 +11,53 @@ class ProviderCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
def _log_reachability_failure(
|
||||
self,
|
||||
provider,
|
||||
provider_capability_type: ProviderType | None,
|
||||
err_code: str,
|
||||
err_reason: str,
|
||||
):
|
||||
"""记录不可达原因到日志。"""
|
||||
meta = provider.meta()
|
||||
logger.warning(
|
||||
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
|
||||
meta.id,
|
||||
provider_capability_type.name if provider_capability_type else "unknown",
|
||||
err_code,
|
||||
err_reason,
|
||||
)
|
||||
|
||||
async def _test_provider_capability(self, provider):
|
||||
"""测试单个 provider 的可用性"""
|
||||
meta = provider.meta()
|
||||
provider_capability_type = meta.provider_type
|
||||
|
||||
try:
|
||||
result = await provider.test()
|
||||
if result:
|
||||
return True, None, None
|
||||
err_code = "TEST_FAILED"
|
||||
err_reason = "Provider test returned False"
|
||||
self._log_reachability_failure(
|
||||
provider, provider_capability_type, err_code, err_reason
|
||||
)
|
||||
return False, err_code, err_reason
|
||||
except Exception as exc:
|
||||
err_code = (
|
||||
getattr(exc, "status_code", None)
|
||||
or getattr(exc, "code", None)
|
||||
or getattr(exc, "error_code", None)
|
||||
)
|
||||
err_reason = str(exc)
|
||||
if not err_code:
|
||||
err_code = exc.__class__.__name__
|
||||
|
||||
self._log_reachability_failure(
|
||||
provider, provider_capability_type, err_code, err_reason
|
||||
)
|
||||
return False, err_code, err_reason
|
||||
|
||||
async def provider(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -17,46 +66,131 @@ class ProviderCommands:
|
||||
):
|
||||
"""查看或者切换 LLM Provider"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
reachability_check_enabled = cfg.get("reachability_check", True)
|
||||
|
||||
if idx is None:
|
||||
parts = ["## 载入的 LLM 提供商\n"]
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
id_ = llm.meta().id
|
||||
line = f"{idx + 1}. {id_} ({llm.meta().model})"
|
||||
|
||||
# 获取所有类型的提供商
|
||||
llms = list(self.context.get_all_providers())
|
||||
ttss = self.context.get_all_tts_providers()
|
||||
stts = self.context.get_all_stt_providers()
|
||||
|
||||
# 构造待检测列表: [(provider, type_label), ...]
|
||||
all_providers = []
|
||||
all_providers.extend([(p, "llm") for p in llms])
|
||||
all_providers.extend([(p, "tts") for p in ttss])
|
||||
all_providers.extend([(p, "stt") for p in stts])
|
||||
|
||||
# 并发测试连通性
|
||||
if reachability_check_enabled:
|
||||
if all_providers:
|
||||
await event.send(
|
||||
MessageEventResult().message(
|
||||
"正在进行提供商可达性测试,请稍候..."
|
||||
)
|
||||
)
|
||||
check_results = await asyncio.gather(
|
||||
*[self._test_provider_capability(p) for p, _ in all_providers],
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
# 用 None 表示未检测
|
||||
check_results = [None for _ in all_providers]
|
||||
|
||||
# 整合结果
|
||||
display_data = []
|
||||
for (p, p_type), reachable in zip(all_providers, check_results):
|
||||
meta = p.meta()
|
||||
id_ = meta.id
|
||||
error_code = None
|
||||
|
||||
if isinstance(reachable, Exception):
|
||||
# 异常情况下兜底处理,避免单个 provider 导致列表失败
|
||||
self._log_reachability_failure(
|
||||
p,
|
||||
None,
|
||||
reachable.__class__.__name__,
|
||||
str(reachable),
|
||||
)
|
||||
reachable_flag = False
|
||||
error_code = reachable.__class__.__name__
|
||||
elif isinstance(reachable, tuple):
|
||||
reachable_flag, error_code, _ = reachable
|
||||
else:
|
||||
reachable_flag = reachable
|
||||
|
||||
# 根据类型构建显示名称
|
||||
if p_type == "llm":
|
||||
info = f"{id_} ({meta.model})"
|
||||
else:
|
||||
info = f"{id_}"
|
||||
|
||||
# 确定状态标记
|
||||
if reachable_flag is True:
|
||||
mark = " ✅"
|
||||
elif reachable_flag is False:
|
||||
if error_code:
|
||||
mark = f" ❌(错误码: {error_code})"
|
||||
else:
|
||||
mark = " ❌"
|
||||
else:
|
||||
mark = "" # 不支持检测时不显示标记
|
||||
|
||||
display_data.append(
|
||||
{
|
||||
"type": p_type,
|
||||
"info": info,
|
||||
"mark": mark,
|
||||
"provider": p,
|
||||
}
|
||||
)
|
||||
|
||||
# 分组输出
|
||||
# 1. LLM
|
||||
llm_data = [d for d in display_data if d["type"] == "llm"]
|
||||
for i, d in enumerate(llm_data):
|
||||
line = f"{i + 1}. {d['info']}{d['mark']}"
|
||||
provider_using = self.context.get_using_provider(umo=umo)
|
||||
if provider_using and provider_using.meta().id == id_:
|
||||
if (
|
||||
provider_using
|
||||
and provider_using.meta().id == d["provider"].meta().id
|
||||
):
|
||||
line += " (当前使用)"
|
||||
parts.append(line + "\n")
|
||||
|
||||
tts_providers = self.context.get_all_tts_providers()
|
||||
if tts_providers:
|
||||
# 2. TTS
|
||||
tts_data = [d for d in display_data if d["type"] == "tts"]
|
||||
if tts_data:
|
||||
parts.append("\n## 载入的 TTS 提供商\n")
|
||||
for idx, tts in enumerate(tts_providers):
|
||||
id_ = tts.meta().id
|
||||
line = f"{idx + 1}. {id_}"
|
||||
for i, d in enumerate(tts_data):
|
||||
line = f"{i + 1}. {d['info']}{d['mark']}"
|
||||
tts_using = self.context.get_using_tts_provider(umo=umo)
|
||||
if tts_using and tts_using.meta().id == id_:
|
||||
if tts_using and tts_using.meta().id == d["provider"].meta().id:
|
||||
line += " (当前使用)"
|
||||
parts.append(line + "\n")
|
||||
|
||||
stt_providers = self.context.get_all_stt_providers()
|
||||
if stt_providers:
|
||||
# 3. STT
|
||||
stt_data = [d for d in display_data if d["type"] == "stt"]
|
||||
if stt_data:
|
||||
parts.append("\n## 载入的 STT 提供商\n")
|
||||
for idx, stt in enumerate(stt_providers):
|
||||
id_ = stt.meta().id
|
||||
line = f"{idx + 1}. {id_}"
|
||||
for i, d in enumerate(stt_data):
|
||||
line = f"{i + 1}. {d['info']}{d['mark']}"
|
||||
stt_using = self.context.get_using_stt_provider(umo=umo)
|
||||
if stt_using and stt_using.meta().id == id_:
|
||||
if stt_using and stt_using.meta().id == d["provider"].meta().id:
|
||||
line += " (当前使用)"
|
||||
parts.append(line + "\n")
|
||||
|
||||
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
|
||||
ret = "".join(parts)
|
||||
|
||||
if tts_providers:
|
||||
if ttss:
|
||||
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
|
||||
if stt_providers:
|
||||
ret += "\n使用 /provider stt <切换> STT 提供商。"
|
||||
if stts:
|
||||
ret += "\n使用 /provider stt <序号> 切换 STT 提供商。"
|
||||
if not reachability_check_enabled:
|
||||
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
elif idx == "tts":
|
||||
|
||||
@@ -6,9 +6,9 @@ from collections import defaultdict
|
||||
from astrbot import logger
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api.message_components import At, Image, Plain
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
|
||||
"""
|
||||
@@ -30,16 +30,13 @@ class LongTermMemory:
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
max_cnt = 300
|
||||
image_caption = (
|
||||
True
|
||||
if cfg["provider_settings"]["default_image_caption_provider_id"]
|
||||
and cfg["provider_ltm_settings"]["image_caption"]
|
||||
else False
|
||||
)
|
||||
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
|
||||
image_caption_provider_id = cfg["provider_settings"][
|
||||
"default_image_caption_provider_id"
|
||||
]
|
||||
image_caption_provider_id = cfg["provider_ltm_settings"].get(
|
||||
"image_caption_provider_id"
|
||||
)
|
||||
image_caption = cfg["provider_ltm_settings"]["image_caption"] and bool(
|
||||
image_caption_provider_id
|
||||
)
|
||||
active_reply = cfg["provider_ltm_settings"]["active_reply"]
|
||||
enable_active_reply = active_reply.get("enable", False)
|
||||
ar_method = active_reply["method"]
|
||||
@@ -142,6 +139,8 @@ class LongTermMemory:
|
||||
logger.error(f"获取图片描述失败: {e}")
|
||||
else:
|
||||
parts.append(" [Image]")
|
||||
elif isinstance(comp, At):
|
||||
parts.append(f" [At: {comp.name}]")
|
||||
|
||||
final_message = "".join(parts)
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
@@ -159,8 +158,12 @@ class LongTermMemory:
|
||||
cfg = self.cfg(event)
|
||||
if cfg["enable_active_reply"]:
|
||||
prompt = req.prompt
|
||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
||||
req.prompt = (
|
||||
f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
f"\nNow, a new message is coming: `{prompt}`. "
|
||||
"Please react to it. Only output your response and do not output any other information. "
|
||||
"You MUST use the SAME language as the chatroom is using."
|
||||
)
|
||||
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
||||
else:
|
||||
req.system_prompt += (
|
||||
@@ -168,13 +171,15 @@ class LongTermMemory:
|
||||
)
|
||||
req.system_prompt += chats_str
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent):
|
||||
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
if event.get_result() and event.get_result().is_llm_result():
|
||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
if llm_resp.completion_text:
|
||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}"
|
||||
logger.debug(
|
||||
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
|
||||
)
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
cfg = self.cfg(event)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
||||
|
||||
@@ -322,7 +322,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后基于配置注入思考过程文本"""
|
||||
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||
@@ -331,12 +331,9 @@ class Main(star.Star):
|
||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||
)
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_llm_req(self, event: AstrMessageEvent):
|
||||
"""在 LLM 请求后记录对话"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.after_req_llm(event)
|
||||
await self.ltm.after_req_llm(event, resp)
|
||||
except Exception as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.6.1"
|
||||
version = "4.7.3"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user