Compare commits

..

12 Commits

42 changed files with 1142 additions and 863 deletions
-2
View File
@@ -48,5 +48,3 @@ astrbot.lock
chroma chroma
venv/* venv/*
pytest.ini pytest.ini
AGENTS.md
IFLOW.md
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.7.4" __version__ = "4.7.1"
+4 -21
View File
@@ -3,7 +3,7 @@
from typing import Any, ClassVar, Literal, cast from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import core_schema from pydantic_core import core_schema
@@ -145,39 +145,22 @@ class Message(BaseModel):
"tool", "tool",
] ]
content: str | list[ContentPart] | None = None content: str | list[ContentPart]
"""The content of the message.""" """The content of the message."""
tool_calls: list[ToolCall] | list[dict] | None = None
"""The tool calls of the message."""
tool_call_id: str | None = None
"""The ID of the tool call."""
@model_validator(mode="after")
def check_content_required(self):
# assistant + tool_calls is not None: allow content to be None
if self.role == "assistant" and self.tool_calls is not None:
return self
# other all cases: content is required
if self.content is None:
raise ValueError(
"content is required unless role='assistant' and tool_calls is not None"
)
return self
class AssistantMessageSegment(Message): class AssistantMessageSegment(Message):
"""A message segment from the assistant.""" """A message segment from the assistant."""
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
tool_calls: list[ToolCall] | list[dict] | None = None
class ToolCallMessageSegment(Message): class ToolCallMessageSegment(Message):
"""A message segment representing a tool call.""" """A message segment representing a tool call."""
role: Literal["tool"] = "tool" role: Literal["tool"] = "tool"
tool_call_id: str
class UserMessageSegment(Message): class UserMessageSegment(Message):
+2 -70
View File
@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.7.4" VERSION = "4.7.1"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置 # 默认配置
@@ -73,14 +73,8 @@ DEFAULT_CONFIG = {
"coze_agent_runner_provider_id": "", "coze_agent_runner_provider_id": "",
"dashscope_agent_runner_provider_id": "", "dashscope_agent_runner_provider_id": "",
"unsupported_streaming_strategy": "realtime_segmenting", "unsupported_streaming_strategy": "realtime_segmenting",
"reachability_check": False,
"max_agent_step": 30, "max_agent_step": 30,
"tool_call_timeout": 60, "tool_call_timeout": 60,
"file_extract": {
"enable": False,
"provider": "moonshotai",
"moonshotai_api_key": "",
},
}, },
"provider_stt_settings": { "provider_stt_settings": {
"enable": False, "enable": False,
@@ -96,7 +90,6 @@ DEFAULT_CONFIG = {
"group_icl_enable": False, "group_icl_enable": False,
"group_message_max_cnt": 300, "group_message_max_cnt": 300,
"image_caption": False, "image_caption": False,
"image_caption_provider_id": "",
"active_reply": { "active_reply": {
"enable": False, "enable": False,
"method": "possibility_reply", "method": "possibility_reply",
@@ -2074,20 +2067,6 @@ CONFIG_METADATA_2 = {
"tool_call_timeout": { "tool_call_timeout": {
"type": "int", "type": "int",
}, },
"file_extract": {
"type": "object",
"items": {
"enable": {
"type": "bool",
},
"provider": {
"type": "string",
},
"moonshotai_api_key": {
"type": "string",
},
},
},
}, },
}, },
"provider_stt_settings": { "provider_stt_settings": {
@@ -2130,9 +2109,6 @@ CONFIG_METADATA_2 = {
"image_caption": { "image_caption": {
"type": "bool", "type": "bool",
}, },
"image_caption_provider_id": {
"type": "string",
},
"image_caption_prompt": { "image_caption_prompt": {
"type": "string", "type": "string",
}, },
@@ -2422,36 +2398,6 @@ CONFIG_METADATA_3 = {
"provider_settings.enable": True, "provider_settings.enable": True,
}, },
}, },
# "file_extract": {
# "description": "文档解析能力 [beta]",
# "type": "object",
# "items": {
# "provider_settings.file_extract.enable": {
# "description": "启用文档解析能力",
# "type": "bool",
# },
# "provider_settings.file_extract.provider": {
# "description": "文档解析提供商",
# "type": "string",
# "options": ["moonshotai"],
# "condition": {
# "provider_settings.file_extract.enable": True,
# },
# },
# "provider_settings.file_extract.moonshotai_api_key": {
# "description": "Moonshot AI API Key",
# "type": "string",
# "condition": {
# "provider_settings.file_extract.provider": "moonshotai",
# "provider_settings.file_extract.enable": True,
# },
# },
# },
# "condition": {
# "provider_settings.agent_runner_type": "local",
# "provider_settings.enable": True,
# },
# },
"others": { "others": {
"description": "其他配置", "description": "其他配置",
"type": "object", "type": "object",
@@ -2546,11 +2492,6 @@ CONFIG_METADATA_3 = {
"description": "开启 TTS 时同时输出语音和文字内容", "description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool", "type": "bool",
}, },
"provider_settings.reachability_check": {
"description": "提供商可达性检测",
"type": "bool",
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
},
}, },
"condition": { "condition": {
"provider_settings.enable": True, "provider_settings.enable": True,
@@ -2844,16 +2785,7 @@ CONFIG_METADATA_3 = {
"provider_ltm_settings.image_caption": { "provider_ltm_settings.image_caption": {
"description": "自动理解图片", "description": "自动理解图片",
"type": "bool", "type": "bool",
"hint": "需要设置群聊图片转述模型。", "hint": "需要设置默认图片转述模型。",
},
"provider_ltm_settings.image_caption_provider_id": {
"description": "群聊图片转述模型",
"type": "string",
"_special": "select_provider",
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。",
"condition": {
"provider_ltm_settings.image_caption": True,
},
}, },
"provider_ltm_settings.active_reply.enable": { "provider_ltm_settings.active_reply.enable": {
"description": "主动回复", "description": "主动回复",
+21
View File
@@ -213,6 +213,27 @@ class BaseDatabase(abc.ABC):
"""Get an attachment by its ID.""" """Get an attachment by its ID."""
... ...
@abc.abstractmethod
async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]:
"""Get multiple attachments by their IDs."""
...
@abc.abstractmethod
async def delete_attachment(self, attachment_id: str) -> bool:
"""Delete an attachment by its ID.
Returns True if the attachment was deleted, False if it was not found.
"""
...
@abc.abstractmethod
async def delete_attachments(self, attachment_ids: list[str]) -> int:
"""Delete multiple attachments by their IDs.
Returns the number of attachments deleted.
"""
...
@abc.abstractmethod @abc.abstractmethod
async def insert_persona( async def insert_persona(
self, self,
+42
View File
@@ -470,6 +470,48 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query) result = await session.execute(query)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def get_attachments(self, attachment_ids: list[str]) -> list:
"""Get multiple attachments by their IDs."""
if not attachment_ids:
return []
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(
Attachment.attachment_id.in_(attachment_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
async def delete_attachment(self, attachment_id: str) -> bool:
"""Delete an attachment by its ID.
Returns True if the attachment was deleted, False if it was not found.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
Attachment.attachment_id == attachment_id
)
result = await session.execute(query)
return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int:
"""Delete multiple attachments by their IDs.
Returns the number of attachments deleted.
"""
if not attachment_ids:
return 0
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
Attachment.attachment_id.in_(attachment_ids)
)
result = await session.execute(query)
return result.rowcount
async def insert_persona( async def insert_persona(
self, self,
persona_id, persona_id,
+1 -6
View File
@@ -722,12 +722,7 @@ class File(BaseMessageComponent):
"""下载文件""" """下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp") download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True) os.makedirs(download_dir, exist_ok=True)
if self.name: file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
name, ext = os.path.splitext(self.name)
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
else:
filename = f"{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path) await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path) self.file_ = os.path.abspath(file_path)
@@ -9,7 +9,7 @@ from astrbot.core import logger
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Reply from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
MessageEventResult, MessageEventResult,
@@ -22,7 +22,6 @@ from astrbot.core.provider.entities import (
ProviderRequest, ProviderRequest,
) )
from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.metrics import Metric from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager from astrbot.core.utils.session_lock import session_lock_manager
@@ -57,13 +56,6 @@ class InternalAgentSubStage(Stage):
self.show_reasoning = settings.get("display_reasoning_text", False) self.show_reasoning = settings.get("display_reasoning_text", False)
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False) self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
file_extract_conf: dict = settings.get("file_extract", {})
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
self.file_extract_msh_api_key: str = file_extract_conf.get(
"moonshotai_api_key", ""
)
self.conv_manager = ctx.plugin_manager.context.conversation_manager self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent): def _select_provider(self, event: AstrMessageEvent):
@@ -122,50 +114,6 @@ class InternalAgentSubStage(Stage):
req.func_tool = ToolSet() req.func_tool = ToolSet()
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL) req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
async def _apply_file_extract(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""Apply file extract to the provider request"""
file_paths = []
file_names = []
for comp in event.message_obj.message:
if isinstance(comp, File):
file_paths.append(await comp.get_file())
file_names.append(comp.name)
elif isinstance(comp, Reply) and comp.chain:
for reply_comp in comp.chain:
if isinstance(reply_comp, File):
file_paths.append(await reply_comp.get_file())
file_names.append(reply_comp.name)
if not file_paths:
return
if not req.prompt:
req.prompt = "总结一下文件里面讲了什么?"
if self.file_extract_prov == "moonshotai":
if not self.file_extract_msh_api_key:
logger.error("Moonshot AI API key for file extract is not set")
return
file_contents = await asyncio.gather(
*[
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
for file_path in file_paths
]
)
else:
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
return
# add file extract results to contexts
for file_content, file_name in zip(file_contents, file_names):
req.contexts.append(
{
"role": "system",
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
},
)
def _truncate_contexts( def _truncate_contexts(
self, self,
contexts: list[dict], contexts: list[dict],
@@ -398,17 +346,6 @@ class InternalAgentSubStage(Stage):
event.set_extra("provider_request", req) event.set_extra("provider_request", req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# apply file extract
if self.file_extract_enabled:
try:
await self._apply_file_extract(event, req)
except Exception as e:
logger.error(f"Error occurred while applying file extract: {e}")
if not req.prompt and not req.image_urls: if not req.prompt and not req.image_urls:
return return
@@ -419,6 +356,10 @@ class InternalAgentSubStage(Stage):
# apply knowledge base feature # apply knowledge base feature
await self._apply_kb(event, req) await self._apply_kb(event, req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# truncate contexts to fit max length # truncate contexts to fit max length
if req.contexts: if req.contexts:
req.contexts = self._truncate_contexts(req.contexts) req.contexts = self._truncate_contexts(req.contexts)
@@ -2,7 +2,7 @@ import asyncio
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from astrbot.core import astrbot_config, logger from astrbot.core import logger
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner, DashscopeAgentRunner,
@@ -88,15 +88,12 @@ class ThirdPartyAgentSubStage(Stage):
return return
self.prov_cfg: dict = next( self.prov_cfg: dict = next(
(p for p in astrbot_config["provider"] if p["id"] == self.prov_id), (p for p in self.conf["provider"] if p["id"] == self.prov_id),
{}, {},
) )
if not self.prov_id: if not self.prov_id or not self.prov_cfg:
logger.error("没有填写 Agent Runner 提供商 ID,请前往配置页面配置。")
return
if not self.prov_cfg:
logger.error( logger.error(
f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。" "Third Party Agent Runner provider ID is not configured properly."
) )
return return
@@ -1,5 +1,6 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import ProviderRequest from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.star.star_handler import StarHandlerMetadata
@@ -62,5 +63,12 @@ class ProcessStage(Stage):
if ( if (
event.get_result() and not event.get_result().is_stopped() event.get_result() and not event.get_result().is_stopped()
) or not event.get_result(): ) or not event.get_result():
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
async for _ in self.agent_sub_stage.process(event): async for _ in self.agent_sub_stage.process(event):
yield yield
@@ -246,13 +246,7 @@ class AiocqhttpAdapter(Platform):
if m["data"].get("url") and m["data"].get("url").startswith("http"): if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange # Lagrange
logger.info("guessing lagrange") logger.info("guessing lagrange")
# 检查多个可能的文件名字段 file_name = m["data"].get("file_name", "file")
file_name = (
m["data"].get("file_name", "")
or m["data"].get("name", "")
or m["data"].get("file", "")
or "file"
)
abm.message.append(File(name=file_name, url=m["data"]["url"])) abm.message.append(File(name=file_name, url=m["data"]["url"]))
else: else:
try: try:
@@ -271,14 +265,7 @@ class AiocqhttpAdapter(Platform):
) )
if ret and "url" in ret: if ret and "url" in ret:
file_url = ret["url"] # https file_url = ret["url"] # https
# 优先从 API 返回值获取文件名,其次从原始消息数据获取 a = File(name="", url=file_url)
file_name = (
ret.get("file_name", "")
or ret.get("name", "")
or m["data"].get("file", "")
or m["data"].get("file_name", "")
)
a = File(name=file_name, url=file_url)
abm.message.append(a) abm.message.append(a)
else: else:
logger.error(f"获取文件失败: {ret}") logger.error(f"获取文件失败: {ret}")
@@ -250,7 +250,7 @@ class DingtalkPlatformAdapter(Platform):
async def terminate(self): async def terminate(self):
def monkey_patch_close(): def monkey_patch_close():
raise KeyboardInterrupt("Graceful shutdown") raise Exception("Graceful shutdown")
self.client_.open_connection = monkey_patch_close self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown") await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
@@ -381,9 +381,7 @@ class TelegramPlatformAdapter(Platform):
f"Telegram document file_path is None, cannot save the file {file_name}.", f"Telegram document file_path is None, cannot save the file {file_name}.",
) )
else: else:
message.message.append( message.message.append(Comp.File(file=file_path, name=file_name))
Comp.File(file=file_path, name=file_name, url=file_path)
)
elif update.message.video: elif update.message.video:
file = await update.message.video.get_file() file = await update.message.video.get_file()
@@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from astrbot import logger from astrbot import logger
from astrbot.core.message.components import Image, Plain, Record from astrbot.core.message.components import File, Image, Plain, Record, Video
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import ( from astrbot.core.platform import (
AstrBotMessage, AstrBotMessage,
@@ -112,26 +112,19 @@ class WebChatAdapter(Platform):
if payload["message"]: if payload["message"]:
abm.message.append(Plain(payload["message"])) abm.message.append(Plain(payload["message"]))
if payload["image_url"]:
if isinstance(payload["image_url"], list): # 处理 files
for img in payload["image_url"]: files_info = payload.get("files", [])
abm.message.append( for file_info in files_info:
Image.fromFileSystem(os.path.join(self.imgs_dir, img)), if file_info["type"] == "image":
) abm.message.append(Image.fromFileSystem(file_info["path"]))
else: elif file_info["type"] == "record":
abm.message.append( abm.message.append(Record.fromFileSystem(file_info["path"]))
Image.fromFileSystem( elif file_info["type"] == "file":
os.path.join(self.imgs_dir, payload["image_url"]), filename = os.path.basename(file_info["path"])
), abm.message.append(File(name=filename, file=file_info["path"]))
) elif file_info["type"] == "video":
if payload["audio_url"]: abm.message.append(Video.fromFileSystem(file_info["path"]))
if isinstance(payload["audio_url"], list):
for audio in payload["audio_url"]:
path = os.path.join(self.imgs_dir, audio)
abm.message.append(Record(file=path, path=path))
else:
path = os.path.join(self.imgs_dir, payload["audio_url"])
abm.message.append(Record(file=path, path=path))
logger.debug(f"WebChatAdapter: {abm.message}") logger.debug(f"WebChatAdapter: {abm.message}")
@@ -1,12 +1,12 @@
import base64 import base64
import os import os
import shutil
import uuid import uuid
from astrbot.api import logger from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Image, Plain, Record from astrbot.api.message_components import File, Image, Plain, Record
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_image_by_url
from .webchat_queue_mgr import webchat_queue_mgr from .webchat_queue_mgr import webchat_queue_mgr
@@ -19,7 +19,9 @@ class WebChatMessageEvent(AstrMessageEvent):
os.makedirs(imgs_dir, exist_ok=True) os.makedirs(imgs_dir, exist_ok=True)
@staticmethod @staticmethod
async def _send(message: MessageChain, session_id: str, streaming: bool = False): async def _send(
message: MessageChain | None, session_id: str, streaming: bool = False
) -> str | None:
cid = session_id.split("!")[-1] cid = session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
if not message: if not message:
@@ -30,7 +32,7 @@ class WebChatMessageEvent(AstrMessageEvent):
"streaming": False, "streaming": False,
}, # end means this request is finished }, # end means this request is finished
) )
return "" return
data = "" data = ""
for comp in message.chain: for comp in message.chain:
@@ -47,24 +49,11 @@ class WebChatMessageEvent(AstrMessageEvent):
) )
elif isinstance(comp, Image): elif isinstance(comp, Image):
# save image to local # save image to local
filename = str(uuid.uuid4()) + ".jpg" filename = f"{str(uuid.uuid4())}.jpg"
path = os.path.join(imgs_dir, filename) path = os.path.join(imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"): image_base64 = await comp.convert_to_base64()
ph = comp.file[8:] with open(path, "wb") as f:
with open(path, "wb") as f: f.write(base64.b64decode(image_base64))
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file.startswith("base64://"):
base64_str = comp.file[9:]
image_data = base64.b64decode(base64_str)
with open(path, "wb") as f:
f.write(image_data)
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
data = f"[IMAGE]{filename}" data = f"[IMAGE]{filename}"
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
@@ -76,19 +65,11 @@ class WebChatMessageEvent(AstrMessageEvent):
) )
elif isinstance(comp, Record): elif isinstance(comp, Record):
# save record to local # save record to local
filename = str(uuid.uuid4()) + ".wav" filename = f"{str(uuid.uuid4())}.wav"
path = os.path.join(imgs_dir, filename) path = os.path.join(imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"): record_base64 = await comp.convert_to_base64()
ph = comp.file[8:] with open(path, "wb") as f:
with open(path, "wb") as f: f.write(base64.b64decode(record_base64))
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
data = f"[RECORD]{filename}" data = f"[RECORD]{filename}"
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
@@ -98,6 +79,23 @@ class WebChatMessageEvent(AstrMessageEvent):
"streaming": streaming, "streaming": streaming,
}, },
) )
elif isinstance(comp, File):
# save file to local
file_path = await comp.get_file()
original_name = comp.name or os.path.basename(file_path)
ext = os.path.splitext(original_name)[1] or ""
filename = f"{uuid.uuid4()!s}{ext}"
dest_path = os.path.join(imgs_dir, filename)
shutil.copy2(file_path, dest_path)
data = f"[FILE]{filename}|{original_name}"
await web_chat_back_queue.put(
{
"type": "file",
"cid": cid,
"data": data,
"streaming": streaming,
},
)
else: else:
logger.debug(f"webchat 忽略: {comp.type}") logger.debug(f"webchat 忽略: {comp.type}")
@@ -131,6 +129,8 @@ class WebChatMessageEvent(AstrMessageEvent):
session_id=self.session_id, session_id=self.session_id,
streaming=True, streaming=True,
) )
if not r:
continue
if chain.type == "reasoning": if chain.type == "reasoning":
reasoning_content += chain.get_plain_text() reasoning_content += chain.get_plain_text()
else: else:
+1 -1
View File
@@ -10,7 +10,7 @@ class PlatformMessageHistoryManager:
self, self,
platform_id: str, platform_id: str,
user_id: str, user_id: str,
content: list[dict], # TODO: parse from message chain content: dict, # TODO: parse from message chain
sender_id: str | None = None, sender_id: str | None = None,
sender_name: str | None = None, sender_name: str | None = None,
): ):
-35
View File
@@ -1,6 +1,5 @@
import abc import abc
import asyncio import asyncio
import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from astrbot.core.agent.message import Message from astrbot.core.agent.message import Message
@@ -12,7 +11,6 @@ from astrbot.core.provider.entities import (
ToolCallsResult, ToolCallsResult,
) )
from astrbot.core.provider.register import provider_cls_map from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
class AbstractProvider(abc.ABC): class AbstractProvider(abc.ABC):
@@ -45,14 +43,6 @@ class AbstractProvider(abc.ABC):
) )
return meta return meta
async def test(self):
"""test the provider is a
raises:
Exception: if the provider is not available
"""
...
class Provider(AbstractProvider): class Provider(AbstractProvider):
"""Chat Provider""" """Chat Provider"""
@@ -175,12 +165,6 @@ class Provider(AbstractProvider):
return dicts return dicts
async def test(self, timeout: float = 45.0):
await asyncio.wait_for(
self.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=timeout,
)
class STTProvider(AbstractProvider): class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -193,14 +177,6 @@ class STTProvider(AbstractProvider):
"""获取音频的文本""" """获取音频的文本"""
raise NotImplementedError raise NotImplementedError
async def test(self):
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
await self.get_text(sample_audio_path)
class TTSProvider(AbstractProvider): class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -213,9 +189,6 @@ class TTSProvider(AbstractProvider):
"""获取文本的音频,返回音频文件路径""" """获取文本的音频,返回音频文件路径"""
raise NotImplementedError raise NotImplementedError
async def test(self):
await self.get_audio("hi")
class EmbeddingProvider(AbstractProvider): class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -238,9 +211,6 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度""" """获取向量的维度"""
... ...
async def test(self):
await self.get_embedding("astrbot")
async def get_embeddings_batch( async def get_embeddings_batch(
self, self,
texts: list[str], texts: list[str],
@@ -324,8 +294,3 @@ class RerankProvider(AbstractProvider):
) -> list[RerankResult]: ) -> list[RerankResult]:
"""获取查询和文档的重排序分数""" """获取查询和文档的重排序分数"""
... ...
async def test(self):
result = await self.rerank("Apple", documents=["apple", "banana"])
if not result:
raise Exception("Rerank provider test failed, no results returned")
-23
View File
@@ -1,23 +0,0 @@
from pathlib import Path
from openai import AsyncOpenAI
async def extract_file_moonshotai(file_path: str, api_key: str) -> str:
"""Extract text from a file using Moonshot AI API"""
"""
Args:
file_path: The path to the file to extract text from
api_key: The API key to use to extract text from the file
Returns:
The text extracted from the file
"""
client = AsyncOpenAI(
api_key=api_key,
base_url="https://api.moonshot.cn/v1",
)
file_object = await client.files.create(
file=Path(file_path),
purpose="file-extract", # type: ignore
)
return (await client.files.content(file_id=file_object.id)).text
+266 -64
View File
@@ -1,15 +1,16 @@
import asyncio import asyncio
import json import json
import mimetypes
import os import os
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from quart import Response as QuartResponse from quart import g, make_response, request, send_file
from quart import g, make_response, request
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Attachment
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -44,7 +45,7 @@ class ChatRoute(Route):
self.update_session_display_name, self.update_session_display_name,
), ),
"/chat/get_file": ("GET", self.get_file), "/chat/get_file": ("GET", self.get_file),
"/chat/post_image": ("POST", self.post_image), "/chat/get_attachment": ("GET", self.get_attachment),
"/chat/post_file": ("POST", self.post_file), "/chat/post_file": ("POST", self.post_file),
} }
self.core_lifecycle = core_lifecycle self.core_lifecycle = core_lifecycle
@@ -73,52 +74,176 @@ class ChatRoute(Route):
if not real_file_path.startswith(real_imgs_dir): if not real_file_path.startswith(real_imgs_dir):
return Response().error("Invalid file path").__dict__ return Response().error("Invalid file path").__dict__
with open(real_file_path, "rb") as f: filename_ext = os.path.splitext(filename)[1].lower()
filename_ext = os.path.splitext(filename)[1].lower() if filename_ext == ".wav":
return await send_file(real_file_path, mimetype="audio/wav")
if filename_ext == ".wav": if filename_ext[1:] in self.supported_imgs:
return QuartResponse(f.read(), mimetype="audio/wav") return await send_file(real_file_path, mimetype="image/jpeg")
if filename_ext[1:] in self.supported_imgs: return await send_file(real_file_path)
return QuartResponse(f.read(), mimetype="image/jpeg")
return QuartResponse(f.read())
except (FileNotFoundError, OSError): except (FileNotFoundError, OSError):
return Response().error("File access error").__dict__ return Response().error("File access error").__dict__
async def post_image(self): async def get_attachment(self):
post_data = await request.files """Get attachment file by attachment_id."""
if "file" not in post_data: attachment_id = request.args.get("attachment_id")
return Response().error("Missing key: file").__dict__ if not attachment_id:
return Response().error("Missing key: attachment_id").__dict__
file = post_data["file"] try:
filename = str(uuid.uuid4()) + ".jpg" attachment = await self.db.get_attachment_by_id(attachment_id)
path = os.path.join(self.imgs_dir, filename) if not attachment:
await file.save(path) return Response().error("Attachment not found").__dict__
return Response().ok(data={"filename": filename}).__dict__ file_path = attachment.path
real_file_path = os.path.realpath(file_path)
return await send_file(real_file_path, mimetype=attachment.mime_type)
except (FileNotFoundError, OSError):
return Response().error("File access error").__dict__
async def post_file(self): async def post_file(self):
"""Upload a file and create an attachment record, return attachment_id."""
post_data = await request.files post_data = await request.files
if "file" not in post_data: if "file" not in post_data:
return Response().error("Missing key: file").__dict__ return Response().error("Missing key: file").__dict__
file = post_data["file"] file = post_data["file"]
filename = f"{uuid.uuid4()!s}" filename = file.filename or f"{uuid.uuid4()!s}"
# 通过文件格式判断文件类型 content_type = file.content_type or "application/octet-stream"
if file.content_type.startswith("audio"):
filename += ".wav" # 根据 content_type 判断文件类型并添加扩展名
if content_type.startswith("image"):
attach_type = "image"
elif content_type.startswith("audio"):
attach_type = "record"
elif content_type.startswith("video"):
attach_type = "video"
else:
attach_type = "file"
path = os.path.join(self.imgs_dir, filename) path = os.path.join(self.imgs_dir, filename)
await file.save(path) await file.save(path)
return Response().ok(data={"filename": filename}).__dict__ # 创建 attachment 记录
attachment = await self.db.insert_attachment(
path=path,
type=attach_type,
mime_type=content_type,
)
if not attachment:
return Response().error("Failed to create attachment").__dict__
filename = os.path.basename(attachment.path)
return (
Response()
.ok(
data={
"attachment_id": attachment.attachment_id,
"filename": filename,
"type": attach_type,
}
)
.__dict__
)
async def _build_user_message_parts(
self,
message: str,
attachments: list[Attachment],
) -> list:
"""构建用户消息的部分列表
Args:
message: 文本消息
files: attachment_id 列表
"""
parts = []
if message:
parts.append({"type": "plain", "text": message})
if attachments:
for attachment in attachments:
parts.append(
{
"type": attachment.type,
"attachment_id": attachment.attachment_id,
"filename": os.path.basename(attachment.path),
}
)
return parts
async def _create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
"""从本地文件创建 attachment 并返回消息部分
用于处理 bot 回复中的媒体文件
Args:
filename: 存储的文件名
attach_type: 附件类型 (image, record, file, video)
"""
file_path = os.path.join(self.imgs_dir, os.path.basename(filename))
if not os.path.exists(file_path):
return None
# guess mime type
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
# insert attachment
attachment = await self.db.insert_attachment(
path=file_path,
type=attach_type,
mime_type=mime_type,
)
if not attachment:
return None
return {
"type": attach_type,
"attachment_id": attachment.attachment_id,
"filename": os.path.basename(file_path),
}
async def _save_bot_message(
self,
webchat_conv_id: str,
text: str,
media_parts: list,
reasoning: str,
):
"""保存 bot 消息到历史记录"""
bot_message_parts = []
if text:
bot_message_parts.append({"type": "plain", "text": text})
bot_message_parts.extend(media_parts)
new_his = {"type": "bot", "message": bot_message_parts}
if reasoning:
new_his["reasoning"] = reasoning
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
async def chat(self): async def chat(self):
username = g.get("username", "guest") username = g.get("username", "guest")
post_data = await request.json post_data = await request.json
if "message" not in post_data and "image_url" not in post_data: if "message" not in post_data and "files" not in post_data:
return Response().error("Missing key: message or image_url").__dict__ return Response().error("Missing key: message or files").__dict__
if "session_id" not in post_data and "conversation_id" not in post_data: if "session_id" not in post_data and "conversation_id" not in post_data:
return ( return (
@@ -126,44 +251,44 @@ class ChatRoute(Route):
) )
message = post_data["message"] message = post_data["message"]
# conversation_id = post_data["conversation_id"]
session_id = post_data.get("session_id", post_data.get("conversation_id")) session_id = post_data.get("session_id", post_data.get("conversation_id"))
image_url = post_data.get("image_url") files = post_data.get("files") # list of attachment_id
audio_url = post_data.get("audio_url")
selected_provider = post_data.get("selected_provider") selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model") selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True) # 默认为 True enable_streaming = post_data.get("enable_streaming", True)
if not message and not image_url and not audio_url: if not message and not files:
return ( return Response().error("Message and files are both empty").__dict__
Response()
.error("Message and image_url and audio_url are empty")
.__dict__
)
if not session_id: if not session_id:
return Response().error("session_id is empty").__dict__ return Response().error("session_id is empty").__dict__
# 追加用户消息
webchat_conv_id = session_id webchat_conv_id = session_id
# 获取会话特定的队列
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
new_his = {"type": "user", "message": message} # 构建并保存用户消息
if image_url: attachments = await self.db.get_attachments(files)
new_his["image_url"] = image_url message_parts = await self._build_user_message_parts(message, attachments)
if audio_url: files_info = [
new_his["audio_url"] = audio_url {
"type": attachment.type,
"path": attachment.path,
}
for attachment in attachments
]
await self.platform_history_mgr.insert( await self.platform_history_mgr.insert(
platform_id="webchat", platform_id="webchat",
user_id=webchat_conv_id, user_id=webchat_conv_id,
content=new_his, content={"type": "user", "message": message_parts},
sender_id=username, sender_id=username,
sender_name=username, sender_name=username,
) )
async def stream(): async def stream():
client_disconnected = False client_disconnected = False
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
try: try:
async with track_conversation(self.running_convs, webchat_conv_id): async with track_conversation(self.running_convs, webchat_conv_id):
@@ -182,16 +307,17 @@ class ChatRoute(Route):
continue continue
result_text = result["data"] result_text = result["data"]
type = result.get("type") msg_type = result.get("type")
streaming = result.get("streaming", False) streaming = result.get("streaming", False)
# 发送 SSE 数据
try: try:
if not client_disconnected: if not client_disconnected:
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
except Exception as e: except Exception as e:
if not client_disconnected: if not client_disconnected:
logger.debug( logger.debug(
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}", f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
) )
client_disconnected = True client_disconnected = True
@@ -202,24 +328,55 @@ class ChatRoute(Route):
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True client_disconnected = True
if type == "end": # 累积消息部分
if msg_type == "plain":
chain_type = result.get("chain_type", "normal")
if chain_type == "reasoning":
accumulated_reasoning += result_text
else:
accumulated_text += result_text
elif msg_type == "image":
filename = result_text.replace("[IMAGE]", "")
part = await self._create_attachment_from_file(
filename, "image"
)
if part:
accumulated_parts.append(part)
elif msg_type == "record":
filename = result_text.replace("[RECORD]", "")
part = await self._create_attachment_from_file(
filename, "record"
)
if part:
accumulated_parts.append(part)
elif msg_type == "file":
# 格式: [FILE]filename
filename = result_text.replace("[FILE]", "")
part = await self._create_attachment_from_file(
filename, "file"
)
if part:
accumulated_parts.append(part)
# 消息结束处理
if msg_type == "end":
break break
elif ( elif (
(streaming and type == "complete") (streaming and msg_type == "complete")
or not streaming or not streaming
or type == "break" or msg_type == "break"
): ):
# 追加机器人消息 await self._save_bot_message(
new_his = {"type": "bot", "message": result_text} webchat_conv_id,
if "reasoning" in result: accumulated_text,
new_his["reasoning"] = result["reasoning"] accumulated_parts,
await self.platform_history_mgr.insert( accumulated_reasoning,
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
) )
# 重置累积变量 (对于 break 后的下一段消息)
if msg_type == "break":
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
except BaseException as e: except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
@@ -231,8 +388,7 @@ class ChatRoute(Route):
webchat_conv_id, webchat_conv_id,
{ {
"message": message, "message": message,
"image_url": image_url, # list "files": files_info,
"audio_url": audio_url,
"selected_provider": selected_provider, "selected_provider": selected_provider,
"selected_model": selected_model, "selected_model": selected_model,
"enable_streaming": enable_streaming, "enable_streaming": enable_streaming,
@@ -249,7 +405,7 @@ class ChatRoute(Route):
"Connection": "keep-alive", "Connection": "keep-alive",
}, },
) )
response.timeout = None # fix SSE auto disconnect issue response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue]
return response return response
async def delete_webchat_session(self): async def delete_webchat_session(self):
@@ -271,6 +427,17 @@ class ChatRoute(Route):
unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}" unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}"
await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
# 获取消息历史中的所有附件 ID 并删除附件
history_list = await self.platform_history_mgr.get(
platform_id=session.platform_id,
user_id=session_id,
page=1,
page_size=100000, # 获取足够多的记录
)
attachment_ids = self._extract_attachment_ids(history_list)
if attachment_ids:
await self._delete_attachments(attachment_ids)
# 删除消息历史 # 删除消息历史
await self.platform_history_mgr.delete( await self.platform_history_mgr.delete(
platform_id=session.platform_id, platform_id=session.platform_id,
@@ -297,6 +464,41 @@ class ChatRoute(Route):
return Response().ok().__dict__ return Response().ok().__dict__
def _extract_attachment_ids(self, history_list) -> list[str]:
"""从消息历史中提取所有 attachment_id"""
attachment_ids = []
for history in history_list:
content = history.content
if not content or "message" not in content:
continue
message_parts = content.get("message", [])
for part in message_parts:
if isinstance(part, dict) and "attachment_id" in part:
attachment_ids.append(part["attachment_id"])
return attachment_ids
async def _delete_attachments(self, attachment_ids: list[str]):
"""删除附件(包括数据库记录和磁盘文件)"""
try:
attachments = await self.db.get_attachments(attachment_ids)
for attachment in attachments:
if not os.path.exists(attachment.path):
continue
try:
os.remove(attachment.path)
except OSError as e:
logger.warning(
f"Failed to delete attachment file {attachment.path}: {e}"
)
except Exception as e:
logger.warning(f"Failed to get attachments: {e}")
# 批量删除数据库记录
try:
await self.db.delete_attachments(attachment_ids)
except Exception as e:
logger.warning(f"Failed to delete attachments: {e}")
async def new_session(self): async def new_session(self):
"""Create a new Platform session (default: webchat).""" """Create a new Platform session (default: webchat)."""
username = g.get("username", "guest") username = g.get("username", "guest")
+165 -13
View File
@@ -18,8 +18,11 @@ from astrbot.core.config.i18n_utils import ConfigMetadataI18n
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.provider.register import provider_registry from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry from astrbot.core.star.star import star_registry
from astrbot.core.utils.astrbot_path import get_astrbot_path
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
@@ -353,20 +356,169 @@ class ConfigRoute(Route):
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
) )
try: if provider_capability_type == ProviderType.CHAT_COMPLETION:
await provider.test() try:
status_info["status"] = "available" logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
logger.info( response = await asyncio.wait_for(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", provider.text_chat(prompt="REPLY `PONG` ONLY"),
) timeout=45.0,
except Exception as e: )
error_message = str(e) logger.debug(
status_info["error"] = error_message f"Received response from {status_info['name']}: {response}",
logger.warning( )
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}", if response is not None:
) status_info["status"] = "available"
response_text_snippet = ""
if (
hasattr(response, "completion_text")
and response.completion_text
):
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
else response.completion_text
)
elif hasattr(response, "result_chain") and response.result_chain:
try:
response_text_snippet = (
response.result_chain.get_plain_text()[:70] + "..."
if len(response.result_chain.get_plain_text()) > 70
else response.result_chain.get_plain_text()
)
except Exception as _:
pass
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
)
else:
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
)
except asyncio.TimeoutError:
status_info["error"] = (
"Connection timed out after 45 seconds during test call."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
)
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
)
elif provider_capability_type == ProviderType.EMBEDDING:
try:
# For embedding, we can call the get_embedding method with a short prompt.
embedding_result = await provider.get_embedding("health_check")
if isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"Embedding test failed: unexpected result type {type(embedding_result)}"
)
except Exception as e:
logger.error(
f"Error testing embedding provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: {e!s}"
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
try:
# For TTS, we can call the get_audio method with a short prompt.
audio_result = await provider.get_audio("你好")
if isinstance(audio_result, str) and audio_result:
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"TTS test failed: unexpected result type {type(audio_result)}"
)
except Exception as e:
logger.error(
f"Error testing TTS provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: {e!s}"
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
try:
logger.debug(
f"Sending health check audio to provider: {status_info['name']}",
)
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable"
status_info["error"] = (
"STT test failed: sample audio file not found."
)
logger.warning(
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
)
else:
text_result = await provider.get_text(sample_audio_path)
if isinstance(text_result, str) and text_result:
status_info["status"] = "available"
snippet = (
text_result[:70] + "..."
if len(text_result) > 70
else text_result
)
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
)
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"STT test failed: unexpected result type {type(text_result)}"
)
logger.warning(
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
)
except Exception as e:
logger.error(
f"Error testing STT provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {e!s}"
elif provider_capability_type == ProviderType.RERANK:
try:
assert isinstance(provider, RerankProvider)
await provider.rerank("Apple", documents=["apple", "banana"])
status_info["status"] = "available"
except Exception as e:
logger.error(
f"Error testing rerank provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"Rerank test failed: {e!s}"
else:
logger.debug( logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}", f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}",
)
status_info["status"] = "available"
status_info["error"] = (
"This provider type is not tested and is assumed to be available."
) )
return status_info return status_info
+6 -65
View File
@@ -1,4 +1,3 @@
import asyncio
import json import json
import os import os
import ssl import ssl
@@ -20,10 +19,6 @@ from astrbot.core.star.star_manager import PluginManager
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
PLUGIN_UPDATE_CONCURRENCY = (
3 # limit concurrent updates to avoid overwhelming plugin sources
)
class PluginRoute(Route): class PluginRoute(Route):
def __init__( def __init__(
@@ -38,7 +33,6 @@ class PluginRoute(Route):
"/plugin/install": ("POST", self.install_plugin), "/plugin/install": ("POST", self.install_plugin),
"/plugin/install-upload": ("POST", self.install_plugin_upload), "/plugin/install-upload": ("POST", self.install_plugin_upload),
"/plugin/update": ("POST", self.update_plugin), "/plugin/update": ("POST", self.update_plugin),
"/plugin/update-all": ("POST", self.update_all_plugins),
"/plugin/uninstall": ("POST", self.uninstall_plugin), "/plugin/uninstall": ("POST", self.uninstall_plugin),
"/plugin/market_list": ("GET", self.get_online_plugins), "/plugin/market_list": ("GET", self.get_online_plugins),
"/plugin/off": ("POST", self.off_plugin), "/plugin/off": ("POST", self.off_plugin),
@@ -69,7 +63,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
data = await request.get_json() data = await request.json
plugin_name = data.get("name", None) plugin_name = data.get("name", None)
try: try:
success, message = await self.plugin_manager.reload(plugin_name) success, message = await self.plugin_manager.reload(plugin_name)
@@ -352,7 +346,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
repo_url = post_data["url"] repo_url = post_data["url"]
proxy: str = post_data.get("proxy", None) proxy: str = post_data.get("proxy", None)
@@ -399,7 +393,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
delete_config = post_data.get("delete_config", False) delete_config = post_data.get("delete_config", False)
delete_data = post_data.get("delete_data", False) delete_data = post_data.get("delete_data", False)
@@ -424,7 +418,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
proxy: str = post_data.get("proxy", None) proxy: str = post_data.get("proxy", None)
try: try:
@@ -438,59 +432,6 @@ class PluginRoute(Route):
logger.error(f"/api/plugin/update: {traceback.format_exc()}") logger.error(f"/api/plugin/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__ return Response().error(str(e)).__dict__
async def update_all_plugins(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.get_json()
plugin_names: list[str] = post_data.get("names") or []
proxy: str = post_data.get("proxy", "")
if not isinstance(plugin_names, list) or not plugin_names:
return Response().error("插件列表不能为空").__dict__
results = []
sem = asyncio.Semaphore(PLUGIN_UPDATE_CONCURRENCY)
async def _update_one(name: str):
async with sem:
try:
logger.info(f"批量更新插件 {name}")
await self.plugin_manager.update_plugin(name, proxy)
return {"name": name, "status": "ok", "message": "更新成功"}
except Exception as e:
logger.error(
f"/api/plugin/update-all: 更新插件 {name} 失败: {traceback.format_exc()}",
)
return {"name": name, "status": "error", "message": str(e)}
raw_results = await asyncio.gather(
*(_update_one(name) for name in plugin_names),
return_exceptions=True,
)
for name, result in zip(plugin_names, raw_results):
if isinstance(result, asyncio.CancelledError):
raise result
if isinstance(result, BaseException):
results.append(
{"name": name, "status": "error", "message": str(result)}
)
else:
results.append(result)
failed = [r for r in results if r["status"] == "error"]
message = (
"批量更新完成,全部成功。"
if not failed
else f"批量更新完成,其中 {len(failed)}/{len(results)} 个插件失败。"
)
return Response().ok({"results": results}, message).__dict__
async def off_plugin(self): async def off_plugin(self):
if DEMO_MODE: if DEMO_MODE:
return ( return (
@@ -499,7 +440,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
try: try:
await self.plugin_manager.turn_off_plugin(plugin_name) await self.plugin_manager.turn_off_plugin(plugin_name)
@@ -517,7 +458,7 @@ class PluginRoute(Route):
.__dict__ .__dict__
) )
post_data = await request.get_json() post_data = await request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
try: try:
await self.plugin_manager.turn_on_plugin(plugin_name) await self.plugin_manager.turn_on_plugin(plugin_name)
-25
View File
@@ -1,25 +0,0 @@
## What's Changed
1. 修复使用非默认配置文件情况下时,第三方 Agent Runner (Dify、Coze、阿里云百炼应用等)无法正常工作的问题
2. 修复当“聊天模型”未设置,并且模型提供商中仅有 Agent Runner 时,无法正常使用 Agent Runner 的问题
3. 修复部分情况下报错 `pydantic_core._pydantic_core.ValidationError: 1 validation error for Message content` 的问题
4. 新增群聊模式下的专用图片转述模型配置 ([#3822](https://github.com/AstrBotDevs/AstrBot/issues/3822))
---
重构:
- 将 Dify、Coze、阿里云百炼应用等 LLMOps 提供商迁移到 Agent 执行器层,理清和本地 Agent 执行器的边界。详见:[Agent 执行器](https://docs.astrbot.app/use/agent-runner.html)
- 将「会话管理」功能重构为「自定义规则」功能,理清和多配置文件功能的边界。详见:[自定义规则](https://docs.astrbot.app/use/custom-rules.html)
优化:
- Dify、阿里云百炼应用支持流式输出
- 防止分段回复正则表达式解析错误导致消息不发送
- 群聊上下文感知记录 At 信息
- 优化模型提供商页面的测试提供商功能
新增:
- 支持在配置文件页面快速测试对话
- 为配置文件配置项内容添加国际化支持
修复:
- 在更新 MCP Server 配置后,MCP 无法正常重启的问题
-7
View File
@@ -1,7 +0,0 @@
## What's Changed
1. 修复:assistant message 中 tool_call 存在但 content 不存在时,导致验证错误的问题 ([#3862](https://github.com/AstrBotDevs/AstrBot/issues/3862))
2. 修复:fix: aiocqhttp 适配器 NapCat 文件名获取为空 ([#3853](https://github.com/AstrBotDevs/AstrBot/issues/3853))
3. 新增:升级所有插件按钮
4. 新增:/provider 指令支持同时测试提供商可用性
5. 优化:主动回复的 prompt
+20 -5
View File
@@ -84,6 +84,7 @@
v-model:prompt="prompt" v-model:prompt="prompt"
:stagedImagesUrl="stagedImagesUrl" :stagedImagesUrl="stagedImagesUrl"
:stagedAudioUrl="stagedAudioUrl" :stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles"
:disabled="isStreaming" :disabled="isStreaming"
:enableStreaming="enableStreaming" :enableStreaming="enableStreaming"
:isRecording="isRecording" :isRecording="isRecording"
@@ -93,6 +94,7 @@
@toggleStreaming="toggleStreaming" @toggleStreaming="toggleStreaming"
@removeImage="removeImage" @removeImage="removeImage"
@removeAudio="removeAudio" @removeAudio="removeAudio"
@removeFile="removeFile"
@startRecording="handleStartRecording" @startRecording="handleStartRecording"
@stopRecording="handleStopRecording" @stopRecording="handleStopRecording"
@pasteImage="handlePaste" @pasteImage="handlePaste"
@@ -189,14 +191,17 @@ const {
} = useSessions(props.chatboxMode); } = useSessions(props.chatboxMode);
const { const {
stagedImagesName,
stagedImagesUrl, stagedImagesUrl,
stagedAudioUrl, stagedAudioUrl,
stagedFiles,
stagedNonImageFiles,
getMediaFile, getMediaFile,
processAndUploadImage, processAndUploadImage,
processAndUploadFile,
handlePaste, handlePaste,
removeImage, removeImage,
removeAudio, removeAudio,
removeFile,
clearStaged, clearStaged,
cleanupMediaCache cleanupMediaCache
} = useMediaHandling(); } = useMediaHandling();
@@ -295,13 +300,18 @@ async function handleStopRecording() {
} }
async function handleFileSelect(files: FileList) { async function handleFileSelect(files: FileList) {
const imageTypes = ['image/jpeg', 'image/png', 'image/gif', 'image/webp'];
for (const file of files) { for (const file of files) {
await processAndUploadImage(file); if (imageTypes.includes(file.type)) {
await processAndUploadImage(file);
} else {
await processAndUploadFile(file);
}
} }
} }
async function handleSendMessage() { async function handleSendMessage() {
if (!prompt.value.trim() && stagedImagesName.value.length === 0 && !stagedAudioUrl.value) { if (!prompt.value.trim() && stagedFiles.value.length === 0 && !stagedAudioUrl.value) {
return; return;
} }
@@ -310,8 +320,13 @@ async function handleSendMessage() {
} }
const promptToSend = prompt.value.trim(); const promptToSend = prompt.value.trim();
const imageNamesToSend = [...stagedImagesName.value];
const audioNameToSend = stagedAudioUrl.value; const audioNameToSend = stagedAudioUrl.value;
const filesToSend = stagedFiles.value.map(f => ({
attachment_id: f.attachment_id,
url: f.url,
original_name: f.original_name,
type: f.type
}));
// //
prompt.value = ''; prompt.value = '';
@@ -324,7 +339,7 @@ async function handleSendMessage() {
await sendMsg( await sendMsg(
promptToSend, promptToSend,
imageNamesToSend, filesToSend,
audioNameToSend, audioNameToSend,
selectedProviderId, selectedProviderId,
selectedModelName selectedModelName
+36 -7
View File
@@ -30,7 +30,7 @@
</v-tooltip> </v-tooltip>
</div> </div>
<div style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;"> <div style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;">
<input type="file" ref="imageInputRef" @change="handleFileSelect" accept="image/*" <input type="file" ref="imageInputRef" @change="handleFileSelect"
style="display: none" multiple /> style="display: none" multiple />
<v-progress-circular v-if="disabled" indeterminate size="16" class="mr-1" width="1.5" /> <v-progress-circular v-if="disabled" indeterminate size="16" class="mr-1" width="1.5" />
<v-btn @click="triggerImageInput" icon="mdi-plus" variant="text" color="deep-purple" <v-btn @click="triggerImageInput" icon="mdi-plus" variant="text" color="deep-purple"
@@ -45,8 +45,8 @@
</div> </div>
<!-- 附件预览区 --> <!-- 附件预览区 -->
<div class="attachments-preview" v-if="stagedImagesUrl.length > 0 || stagedAudioUrl"> <div class="attachments-preview" v-if="stagedImagesUrl.length > 0 || stagedAudioUrl || (stagedFiles && stagedFiles.length > 0)">
<div v-for="(img, index) in stagedImagesUrl" :key="index" class="image-preview"> <div v-for="(img, index) in stagedImagesUrl" :key="'img-' + index" class="image-preview">
<img :src="img" class="preview-image" /> <img :src="img" class="preview-image" />
<v-btn @click="$emit('removeImage', index)" class="remove-attachment-btn" icon="mdi-close" <v-btn @click="$emit('removeImage', index)" class="remove-attachment-btn" icon="mdi-close"
size="small" color="error" variant="text" /> size="small" color="error" variant="text" />
@@ -60,6 +60,15 @@
<v-btn @click="$emit('removeAudio')" class="remove-attachment-btn" icon="mdi-close" size="small" <v-btn @click="$emit('removeAudio')" class="remove-attachment-btn" icon="mdi-close" size="small"
color="error" variant="text" /> color="error" variant="text" />
</div> </div>
<div v-for="(file, index) in stagedFiles" :key="'file-' + index" class="file-preview">
<v-chip color="blue-grey-lighten-4" class="file-chip">
<v-icon start icon="mdi-file-document-outline" size="small"></v-icon>
<span class="file-name-preview">{{ file.original_name }}</span>
</v-chip>
<v-btn @click="$emit('removeFile', index)" class="remove-attachment-btn" icon="mdi-close" size="small"
color="error" variant="text" />
</div>
</div> </div>
</div> </div>
</template> </template>
@@ -71,10 +80,19 @@ import ProviderModelSelector from './ProviderModelSelector.vue';
import ConfigSelector from './ConfigSelector.vue'; import ConfigSelector from './ConfigSelector.vue';
import type { Session } from '@/composables/useSessions'; import type { Session } from '@/composables/useSessions';
interface StagedFileInfo {
attachment_id: string;
filename: string;
original_name: string;
url: string;
type: string;
}
interface Props { interface Props {
prompt: string; prompt: string;
stagedImagesUrl: string[]; stagedImagesUrl: string[];
stagedAudioUrl: string; stagedAudioUrl: string;
stagedFiles?: StagedFileInfo[];
disabled: boolean; disabled: boolean;
enableStreaming: boolean; enableStreaming: boolean;
isRecording: boolean; isRecording: boolean;
@@ -86,7 +104,8 @@ interface Props {
const props = withDefaults(defineProps<Props>(), { const props = withDefaults(defineProps<Props>(), {
sessionId: null, sessionId: null,
currentSession: null, currentSession: null,
configId: null configId: null,
stagedFiles: () => []
}); });
const emit = defineEmits<{ const emit = defineEmits<{
@@ -95,6 +114,7 @@ const emit = defineEmits<{
toggleStreaming: []; toggleStreaming: [];
removeImage: [index: number]; removeImage: [index: number];
removeAudio: []; removeAudio: [];
removeFile: [index: number];
startRecording: []; startRecording: [];
stopRecording: []; stopRecording: [];
pasteImage: [event: ClipboardEvent]; pasteImage: [event: ClipboardEvent];
@@ -117,7 +137,7 @@ const sessionPlatformId = computed(() => props.currentSession?.platform_id || 'w
const sessionIsGroup = computed(() => Boolean(props.currentSession?.is_group)); const sessionIsGroup = computed(() => Boolean(props.currentSession?.is_group));
const canSend = computed(() => { const canSend = computed(() => {
return (props.prompt && props.prompt.trim()) || props.stagedImagesUrl.length > 0 || props.stagedAudioUrl; return (props.prompt && props.prompt.trim()) || props.stagedImagesUrl.length > 0 || props.stagedAudioUrl || (props.stagedFiles && props.stagedFiles.length > 0);
}); });
// Ctrl+B // Ctrl+B
@@ -239,7 +259,8 @@ defineExpose({
} }
.image-preview, .image-preview,
.audio-preview { .audio-preview,
.file-preview {
position: relative; position: relative;
display: inline-flex; display: inline-flex;
} }
@@ -252,11 +273,19 @@ defineExpose({
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
} }
.audio-chip { .audio-chip,
.file-chip {
height: 36px; height: 36px;
border-radius: 18px; border-radius: 18px;
} }
.file-name-preview {
max-width: 120px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.remove-attachment-btn { .remove-attachment-btn {
position: absolute; position: absolute;
top: -8px; top: -8px;
+174 -5
View File
@@ -24,6 +24,22 @@
{{ t('messages.errors.browser.audioNotSupported') }} {{ t('messages.errors.browser.audioNotSupported') }}
</audio> </audio>
</div> </div>
<!-- 文件附件 -->
<div class="file-attachments" v-if="msg.content.file_url && msg.content.file_url.length > 0">
<div v-for="(file, fileIdx) in msg.content.file_url" :key="fileIdx" class="file-attachment">
<a v-if="file.url" :href="file.url" :download="file.filename" class="file-link">
<v-icon size="small" class="file-icon">mdi-file-document-outline</v-icon>
<span class="file-name">{{ file.filename }}</span>
</a>
<a v-else @click="downloadFile(file)" class="file-link file-link-download">
<v-icon size="small" class="file-icon">mdi-file-document-outline</v-icon>
<span class="file-name">{{ file.filename }}</span>
<v-icon v-if="downloadingFiles.has(file.attachment_id)" size="small" class="download-icon">mdi-loading mdi-spin</v-icon>
<v-icon v-else size="small" class="download-icon">mdi-download</v-icon>
</a>
</div>
</div>
</div> </div>
</div> </div>
@@ -77,10 +93,29 @@
{{ t('messages.errors.browser.audioNotSupported') }} {{ t('messages.errors.browser.audioNotSupported') }}
</audio> </audio>
</div> </div>
<!-- Files -->
<div class="embedded-files"
v-if="msg.content.embedded_files && msg.content.embedded_files.length > 0">
<div v-for="(file, fileIndex) in msg.content.embedded_files" :key="fileIndex"
class="embedded-file">
<a v-if="file.url" :href="file.url" :download="file.filename" class="file-link">
<v-icon size="small" class="file-icon">mdi-file-document-outline</v-icon>
<span class="file-name">{{ file.filename }}</span>
</a>
<a v-else @click="downloadFile(file)" class="file-link file-link-download">
<v-icon size="small" class="file-icon">mdi-file-document-outline</v-icon>
<span class="file-name">{{ file.filename }}</span>
<v-icon v-if="downloadingFiles.has(file.attachment_id)" size="small" class="download-icon">mdi-loading mdi-spin</v-icon>
<v-icon v-else size="small" class="download-icon">mdi-download</v-icon>
</a>
</div>
</div>
</template> </template>
</div> </div>
<div class="message-actions" v-if="!msg.content.isLoading"> <div class="message-actions" v-if="!msg.content.isLoading">
<v-btn :icon="getCopyIcon(index)" size="small" variant="text" class="copy-message-btn" <span class="message-time" v-if="msg.created_at">{{ formatMessageTime(msg.created_at) }}</span>
<v-btn :icon="getCopyIcon(index)" size="x-small" variant="text" class="copy-message-btn"
:class="{ 'copy-success': isCopySuccess(index) }" :class="{ 'copy-success': isCopySuccess(index) }"
@click="copyBotMessage(msg.content.message, index)" :title="t('core.common.copy')" /> @click="copyBotMessage(msg.content.message, index)" :title="t('core.common.copy')" />
</div> </div>
@@ -96,6 +131,7 @@ import { useI18n, useModuleI18n } from '@/i18n/composables';
import MarkdownIt from 'markdown-it'; import MarkdownIt from 'markdown-it';
import hljs from 'highlight.js'; import hljs from 'highlight.js';
import 'highlight.js/styles/github.css'; import 'highlight.js/styles/github.css';
import axios from 'axios';
const md = new MarkdownIt({ const md = new MarkdownIt({
html: false, html: false,
@@ -147,6 +183,7 @@ export default {
scrollThreshold: 1, scrollThreshold: 1,
scrollTimer: null, scrollTimer: null,
expandedReasoning: new Set(), // Track which reasoning blocks are expanded expandedReasoning: new Set(), // Track which reasoning blocks are expanded
downloadingFiles: new Set(), // Track which files are being downloaded
}; };
}, },
mounted() { mounted() {
@@ -179,6 +216,35 @@ export default {
return this.expandedReasoning.has(messageIndex); return this.expandedReasoning.has(messageIndex);
}, },
//
async downloadFile(file) {
if (!file.attachment_id) return;
//
this.downloadingFiles.add(file.attachment_id);
this.downloadingFiles = new Set(this.downloadingFiles);
try {
const response = await axios.get(`/api/chat/get_attachment?attachment_id=${file.attachment_id}`, {
responseType: 'blob'
});
const url = URL.createObjectURL(response.data);
const a = document.createElement('a');
a.href = url;
a.download = file.filename || 'file';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
setTimeout(() => URL.revokeObjectURL(url), 100);
} catch (err) {
console.error('Download file failed:', err);
} finally {
this.downloadingFiles.delete(file.attachment_id);
this.downloadingFiles = new Set(this.downloadingFiles);
}
},
// //
copyCodeToClipboard(code) { copyCodeToClipboard(code) {
navigator.clipboard.writeText(code).then(() => { navigator.clipboard.writeText(code).then(() => {
@@ -375,6 +441,37 @@ export default {
clearTimeout(this.scrollTimer); clearTimeout(this.scrollTimer);
this.scrollTimer = null; this.scrollTimer = null;
} }
},
//
formatMessageTime(dateStr) {
if (!dateStr) return '';
const date = new Date(dateStr);
const now = new Date();
//
const dateDay = new Date(date.getFullYear(), date.getMonth(), date.getDate());
const todayDay = new Date(now.getFullYear(), now.getMonth(), now.getDate());
const yesterdayDay = new Date(todayDay);
yesterdayDay.setDate(yesterdayDay.getDate() - 1);
// HH:MM
const hours = date.getHours().toString().padStart(2, '0');
const minutes = date.getMinutes().toString().padStart(2, '0');
const timeStr = `${hours}:${minutes}`;
//
if (dateDay.getTime() === todayDay.getTime()) {
return `${this.tm('time.today')} ${timeStr}`;
} else if (dateDay.getTime() === yesterdayDay.getTime()) {
return `${this.tm('time.yesterday')} ${timeStr}`;
} else {
//
const month = (date.getMonth() + 1).toString().padStart(2, '0');
const day = date.getDate().toString().padStart(2, '0');
return `${month}-${day} ${timeStr}`;
}
} }
} }
} }
@@ -413,7 +510,7 @@ export default {
} }
.message-item { .message-item {
margin-bottom: 24px; margin-bottom: 12px;
animation: fadeIn 0.3s ease-out; animation: fadeIn 0.3s ease-out;
} }
@@ -441,10 +538,18 @@ export default {
.message-actions { .message-actions {
display: flex; display: flex;
gap: 4px; align-items: center;
gap: 8px;
opacity: 0; opacity: 0;
transition: opacity 0.2s ease; transition: opacity 0.2s ease;
margin-left: 8px; margin-left: 16px;
}
.message-time {
font-size: 12px;
color: var(--v-theme-secondaryText);
opacity: 0.7;
white-space: nowrap;
} }
.bot-message:hover .message-actions { .bot-message:hover .message-actions {
@@ -553,7 +658,6 @@ export default {
width: auto; width: auto;
height: auto; height: auto;
border-radius: 8px; border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
cursor: pointer; cursor: pointer;
transition: transform 0.2s ease; transition: transform 0.2s ease;
} }
@@ -568,6 +672,71 @@ export default {
max-width: 300px; max-width: 300px;
} }
/* 文件附件样式 */
.file-attachments,
.embedded-files {
margin-top: 8px;
display: flex;
flex-direction: column;
gap: 6px;
}
.file-attachment,
.embedded-file {
display: flex;
align-items: center;
}
.file-link {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 8px 12px;
background-color: rgba(var(--v-theme-primary), 0.08);
border: 1px solid rgba(var(--v-theme-primary), 0.2);
border-radius: 8px;
color: rgb(var(--v-theme-primary));
text-decoration: none;
font-size: 14px;
transition: all 0.2s ease;
max-width: 300px;
}
.file-link-download {
cursor: pointer;
}
.download-icon {
margin-left: 4px;
opacity: 0.7;
}
.file-icon {
flex-shrink: 0;
color: rgb(var(--v-theme-primary));
}
.file-name {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.v-theme--dark .file-link {
background-color: rgba(255, 255, 255, 0.05);
border-color: rgba(255, 255, 255, 0.1);
color: var(--v-theme-secondary);
}
.v-theme--dark .file-link:hover {
background-color: rgba(255, 255, 255, 0.1);
border-color: rgba(255, 255, 255, 0.2);
}
.v-theme--dark .file-icon {
color: var(--v-theme-secondary);
}
/* 动画类 */ /* 动画类 */
.fade-in { .fade-in {
animation: fadeIn 0.3s ease-in-out; animation: fadeIn 0.3s ease-in-out;
@@ -110,9 +110,9 @@ function getSessions() {
} }
const { const {
stagedImagesName,
stagedImagesUrl, stagedImagesUrl,
stagedAudioUrl, stagedAudioUrl,
stagedFiles,
getMediaFile, getMediaFile,
processAndUploadImage, processAndUploadImage,
handlePaste, handlePaste,
@@ -164,7 +164,7 @@ async function handleFileSelect(files: FileList) {
} }
async function handleSendMessage() { async function handleSendMessage() {
if (!prompt.value.trim() && stagedImagesName.value.length === 0 && !stagedAudioUrl.value) { if (!prompt.value.trim() && stagedFiles.value.length === 0 && !stagedAudioUrl.value) {
return; return;
} }
@@ -174,8 +174,13 @@ async function handleSendMessage() {
} }
const promptToSend = prompt.value.trim(); const promptToSend = prompt.value.trim();
const imageNamesToSend = [...stagedImagesName.value];
const audioNameToSend = stagedAudioUrl.value; const audioNameToSend = stagedAudioUrl.value;
const filesToSend = stagedFiles.value.map(f => ({
attachment_id: f.attachment_id,
url: f.url,
original_name: f.original_name,
type: f.type
}));
// //
prompt.value = ''; prompt.value = '';
@@ -188,7 +193,7 @@ async function handleSendMessage() {
await sendMsg( await sendMsg(
promptToSend, promptToSend,
imageNamesToSend, filesToSend,
audioNameToSend, audioNameToSend,
selectedProviderId, selectedProviderId,
selectedModelName selectedModelName
+96 -16
View File
@@ -1,10 +1,17 @@
import { ref } from 'vue'; import { ref, computed } from 'vue';
import axios from 'axios'; import axios from 'axios';
export interface StagedFileInfo {
attachment_id: string;
filename: string;
original_name: string;
url: string; // blob URL for preview
type: string; // image, record, file, video
}
export function useMediaHandling() { export function useMediaHandling() {
const stagedImagesName = ref<string[]>([]);
const stagedImagesUrl = ref<string[]>([]);
const stagedAudioUrl = ref<string>(''); const stagedAudioUrl = ref<string>('');
const stagedFiles = ref<StagedFileInfo[]>([]);
const mediaCache = ref<Record<string, string>>({}); const mediaCache = ref<Record<string, string>>({});
async function getMediaFile(filename: string): Promise<string> { async function getMediaFile(filename: string): Promise<string> {
@@ -32,20 +39,49 @@ export function useMediaHandling() {
formData.append('file', file); formData.append('file', file);
try { try {
const response = await axios.post('/api/chat/post_image', formData, { const response = await axios.post('/api/chat/post_file', formData, {
headers: { headers: {
'Content-Type': 'multipart/form-data' 'Content-Type': 'multipart/form-data'
} }
}); });
const img = response.data.data.filename; const { attachment_id, filename, type } = response.data.data;
stagedImagesName.value.push(img); stagedFiles.value.push({
stagedImagesUrl.value.push(URL.createObjectURL(file)); attachment_id,
filename,
original_name: file.name,
url: URL.createObjectURL(file),
type
});
} catch (err) { } catch (err) {
console.error('Error uploading image:', err); console.error('Error uploading image:', err);
} }
} }
async function processAndUploadFile(file: File) {
const formData = new FormData();
formData.append('file', file);
try {
const response = await axios.post('/api/chat/post_file', formData, {
headers: {
'Content-Type': 'multipart/form-data'
}
});
const { attachment_id, filename, type } = response.data.data;
stagedFiles.value.push({
attachment_id,
filename,
original_name: file.name,
url: URL.createObjectURL(file),
type
});
} catch (err) {
console.error('Error uploading file:', err);
}
}
async function handlePaste(event: ClipboardEvent) { async function handlePaste(event: ClipboardEvent) {
const items = event.clipboardData?.items; const items = event.clipboardData?.items;
if (!items) return; if (!items) return;
@@ -61,23 +97,54 @@ export function useMediaHandling() {
} }
function removeImage(index: number) { function removeImage(index: number) {
const urlToRevoke = stagedImagesUrl.value[index]; // 找到第 index 个图片类型的文件
if (urlToRevoke && urlToRevoke.startsWith('blob:')) { let imageCount = 0;
URL.revokeObjectURL(urlToRevoke); for (let i = 0; i < stagedFiles.value.length; i++) {
if (stagedFiles.value[i].type === 'image') {
if (imageCount === index) {
const fileToRemove = stagedFiles.value[i];
if (fileToRemove.url.startsWith('blob:')) {
URL.revokeObjectURL(fileToRemove.url);
}
stagedFiles.value.splice(i, 1);
return;
}
imageCount++;
}
} }
stagedImagesName.value.splice(index, 1);
stagedImagesUrl.value.splice(index, 1);
} }
function removeAudio() { function removeAudio() {
stagedAudioUrl.value = ''; stagedAudioUrl.value = '';
} }
function removeFile(index: number) {
// 找到第 index 个非图片类型的文件
let fileCount = 0;
for (let i = 0; i < stagedFiles.value.length; i++) {
if (stagedFiles.value[i].type !== 'image') {
if (fileCount === index) {
const fileToRemove = stagedFiles.value[i];
if (fileToRemove.url.startsWith('blob:')) {
URL.revokeObjectURL(fileToRemove.url);
}
stagedFiles.value.splice(i, 1);
return;
}
fileCount++;
}
}
}
function clearStaged() { function clearStaged() {
stagedImagesName.value = [];
stagedImagesUrl.value = [];
stagedAudioUrl.value = ''; stagedAudioUrl.value = '';
// 清理文件的 blob URLs
stagedFiles.value.forEach(file => {
if (file.url.startsWith('blob:')) {
URL.revokeObjectURL(file.url);
}
});
stagedFiles.value = [];
} }
function cleanupMediaCache() { function cleanupMediaCache() {
@@ -89,15 +156,28 @@ export function useMediaHandling() {
mediaCache.value = {}; mediaCache.value = {};
} }
// 计算属性:获取图片的 URL 列表(用于预览)
const stagedImagesUrl = computed(() =>
stagedFiles.value.filter(f => f.type === 'image').map(f => f.url)
);
// 计算属性:获取非图片文件列表
const stagedNonImageFiles = computed(() =>
stagedFiles.value.filter(f => f.type !== 'image')
);
return { return {
stagedImagesName,
stagedImagesUrl, stagedImagesUrl,
stagedAudioUrl, stagedAudioUrl,
stagedFiles,
stagedNonImageFiles,
getMediaFile, getMediaFile,
processAndUploadImage, processAndUploadImage,
processAndUploadFile,
handlePaste, handlePaste,
removeImage, removeImage,
removeAudio, removeAudio,
removeFile,
clearStaged, clearStaged,
cleanupMediaCache cleanupMediaCache
}; };
+161 -46
View File
@@ -2,19 +2,37 @@ import { ref, reactive, type Ref } from 'vue';
import axios from 'axios'; import axios from 'axios';
import { useToast } from '@/utils/toast'; import { useToast } from '@/utils/toast';
// 新格式消息部分的类型定义
export interface MessagePart {
type: 'plain' | 'image' | 'record' | 'file' | 'video';
text?: string; // for plain
attachment_id?: string; // for image, record, file, video
filename?: string; // for file (filename from backend)
}
// 文件信息结构
export interface FileInfo {
url?: string; // blob URL (可选,点击时才加载)
filename: string;
attachment_id?: string; // 用于按需下载
}
export interface MessageContent { export interface MessageContent {
type: string; type: string;
message: string; message: string | MessagePart[]; // 支持旧格式(string)和新格式(MessagePart[])
reasoning?: string; reasoning?: string;
image_url?: string[]; image_url?: string[];
audio_url?: string; audio_url?: string;
file_url?: FileInfo[];
embedded_images?: string[]; embedded_images?: string[];
embedded_audio?: string; embedded_audio?: string;
embedded_files?: FileInfo[];
isLoading?: boolean; isLoading?: boolean;
} }
export interface Message { export interface Message {
content: MessageContent; content: MessageContent;
created_at?: string;
} }
export function useMessages( export function useMessages(
@@ -29,6 +47,7 @@ export function useMessages(
const isToastedRunningInfo = ref(false); const isToastedRunningInfo = ref(false);
const activeSSECount = ref(0); const activeSSECount = ref(0);
const enableStreaming = ref(true); const enableStreaming = ref(true);
const attachmentCache = new Map<string, string>(); // attachment_id -> blob URL
// 从 localStorage 读取流式响应开关状态 // 从 localStorage 读取流式响应开关状态
const savedStreamingState = localStorage.getItem('enableStreaming'); const savedStreamingState = localStorage.getItem('enableStreaming');
@@ -41,6 +60,68 @@ export function useMessages(
localStorage.setItem('enableStreaming', JSON.stringify(enableStreaming.value)); localStorage.setItem('enableStreaming', JSON.stringify(enableStreaming.value));
} }
// 获取 attachment 文件并返回 blob URL
async function getAttachment(attachmentId: string): Promise<string> {
if (attachmentCache.has(attachmentId)) {
return attachmentCache.get(attachmentId)!;
}
try {
const response = await axios.get(`/api/chat/get_attachment?attachment_id=${attachmentId}`, {
responseType: 'blob'
});
const blobUrl = URL.createObjectURL(response.data);
attachmentCache.set(attachmentId, blobUrl);
return blobUrl;
} catch (err) {
console.error('Failed to get attachment:', attachmentId, err);
return '';
}
}
// 解析新格式消息为旧格式兼容的结构 (用于显示)
async function parseMessageContent(content: any): Promise<void> {
const message = content.message;
// 如果 message 是数组 (新格式)
if (Array.isArray(message)) {
let textParts: string[] = [];
let imageUrls: string[] = [];
let audioUrl: string | undefined;
let fileInfos: FileInfo[] = [];
for (const part of message as MessagePart[]) {
if (part.type === 'plain' && part.text) {
textParts.push(part.text);
} else if (part.type === 'image' && part.attachment_id) {
const url = await getAttachment(part.attachment_id);
if (url) imageUrls.push(url);
} else if (part.type === 'record' && part.attachment_id) {
audioUrl = await getAttachment(part.attachment_id);
} else if (part.type === 'file' && part.attachment_id) {
// file 类型不预加载,保留 attachment_id 以便点击时下载
fileInfos.push({
attachment_id: part.attachment_id,
filename: part.filename || 'file'
});
}
// video 类型可以后续扩展
}
// 转换为旧格式兼容的结构
content.message = textParts.join('\n');
if (content.type === 'user') {
content.image_url = imageUrls.length > 0 ? imageUrls : undefined;
content.audio_url = audioUrl;
content.file_url = fileInfos.length > 0 ? fileInfos : undefined;
} else {
content.embedded_images = imageUrls.length > 0 ? imageUrls : undefined;
content.embedded_audio = audioUrl;
content.embedded_files = fileInfos.length > 0 ? fileInfos : undefined;
}
}
// 如果 message 是字符串 (旧格式),保持原有处理逻辑
}
async function getSessionMessages(sessionId: string, router: any) { async function getSessionMessages(sessionId: string, router: any) {
if (!sessionId) return; if (!sessionId) return;
@@ -64,35 +145,45 @@ export function useMessages(
// 处理历史消息中的媒体文件 // 处理历史消息中的媒体文件
for (let i = 0; i < history.length; i++) { for (let i = 0; i < history.length; i++) {
let content = history[i].content; let content = history[i].content;
if (content.message?.startsWith('[IMAGE]')) { // 首先尝试解析新格式消息
let img = content.message.replace('[IMAGE]', ''); await parseMessageContent(content);
const imageUrl = await getMediaFile(img);
if (!content.embedded_images) { // 以下是旧格式的兼容处理 (message 是字符串的情况)
content.embedded_images = []; if (typeof content.message === 'string') {
if (content.message?.startsWith('[IMAGE]')) {
let img = content.message.replace('[IMAGE]', '');
const imageUrl = await getMediaFile(img);
if (!content.embedded_images) {
content.embedded_images = [];
}
content.embedded_images.push(imageUrl);
content.message = '';
}
if (content.message?.startsWith('[RECORD]')) {
let audio = content.message.replace('[RECORD]', '');
const audioUrl = await getMediaFile(audio);
content.embedded_audio = audioUrl;
content.message = '';
} }
content.embedded_images.push(imageUrl);
content.message = '';
}
if (content.message?.startsWith('[RECORD]')) {
let audio = content.message.replace('[RECORD]', '');
const audioUrl = await getMediaFile(audio);
content.embedded_audio = audioUrl;
content.message = '';
} }
// 旧格式中的 image_url 和 audio_url 字段处理
if (content.image_url && content.image_url.length > 0) { if (content.image_url && content.image_url.length > 0) {
for (let j = 0; j < content.image_url.length; j++) { for (let j = 0; j < content.image_url.length; j++) {
content.image_url[j] = await getMediaFile(content.image_url[j]); // 检查是否已经是 blob URL (新格式解析后的结果)
if (!content.image_url[j].startsWith('blob:')) {
content.image_url[j] = await getMediaFile(content.image_url[j]);
}
} }
} }
if (content.audio_url) { if (content.audio_url && !content.audio_url.startsWith('blob:')) {
content.audio_url = await getMediaFile(content.audio_url); content.audio_url = await getMediaFile(content.audio_url);
} }
} }
messages.value = history; messages.value = history;
} catch (err) { } catch (err) {
console.error(err); console.error(err);
@@ -101,7 +192,7 @@ export function useMessages(
async function sendMessage( async function sendMessage(
prompt: string, prompt: string,
imageNames: string[], stagedFiles: { attachment_id: string; url: string; original_name: string; type: string }[],
audioName: string, audioName: string,
selectedProviderId: string, selectedProviderId: string,
selectedModelName: string selectedModelName: string
@@ -111,27 +202,33 @@ export function useMessages(
type: 'user', type: 'user',
message: prompt, message: prompt,
image_url: [], image_url: [],
audio_url: undefined audio_url: undefined,
file_url: []
}; };
// Convert image filenames to blob URLs // 分离图片和文件
if (imageNames.length > 0) { const imageFiles = stagedFiles.filter(f => f.type === 'image');
const imagePromises = imageNames.map(name => { const nonImageFiles = stagedFiles.filter(f => f.type !== 'image');
if (!name.startsWith('blob:')) {
return getMediaFile(name); // 使用 attachment_id 获取图片内容(避免 blob URL 被 revoke 后 404
} if (imageFiles.length > 0) {
return Promise.resolve(name); const imageUrls = await Promise.all(
}); imageFiles.map(f => getAttachment(f.attachment_id))
userMessage.image_url = await Promise.all(imagePromises); );
userMessage.image_url = imageUrls.filter(url => url !== '');
} }
// Convert audio filename to blob URL // 使用 blob URL 作为音频预览(录音不走 attachment)
if (audioName) { if (audioName) {
if (!audioName.startsWith('blob:')) { userMessage.audio_url = audioName;
userMessage.audio_url = await getMediaFile(audioName); }
} else {
userMessage.audio_url = audioName; // 文件不预加载,只显示文件名和 attachment_id
} if (nonImageFiles.length > 0) {
userMessage.file_url = nonImageFiles.map(f => ({
filename: f.original_name,
attachment_id: f.attachment_id
}));
} }
messages.value.push({ content: userMessage }); messages.value.push({ content: userMessage });
@@ -151,6 +248,9 @@ export function useMessages(
isConvRunning.value = true; isConvRunning.value = true;
} }
// 收集所有 attachment_id
const files = stagedFiles.map(f => f.attachment_id);
const response = await fetch('/api/chat/send', { const response = await fetch('/api/chat/send', {
method: 'POST', method: 'POST',
headers: { headers: {
@@ -160,8 +260,7 @@ export function useMessages(
body: JSON.stringify({ body: JSON.stringify({
message: prompt, message: prompt,
session_id: currSessionId.value, session_id: currSessionId.value,
image_url: imageNames, files: files,
audio_url: audioName ? [audioName] : [],
selected_provider: selectedProviderId, selected_provider: selectedProviderId,
selected_model: selectedModelName, selected_model: selectedModelName,
enable_streaming: enableStreaming.value enable_streaming: enableStreaming.value
@@ -207,6 +306,11 @@ export function useMessages(
continue; continue;
} }
const lastMsg = messages.value[messages.value.length - 1];
if (lastMsg?.content?.isLoading) {
messages.value.pop();
}
if (chunk_json.type === 'error') { if (chunk_json.type === 'error') {
console.error('Error received:', chunk_json.data); console.error('Error received:', chunk_json.data);
continue; continue;
@@ -230,16 +334,26 @@ export function useMessages(
embedded_audio: audioUrl embedded_audio: audioUrl
}; };
messages.value.push({ content: bot_resp }); messages.value.push({ content: bot_resp });
} else if (chunk_json.type === 'file') {
// 格式: [FILE]filename|original_name
let fileData = chunk_json.data.replace('[FILE]', '');
let [filename, originalName] = fileData.includes('|')
? fileData.split('|', 2)
: [fileData, fileData];
const fileUrl = await getMediaFile(filename);
let bot_resp: MessageContent = {
type: 'bot',
message: '',
embedded_files: [{
url: fileUrl,
filename: originalName
}]
};
messages.value.push({ content: bot_resp });
} else if (chunk_json.type === 'plain') { } else if (chunk_json.type === 'plain') {
const chain_type = chunk_json.chain_type || 'normal'; const chain_type = chunk_json.chain_type || 'normal';
if (!in_streaming) { if (!in_streaming) {
// 移除加载占位符
const lastMsg = messages.value[messages.value.length - 1];
if (lastMsg?.content?.isLoading) {
messages.value.pop();
}
message_obj = reactive({ message_obj = reactive({
type: 'bot', type: 'bot',
message: chain_type === 'reasoning' ? '' : chunk_json.data, message: chain_type === 'reasoning' ? '' : chunk_json.data,
@@ -298,7 +412,8 @@ export function useMessages(
enableStreaming, enableStreaming,
getSessionMessages, getSessionMessages,
sendMessage, sendMessage,
toggleStreaming toggleStreaming,
getAttachment
}; };
} }
@@ -71,6 +71,10 @@
"reasoning": { "reasoning": {
"thinking": "Thinking Process" "thinking": "Thinking Process"
}, },
"time": {
"today": "Today",
"yesterday": "Yesterday"
},
"connection": { "connection": {
"title": "Connection Status Notice", "title": "Connection Status Notice",
"message": "The system detected that the chat connection needs to be re-established.", "message": "The system detected that the chat connection needs to be re-established.",
@@ -109,22 +109,6 @@
} }
} }
}, },
"file_extract": {
"description": "File Extract",
"provider_settings": {
"file_extract": {
"enable": {
"description": "Enable File Extract"
},
"provider": {
"description": "File Extract Provider"
},
"moonshotai_api_key": {
"description": "Moonshot AI API Key"
}
}
}
},
"others": { "others": {
"description": "Other Settings", "description": "Other Settings",
"provider_settings": { "provider_settings": {
@@ -175,10 +159,6 @@
"prompt_prefix": { "prompt_prefix": {
"description": "User Prompt", "description": "User Prompt",
"hint": "You can use {{prompt}} as a placeholder for user input. If no placeholder is provided, it will be added before the user input." "hint": "You can use {{prompt}} as a placeholder for user input. If no placeholder is provided, it will be added before the user input."
},
"reachability_check": {
"description": "Provider Reachability Check",
"hint": "When running the /provider command, test provider connectivity in parallel. This actively pings models and may consume extra tokens."
} }
}, },
"provider_tts_settings": { "provider_tts_settings": {
@@ -399,11 +379,7 @@
}, },
"image_caption": { "image_caption": {
"description": "Auto-understand Images", "description": "Auto-understand Images",
"hint": "Requires setting a group chat image caption model." "hint": "Requires setting a default image caption model."
},
"image_caption_provider_id": {
"description": "Group Chat Image Caption Model",
"hint": "Used for image understanding in group chat context awareness, configured separately from the default image caption model."
}, },
"active_reply": { "active_reply": {
"enable": { "enable": {
@@ -473,4 +449,4 @@
} }
} }
} }
} }
@@ -32,8 +32,7 @@
"actions": "Actions", "actions": "Actions",
"back": "Back", "back": "Back",
"selectFile": "Select File", "selectFile": "Select File",
"refresh": "Refresh", "refresh": "Refresh"
"updateAll": "Update All"
}, },
"status": { "status": {
"enabled": "Enabled", "enabled": "Enabled",
@@ -142,9 +141,7 @@
"confirmDelete": "Are you sure you want to delete this extension?", "confirmDelete": "Are you sure you want to delete this extension?",
"fillUrlOrFile": "Please fill in extension URL or upload extension file", "fillUrlOrFile": "Please fill in extension URL or upload extension file",
"dontFillBoth": "Please don't fill in both extension URL and upload file", "dontFillBoth": "Please don't fill in both extension URL and upload file",
"supportedFormats": "Supports .zip extension files", "supportedFormats": "Supports .zip extension files"
"updateAllSuccess": "All upgradable extensions have been updated!",
"updateAllFailed": "{failed} of {total} extensions failed to update:"
}, },
"upload": { "upload": {
"fromFile": "Install from File", "fromFile": "Install from File",
@@ -71,6 +71,10 @@
"reasoning": { "reasoning": {
"thinking": "思考过程" "thinking": "思考过程"
}, },
"time": {
"today": "今天",
"yesterday": "昨天"
},
"connection": { "connection": {
"title": "连接状态提醒", "title": "连接状态提醒",
"message": "系统检测到聊天连接需要重新建立。", "message": "系统检测到聊天连接需要重新建立。",
@@ -11,12 +11,7 @@
}, },
"agent_runner_type": { "agent_runner_type": {
"description": "执行器", "description": "执行器",
"labels": [ "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"]
"内置 Agent",
"Dify",
"Coze",
"阿里云百炼应用"
]
}, },
"coze_agent_runner_provider_id": { "coze_agent_runner_provider_id": {
"description": "Coze Agent 执行器提供商 ID" "description": "Coze Agent 执行器提供商 ID"
@@ -114,22 +109,6 @@
} }
} }
}, },
"file_extract": {
"description": "文档解析能力",
"provider_settings": {
"file_extract": {
"enable": {
"description": "启用文档解析能力"
},
"provider": {
"description": "文档解析提供商"
},
"moonshotai_api_key": {
"description": "Moonshot AI API Key"
}
}
}
},
"others": { "others": {
"description": "其他配置", "description": "其他配置",
"provider_settings": { "provider_settings": {
@@ -163,10 +142,7 @@
"unsupported_streaming_strategy": { "unsupported_streaming_strategy": {
"description": "不支持流式回复的平台", "description": "不支持流式回复的平台",
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容", "hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
"labels": [ "labels": ["实时分段回复", "关闭流式回复"]
"实时分段回复",
"关闭流式回复"
]
}, },
"max_context_length": { "max_context_length": {
"description": "最多携带对话轮数", "description": "最多携带对话轮数",
@@ -183,10 +159,6 @@
"prompt_prefix": { "prompt_prefix": {
"description": "用户提示词", "description": "用户提示词",
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。" "hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。"
},
"reachability_check": {
"description": "提供商可达性检测",
"hint": "/provider 命令列出模型时并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。"
} }
}, },
"provider_tts_settings": { "provider_tts_settings": {
@@ -407,11 +379,7 @@
}, },
"image_caption": { "image_caption": {
"description": "自动理解图片", "description": "自动理解图片",
"hint": "需要设置群聊图片转述模型。" "hint": "需要设置默认图片转述模型。"
},
"image_caption_provider_id": {
"description": "群聊图片转述模型",
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。"
}, },
"active_reply": { "active_reply": {
"enable": { "enable": {
@@ -481,4 +449,4 @@
} }
} }
} }
} }
@@ -32,8 +32,7 @@
"actions": "操作", "actions": "操作",
"back": "返回", "back": "返回",
"selectFile": "选择文件", "selectFile": "选择文件",
"refresh": "刷新", "refresh": "刷新"
"updateAll": "更新全部插件"
}, },
"status": { "status": {
"enabled": "启用", "enabled": "启用",
@@ -142,9 +141,7 @@
"confirmDelete": "确定要删除插件吗?", "confirmDelete": "确定要删除插件吗?",
"fillUrlOrFile": "请填写插件链接或上传插件文件", "fillUrlOrFile": "请填写插件链接或上传插件文件",
"dontFillBoth": "请不要同时填写插件链接和上传文件", "dontFillBoth": "请不要同时填写插件链接和上传文件",
"supportedFormats": "支持 .zip 格式的插件文件", "supportedFormats": "支持 .zip 格式的插件文件"
"updateAllSuccess": "所有可更新的插件都已更新!",
"updateAllFailed": "有 {failed}/{total} 个插件更新失败:"
}, },
"upload": { "upload": {
"fromFile": "从文件安装", "fromFile": "从文件安装",
+8 -66
View File
@@ -42,7 +42,6 @@ const loadingDialog = reactive({
const showPluginInfoDialog = ref(false); const showPluginInfoDialog = ref(false);
const selectedPlugin = ref({}); const selectedPlugin = ref({});
const curr_namespace = ref(""); const curr_namespace = ref("");
const updatingAll = ref(false);
const readmeDialog = reactive({ const readmeDialog = reactive({
show: false, show: false,
@@ -138,10 +137,11 @@ const pluginMarketHeaders = computed(() => [
// //
const filteredExtensions = computed(() => { const filteredExtensions = computed(() => {
const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
if (!showReserved.value) { if (!showReserved.value) {
return extension_data?.data?.filter(ext => !ext.reserved) || []; return data.filter(ext => !ext.reserved);
} }
return extension_data.data || []; return data;
}); });
// //
@@ -227,10 +227,6 @@ const paginatedPlugins = computed(() => {
return sortedPlugins.value.slice(start, end); return sortedPlugins.value.slice(start, end);
}); });
const updatableExtensions = computed(() => {
return extension_data?.data?.filter(ext => ext.has_update) || [];
});
// //
const toggleShowReserved = () => { const toggleShowReserved = () => {
showReserved.value = !showReserved.value; showReserved.value = !showReserved.value;
@@ -280,7 +276,8 @@ const checkUpdate = () => {
onlinePluginsNameMap.set(plugin.name, plugin); onlinePluginsNameMap.set(plugin.name, plugin);
}); });
extension_data.data.forEach(extension => { const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
data.forEach(extension => {
const repoKey = extension.repo?.toLowerCase(); const repoKey = extension.repo?.toLowerCase();
const onlinePlugin = repoKey ? onlinePluginsMap.get(repoKey) : null; const onlinePlugin = repoKey ? onlinePluginsMap.get(repoKey) : null;
const onlinePluginByName = onlinePluginsNameMap.get(extension.name); const onlinePluginByName = onlinePluginsNameMap.get(extension.name);
@@ -377,56 +374,6 @@ const updateExtension = async (extension_name) => {
} }
}; };
const updateAllExtensions = async () => {
if (updatingAll.value || updatableExtensions.value.length === 0) return;
updatingAll.value = true;
loadingDialog.title = tm('status.loading');
loadingDialog.statusCode = 0;
loadingDialog.result = "";
loadingDialog.show = true;
const targets = updatableExtensions.value.map(ext => ext.name);
try {
const res = await axios.post('/api/plugin/update-all', {
names: targets,
proxy: localStorage.getItem('selectedGitHubProxy') || ""
});
if (res.data.status === "error") {
onLoadingDialogResult(2, res.data.message || tm('messages.updateAllFailed', {
failed: targets.length,
total: targets.length
}), -1);
return;
}
const results = res.data.data?.results || [];
const failures = results.filter(r => r.status !== 'ok');
try {
await getExtensions();
} catch (err) {
const errorMsg = err.response?.data?.message || err.message || String(err);
failures.push({ name: 'refresh', status: 'error', message: errorMsg });
}
if (failures.length === 0) {
onLoadingDialogResult(1, tm('messages.updateAllSuccess'));
} else {
const failureText = tm('messages.updateAllFailed', {
failed: failures.length,
total: targets.length
});
const detail = failures.map(f => `${f.name}: ${f.message}`).join('\n');
onLoadingDialogResult(2, `${failureText}\n${detail}`, -1);
}
} catch (err) {
const errorMsg = err.response?.data?.message || err.message || String(err);
onLoadingDialogResult(2, errorMsg, -1);
} finally {
updatingAll.value = false;
}
};
const pluginOn = async (extension) => { const pluginOn = async (extension) => {
try { try {
const res = await axios.post('/api/plugin/on', { name: extension.name }); const res = await axios.post('/api/plugin/on', { name: extension.name });
@@ -562,8 +509,9 @@ const trimExtensionName = () => {
}; };
const checkAlreadyInstalled = () => { const checkAlreadyInstalled = () => {
const installedRepos = new Set(extension_data.data.map(ext => ext.repo?.toLowerCase())); const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
const installedNames = new Set(extension_data.data.map(ext => ext.name)); const installedRepos = new Set(data.map(ext => ext.repo?.toLowerCase()));
const installedNames = new Set(data.map(ext => ext.name));
for (let i = 0; i < pluginMarketData.value.length; i++) { for (let i = 0; i < pluginMarketData.value.length; i++) {
const plugin = pluginMarketData.value[i]; const plugin = pluginMarketData.value[i];
@@ -775,12 +723,6 @@ watch(marketSearch, (newVal) => {
{{ showReserved ? tm('buttons.hideSystemPlugins') : tm('buttons.showSystemPlugins') }} {{ showReserved ? tm('buttons.hideSystemPlugins') : tm('buttons.showSystemPlugins') }}
</v-btn> </v-btn>
<v-btn class="ml-2" color="warning" variant="tonal" :disabled="updatableExtensions.length === 0"
:loading="updatingAll" @click="updateAllExtensions">
<v-icon>mdi-update</v-icon>
{{ tm('buttons.updateAll') }}
</v-btn>
<v-btn class="ml-2" color="primary" variant="tonal" @click="dialog = true"> <v-btn class="ml-2" color="primary" variant="tonal" @click="dialog = true">
<v-icon>mdi-plus</v-icon> <v-icon>mdi-plus</v-icon>
{{ tm('buttons.install') }} {{ tm('buttons.install') }}
+1 -2
View File
@@ -35,7 +35,7 @@
</div> </div>
<!-- 日志部分 --> <!-- 日志部分 -->
<v-card elevation="0" class="mt-4 mb-10"> <v-card elevation="0" class="mt-4">
<v-card-title class="d-flex align-center py-3 px-4"> <v-card-title class="d-flex align-center py-3 px-4">
<v-icon class="me-2">mdi-console-line</v-icon> <v-icon class="me-2">mdi-console-line</v-icon>
<span class="text-h4">{{ tm('logs.title') }}</span> <span class="text-h4">{{ tm('logs.title') }}</span>
@@ -233,6 +233,5 @@ export default {
.platform-page { .platform-page {
padding: 20px; padding: 20px;
padding-top: 8px; padding-top: 8px;
padding-bottom: 40px;
} }
</style> </style>
+1 -2
View File
@@ -148,7 +148,7 @@
</div> </div>
<!-- 日志部分 --> <!-- 日志部分 -->
<v-card elevation="0" class="mt-4 mb-10"> <v-card elevation="0" class="mt-4">
<v-card-title class="d-flex align-center py-3 px-4"> <v-card-title class="d-flex align-center py-3 px-4">
<v-icon class="me-2">mdi-console-line</v-icon> <v-icon class="me-2">mdi-console-line</v-icon>
<span class="text-h4">{{ tm('logs.title') }}</span> <span class="text-h4">{{ tm('logs.title') }}</span>
@@ -849,7 +849,6 @@ export default {
.provider-page { .provider-page {
padding: 20px; padding: 20px;
padding-top: 8px; padding-top: 8px;
padding-bottom: 40px;
} }
.status-card { .status-card {
+19 -139
View File
@@ -1,7 +1,5 @@
import asyncio
import re import re
from astrbot import logger
from astrbot.api import star from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.entities import ProviderType
@@ -11,39 +9,6 @@ class ProviderCommands:
def __init__(self, context: star.Context): def __init__(self, context: star.Context):
self.context = context self.context = context
def _log_reachability_failure(
self,
provider,
provider_capability_type: ProviderType | None,
err_code: str,
err_reason: str,
):
"""记录不可达原因到日志。"""
meta = provider.meta()
logger.warning(
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
meta.id,
provider_capability_type.name if provider_capability_type else "unknown",
err_code,
err_reason,
)
async def _test_provider_capability(self, provider):
"""测试单个 provider 的可用性"""
meta = provider.meta()
provider_capability_type = meta.provider_type
try:
await provider.test()
return True, None, None
except Exception as e:
err_code = "TEST_FAILED"
err_reason = str(e)
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
async def provider( async def provider(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
@@ -52,131 +17,46 @@ class ProviderCommands:
): ):
"""查看或者切换 LLM Provider""" """查看或者切换 LLM Provider"""
umo = event.unified_msg_origin umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {})
reachability_check_enabled = cfg.get("reachability_check", True)
if idx is None: if idx is None:
parts = ["## 载入的 LLM 提供商\n"] parts = ["## 载入的 LLM 提供商\n"]
for idx, llm in enumerate(self.context.get_all_providers()):
# 获取所有类型的提供商 id_ = llm.meta().id
llms = list(self.context.get_all_providers()) line = f"{idx + 1}. {id_} ({llm.meta().model})"
ttss = self.context.get_all_tts_providers()
stts = self.context.get_all_stt_providers()
# 构造待检测列表: [(provider, type_label), ...]
all_providers = []
all_providers.extend([(p, "llm") for p in llms])
all_providers.extend([(p, "tts") for p in ttss])
all_providers.extend([(p, "stt") for p in stts])
# 并发测试连通性
if reachability_check_enabled:
if all_providers:
await event.send(
MessageEventResult().message(
"正在进行提供商可达性测试,请稍候..."
)
)
check_results = await asyncio.gather(
*[self._test_provider_capability(p) for p, _ in all_providers],
return_exceptions=True,
)
else:
# 用 None 表示未检测
check_results = [None for _ in all_providers]
# 整合结果
display_data = []
for (p, p_type), reachable in zip(all_providers, check_results):
meta = p.meta()
id_ = meta.id
error_code = None
if isinstance(reachable, Exception):
# 异常情况下兜底处理,避免单个 provider 导致列表失败
self._log_reachability_failure(
p,
None,
reachable.__class__.__name__,
str(reachable),
)
reachable_flag = False
error_code = reachable.__class__.__name__
elif isinstance(reachable, tuple):
reachable_flag, error_code, _ = reachable
else:
reachable_flag = reachable
# 根据类型构建显示名称
if p_type == "llm":
info = f"{id_} ({meta.model})"
else:
info = f"{id_}"
# 确定状态标记
if reachable_flag is True:
mark = ""
elif reachable_flag is False:
if error_code:
mark = f" ❌(错误码: {error_code})"
else:
mark = ""
else:
mark = "" # 不支持检测时不显示标记
display_data.append(
{
"type": p_type,
"info": info,
"mark": mark,
"provider": p,
}
)
# 分组输出
# 1. LLM
llm_data = [d for d in display_data if d["type"] == "llm"]
for i, d in enumerate(llm_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
provider_using = self.context.get_using_provider(umo=umo) provider_using = self.context.get_using_provider(umo=umo)
if ( if provider_using and provider_using.meta().id == id_:
provider_using
and provider_using.meta().id == d["provider"].meta().id
):
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
# 2. TTS tts_providers = self.context.get_all_tts_providers()
tts_data = [d for d in display_data if d["type"] == "tts"] if tts_providers:
if tts_data:
parts.append("\n## 载入的 TTS 提供商\n") parts.append("\n## 载入的 TTS 提供商\n")
for i, d in enumerate(tts_data): for idx, tts in enumerate(tts_providers):
line = f"{i + 1}. {d['info']}{d['mark']}" id_ = tts.meta().id
line = f"{idx + 1}. {id_}"
tts_using = self.context.get_using_tts_provider(umo=umo) tts_using = self.context.get_using_tts_provider(umo=umo)
if tts_using and tts_using.meta().id == d["provider"].meta().id: if tts_using and tts_using.meta().id == id_:
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
# 3. STT stt_providers = self.context.get_all_stt_providers()
stt_data = [d for d in display_data if d["type"] == "stt"] if stt_providers:
if stt_data:
parts.append("\n## 载入的 STT 提供商\n") parts.append("\n## 载入的 STT 提供商\n")
for i, d in enumerate(stt_data): for idx, stt in enumerate(stt_providers):
line = f"{i + 1}. {d['info']}{d['mark']}" id_ = stt.meta().id
line = f"{idx + 1}. {id_}"
stt_using = self.context.get_using_stt_provider(umo=umo) stt_using = self.context.get_using_stt_provider(umo=umo)
if stt_using and stt_using.meta().id == d["provider"].meta().id: if stt_using and stt_using.meta().id == id_:
line += " (当前使用)" line += " (当前使用)"
parts.append(line + "\n") parts.append(line + "\n")
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。") parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
ret = "".join(parts) ret = "".join(parts)
if ttss: if tts_providers:
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
if stts: if stt_providers:
ret += "\n使用 /provider stt <序号> 切换 STT 提供商。" ret += "\n使用 /provider stt <切换> STT 提供商。"
if not reachability_check_enabled:
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
event.set_result(MessageEventResult().message(ret)) event.set_result(MessageEventResult().message(ret))
elif idx == "tts": elif idx == "tts":
+16 -19
View File
@@ -8,7 +8,7 @@ from astrbot.api import star
from astrbot.api.event import AstrMessageEvent from astrbot.api.event import AstrMessageEvent
from astrbot.api.message_components import At, Image, Plain from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import MessageType from astrbot.api.platform import MessageType
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
""" """
@@ -30,13 +30,16 @@ class LongTermMemory:
except BaseException as e: except BaseException as e:
logger.error(e) logger.error(e)
max_cnt = 300 max_cnt = 300
image_caption = (
True
if cfg["provider_settings"]["default_image_caption_provider_id"]
and cfg["provider_ltm_settings"]["image_caption"]
else False
)
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"] image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
image_caption_provider_id = cfg["provider_ltm_settings"].get( image_caption_provider_id = cfg["provider_settings"][
"image_caption_provider_id" "default_image_caption_provider_id"
) ]
image_caption = cfg["provider_ltm_settings"]["image_caption"] and bool(
image_caption_provider_id
)
active_reply = cfg["provider_ltm_settings"]["active_reply"] active_reply = cfg["provider_ltm_settings"]["active_reply"]
enable_active_reply = active_reply.get("enable", False) enable_active_reply = active_reply.get("enable", False)
ar_method = active_reply["method"] ar_method = active_reply["method"]
@@ -158,12 +161,8 @@ class LongTermMemory:
cfg = self.cfg(event) cfg = self.cfg(event)
if cfg["enable_active_reply"]: if cfg["enable_active_reply"]:
prompt = req.prompt prompt = req.prompt
req.prompt = ( req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
f"You are now in a chatroom. The chat history is as follows:\n{chats_str}" req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
f"\nNow, a new message is coming: `{prompt}`. "
"Please react to it. Only output your response and do not output any other information. "
"You MUST use the SAME language as the chatroom is using."
)
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。 req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
else: else:
req.system_prompt += ( req.system_prompt += (
@@ -171,15 +170,13 @@ class LongTermMemory:
) )
req.system_prompt += chats_str req.system_prompt += chats_str
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse): async def after_req_llm(self, event: AstrMessageEvent):
if event.unified_msg_origin not in self.session_chats: if event.unified_msg_origin not in self.session_chats:
return return
if llm_resp.completion_text: if event.get_result() and event.get_result().is_llm_result():
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}" final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
logger.debug( logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
)
self.session_chats[event.unified_msg_origin].append(final_message) self.session_chats[event.unified_msg_origin].append(final_message)
cfg = self.cfg(event) cfg = self.cfg(event)
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
+5 -2
View File
@@ -322,7 +322,7 @@ class Main(star.Star):
@filter.on_llm_response() @filter.on_llm_response()
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse): async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话""" """在 LLM 响应后基于配置注入思考过程文本"""
umo = event.unified_msg_origin umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {}) cfg = self.context.get_config(umo).get("provider_settings", {})
show_reasoning = cfg.get("display_reasoning_text", False) show_reasoning = cfg.get("display_reasoning_text", False)
@@ -331,9 +331,12 @@ class Main(star.Star):
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}" f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
) )
@filter.after_message_sent()
async def after_llm_req(self, event: AstrMessageEvent):
"""在 LLM 请求后记录对话"""
if self.ltm and self.ltm_enabled(event): if self.ltm and self.ltm_enabled(event):
try: try:
await self.ltm.after_req_llm(event, resp) await self.ltm.after_req_llm(event)
except Exception as e: except Exception as e:
logger.error(f"ltm: {e}") logger.error(f"ltm: {e}")
+1 -1
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "AstrBot" name = "AstrBot"
version = "4.7.4" version = "4.7.1"
description = "Easy-to-use multi-platform LLM chatbot and development framework" description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"