Compare commits

...

45 Commits

Author SHA1 Message Date
Soulter
0beb11bade feat(chat): add error handling for message sending and session creation 2025-11-24 22:05:46 +08:00
Soulter
84c459dd77 feat(chat): add standalone chat component and integrate with config page for testing configurations 2025-11-24 16:07:42 +08:00
Soulter
3932b8f982 Merge pull request #3760 from AstrBotDevs/feat/agent-runner
refactor: transfer dify, coze and alibaba dashscope from chat provider to agent runner
2025-11-24 15:33:20 +08:00
Soulter
82488ca900 feat(api): enhance file upload method to support mime type and file name 2025-11-24 15:30:49 +08:00
Soulter
29d9b9b2d6 feat(config): add condition for display_reasoning_text based on agent_runner_type 2025-11-24 15:10:17 +08:00
Soulter
02215e9b7b feat(config): update hint for agent_runner execution method to clarify third-party integration 2025-11-24 15:07:33 +08:00
Soulter
7160b7a18b fix: dify workflow streaming mode 2025-11-24 15:04:15 +08:00
Soulter
ea8dac837a feat(config): enhance hint for agent_runner execution method in configuration 2025-11-24 14:42:36 +08:00
Soulter
e2a7a028bd feat(migration): enhance migration process with error handling and agent runner config updates 2025-11-24 14:37:25 +08:00
Soulter
70db8d264b fix(config): disable auto_save_history option in configuration 2025-11-24 14:25:14 +08:00
Soulter
0518e6d487 feat(config): add hint for agent_runner execution method in configuration 2025-11-24 14:23:53 +08:00
Soulter
39eb367866 perf: improve file structure
- Implemented CozeAPIClient for file upload, image download, chat messaging, and context management.
- Developed DashscopeAgentRunner for handling requests to the Dashscope API with streaming support.
- Created DifyAgentRunner to manage interactions with the Dify API, including file uploads and workflow execution.
- Introduced DifyAPIClient for making asynchronous requests to the Dify API.
- Updated third-party agent imports to reflect new module structure.
2025-11-24 14:00:16 +08:00
Soulter
f1d51a22ad feat(dashscope_agent_runner): refactor request payload construction and enhance streaming response handling 2025-11-24 13:21:34 +08:00
Soulter
77fb554e8f feat(dashscope_agent_runner): implement streaming response handling and request payload construction 2025-11-24 13:09:57 +08:00
Soulter
91f8a0ae09 fix(provider_manager): use get method for provider_type check in load_provider 2025-11-24 10:57:13 +08:00
Soulter
370cda7cf0 feat(dify_api_client): add docstring for file_upload method 2025-11-24 10:53:50 +08:00
Soulter
66b3eed273 fix: correct typo in agent state transition log message 2025-11-24 00:03:22 +08:00
Soulter
99b061a143 fix: make session properties required in Session interface 2025-11-23 23:25:29 +08:00
Soulter
5f3c7ed673 feat(conversation): update agent runner type configuration path to provider_settings 2025-11-23 23:05:36 +08:00
Soulter
a6dc458212 feat(third-party-agent): implement streaming response handling and enhance agent execution flow 2025-11-23 23:03:56 +08:00
Soulter
520f521887 feat(provider): enhance agent runner provider selection with subtype filtering 2025-11-23 22:23:23 +08:00
Soulter
01427d9969 feat(config): add hint for non-built-in agent execution model configuration 2025-11-23 22:13:52 +08:00
Soulter
34c03ce983 Merge remote-tracking branch 'origin/master' into feat/agent-runner 2025-11-23 22:06:52 +08:00
Soulter
95e9da42d6 fix(webchat): webchat session cannot be deleted (#3759) 2025-11-23 22:03:07 +08:00
Soulter
1338cab61b feat: add configuration selector for session management and enhance session handling in chat components 2025-11-23 21:53:56 +08:00
Soulter
7ba98c1e91 feat: enhance provider display with grouped categorization and improved filtering 2025-11-23 21:06:16 +08:00
Soulter
9a5f507cbe feat: enable agent runner providers in configuration 2025-11-23 20:58:18 +08:00
Soulter
d560671d1f feat: agent runner config migration 2025-11-23 20:54:19 +08:00
Soulter
82c9cf4db6 chore: remove legacy coze and dashscope provider 2025-11-23 20:18:51 +08:00
Soulter
910ec6c695 feat: implement third party agent sub stage and refactor provider management
- Added `ThirdPartyAgentSubStage` to handle interactions with third-party agent runners (Dify, Coze, Dashscope).
- Refactored `star_request.py` to ensure consistent return types in the `process` method.
- Updated `stage.py` to initialize and utilize the new `AgentRequestSubStage`.
- Modified `ProviderManager` to skip loading agent runner providers.
- Removed `Dify` source implementation as it is now handled by the new agent runner structure.
- Enhanced `DifyAPIClient` to support file uploads via both file path and file data.
- Cleaned up shared preferences handling to simplify session preference retrieval.
- Updated dashboard configuration to reflect changes in agent runner provider selection.
- Refactored conversation commands to accommodate the new agent runner structure and remove direct dependencies on Dify.
- Adjusted main application logic to ensure compatibility with the new conversation management approach.
2025-11-23 20:18:51 +08:00
Soulter
766d6f2bec fix(conversation): update session configuration retrieval to use unified message origin 2025-11-23 20:18:51 +08:00
Soulter
9f39140987 fix(conversation): update session configuration retrieval to use unified message origin 2025-11-23 19:59:21 +08:00
Soulter
89716ef4da Merge remote-tracking branch 'origin/master' into feat/agent-runner 2025-11-23 14:48:08 +08:00
Soulter
3c4ea5a339 chore: bump version to 4.6.1 2025-11-23 13:58:53 +08:00
Soulter
601846a8c1 docs: refine readme 2025-11-22 18:57:08 +08:00
Soulter
85d66c1056 fix(migration): update migration_done key for webchat session tracking (#3746) 2025-11-22 18:51:00 +08:00
Dt8333
b89d3f663c fix(core.db): 修复升级后webchat未正确迁移的问题 (#3745)
不是所有人都叫Astrbot

#3722
2025-11-22 18:37:39 +08:00
Soulter
0260d430d1 Merge pull request #3706 from piexian/master 2025-11-22 01:11:35 +08:00
piexian
2e608cdc09 refactor(bailian_rerank): 修复误删除并优化top_n参数处理
- 移除不合理的知识库配置读取逻辑
- 添加os模块导入(用于读取环境变量)
- 抽取辅助函数:_build_payload()、_parse_results()、_log_usage()
- 添加自定义异常类:BailianRerankError、BailianAPIError、BailianNetworkError
- 使用.get()安全访问API响应字段,避免KeyError
- 使用raise ... from e保持异常链
2025-11-21 05:34:18 +08:00
piexian
234ce93dc1 refactor(bailian_rerank): 优化代码质量和错误处理
- 移除未使用的 os 导入
- 简化 API Key 验证逻辑
- 优化 top_n 参数处理,优先使用传入值
- 改进错误处理,使用 RuntimeError 替代通用 Exception
- 添加异常链保持原始错误上下文
2025-11-21 04:07:45 +08:00
piexian
2ada1deb9a 修复文档返回读取问题 2025-11-20 08:31:50 +08:00
piexian
788ceb9721 添加阿里百炼重排序模型 2025-11-20 08:05:42 +08:00
Soulter
61a68477d0 stage 2025-10-21 14:19:38 +08:00
Soulter
e74f626383 stage 2025-10-21 09:55:14 +08:00
Soulter
ef99f64291 feat(config): 添加 agent 运行器类型及相关配置支持 2025-10-21 00:47:04 +08:00
48 changed files with 3037 additions and 1553 deletions

View File

@@ -32,7 +32,7 @@
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手还是企业知识库AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用
## 主要功能

View File

@@ -2,13 +2,12 @@ import abc
import typing as T
from enum import Enum, auto
from astrbot.core.provider import Provider
from astrbot import logger
from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
class AgentState(Enum):
@@ -24,9 +23,7 @@ class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(
self,
provider: Provider,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
@@ -60,3 +57,9 @@ class BaseAgentRunner(T.Generic[TContext]):
This method should be called after the agent is done.
"""
...
def _transition_state(self, new_state: AgentState) -> None:
"""Transition the agent state."""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state

View File

@@ -0,0 +1,367 @@
import base64
import json
import sys
import typing as T
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core import sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .coze_api_client import CozeAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class CozeAgentRunner(BaseAgentRunner[TContext]):
"""Coze Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://"),
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
# 会话相关缓存
self.file_id_cache: dict[str, dict[str, str]] = {}
@override
async def step(self):
"""
执行 Coze Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Coze 请求并处理结果
async for response in self._execute_coze_request():
yield response
except Exception as e:
logger.error(f"Coze 请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"Coze 请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Coze 请求失败:{str(e)}")
),
)
finally:
await self.api_client.close()
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
async def _execute_coze_request(self):
"""执行 Coze 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
contexts = self.req.contexts or []
system_prompt = self.req.system_prompt
# 用户ID参数
user_id = session_id
# 获取或创建会话ID
conversation_id = await sp.get_async(
scope="umo",
scope_id=user_id,
key="coze_conversation_id",
default="",
)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{
"role": "system",
"content": system_prompt,
"content_type": "text",
},
)
# 处理历史上下文
if not self.auto_save_history and contexts:
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
# 处理上下文中的图片
content = ctx["content"]
if isinstance(content, list):
# 多模态内容,需要处理图片
processed_content = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片上传
try:
image_data = item.get("image_url", {})
url = image_data.get("url", "")
if url:
file_id = (
await self._download_and_upload_image(
url, session_id
)
)
processed_content.append(
{
"type": "file",
"file_id": file_id,
"file_url": url,
}
)
except Exception as e:
logger.warning(f"处理上下文图片失败: {e}")
continue
if processed_content:
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
}
)
else:
# 纯文本内容
additional_messages.append(
{
"role": ctx["role"],
"content": content,
"content_type": "text",
}
)
# 构建当前消息
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
# the url is a base64 string
try:
image_data = base64.b64decode(url)
file_id = await self.api_client.upload_file(image_data)
object_string_content.append(
{
"type": "image",
"file_id": file_id,
}
)
except Exception as e:
logger.warning(f"处理图片失败 {url}: {e}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
}
)
elif prompt:
# 纯文本
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
},
)
# 执行 Coze API 请求
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
await sp.put_async(
scope="umo",
scope_id=user_id,
key="coze_conversation_id",
value=data["conversation_id"],
)
if event_type == "conversation.message.delta":
# 增量消息
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
accumulated_content += content
message_started = True
# 如果是流式响应,发送增量数据
if self.streaming:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(content)
),
)
elif event_type == "conversation.message.completed":
# 消息完成
logger.debug("Coze message completed")
message_started = True
elif event_type == "conversation.chat.completed":
# 对话完成
logger.debug("Coze chat completed")
break
elif event_type == "error":
# 错误处理
error_msg = data.get("msg", "未知错误")
error_code = data.get("code", "UNKNOWN")
logger.error(f"Coze 出现错误: {error_code} - {error_msg}")
raise Exception(f"Coze 出现错误: {error_code} - {error_msg}")
if not message_started and not accumulated_content:
logger.warning("Coze 未返回任何内容")
accumulated_content = ""
# 创建最终响应
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def _download_and_upload_image(
self,
image_url: str,
session_id: str | None = None,
) -> str:
"""下载图片并上传到 Coze返回 file_id"""
import hashlib
# 计算哈希实现缓存
cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest()
if session_id:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}")
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self.api_client.upload_file(image_data)
if session_id:
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {e!s}")
raise Exception(f"处理图片失败: {e!s}")
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -0,0 +1,403 @@
import asyncio
import functools
import queue
import re
import sys
import threading
import typing as T
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
"""Dashscope Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
raise Exception("阿里云百炼 API Key 不能为空。")
self.app_id = provider_config.get("dashscope_app_id", "")
if not self.app_id:
raise Exception("阿里云百炼 APP ID 不能为空。")
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空。")
self.variables: dict = provider_config.get("variables", {}) or {}
self.rag_options: dict = provider_config.get("rag_options", {})
self.output_reference = self.rag_options.get("output_reference", False)
self.rag_options = self.rag_options.copy()
self.rag_options.pop("output_reference", None)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
"""判断是否有 RAG 选项
Returns:
bool: 是否有 RAG 选项
"""
if self.rag_options and (
len(self.rag_options.get("pipeline_ids", [])) > 0
or len(self.rag_options.get("file_ids", [])) > 0
):
return True
return False
@override
async def step(self):
"""
执行 Dashscope Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Dashscope 请求并处理结果
async for response in self._execute_dashscope_request():
yield response
except Exception as e:
logger.error(f"阿里云百炼请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}")
),
)
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
def _consume_sync_generator(
self, response: T.Any, response_queue: queue.Queue
) -> None:
"""在线程中消费同步generator,将结果放入队列
Args:
response: 同步generator对象
response_queue: 用于传递数据的队列
"""
try:
if self.streaming:
for chunk in response:
response_queue.put(("data", chunk))
else:
response_queue.put(("data", response))
except Exception as e:
response_queue.put(("error", e))
finally:
response_queue.put(("done", None))
async def _process_stream_chunk(
self, chunk: ApplicationResponse, output_text: str
) -> tuple[str, list | None, AgentResponse | None]:
"""处理流式响应的单个chunk
Args:
chunk: Dashscope响应chunk
output_text: 当前累积的输出文本
Returns:
(更新后的output_text, doc_references, AgentResponse或None)
"""
logger.debug(f"dashscope stream chunk: {chunk}")
if chunk.status_code != 200:
logger.error(
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
)
self._transition_state(AgentState.ERROR)
error_msg = (
f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
)
self.final_llm_resp = LLMResponse(
role="err",
result_chain=MessageChain().message(error_msg),
)
return (
output_text,
None,
AgentResponse(
type="err",
data=AgentResponseData(chain=MessageChain().message(error_msg)),
),
)
chunk_text = chunk.output.get("text", "") or ""
# RAG 引用脚标格式化
chunk_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", chunk_text)
response = None
if chunk_text:
output_text += chunk_text
response = AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=MessageChain().message(chunk_text)),
)
# 获取文档引用
doc_references = chunk.output.get("doc_references", None)
return output_text, doc_references, response
def _format_doc_references(self, doc_references: list) -> str:
"""格式化文档引用为文本
Args:
doc_references: 文档引用列表
Returns:
格式化后的引用文本
"""
ref_parts = []
for ref in doc_references:
ref_title = (
ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
)
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
ref_str = "".join(ref_parts)
return f"\n\n回答来源:\n{ref_str}"
async def _build_request_payload(
self, prompt: str, session_id: str, contexts: list, system_prompt: str
) -> dict:
"""构建请求payload
Args:
prompt: 用户输入
session_id: 会话ID
contexts: 上下文列表
system_prompt: 系统提示词
Returns:
请求payload字典
"""
conversation_id = await sp.get_async(
scope="umo",
scope_id=session_id,
key="dashscope_conversation_id",
default="",
)
# 获得会话变量
payload_vars = self.variables.copy()
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
payload_vars.update(session_var)
if (
self.dashscope_app_type in ["agent", "dialog-workflow"]
and not self.has_rag_options()
):
# 支持多轮对话的
p = {
"app_id": self.app_id,
"api_key": self.api_key,
"prompt": prompt,
"biz_params": payload_vars or None,
"stream": self.streaming,
"incremental_output": True,
}
if conversation_id:
p["session_id"] = conversation_id
return p
else:
# 不支持多轮对话的
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
"stream": self.streaming,
"incremental_output": True,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
return payload
async def _handle_streaming_response(
self, response: T.Any, session_id: str
) -> T.AsyncGenerator[AgentResponse, None]:
"""处理流式响应
Args:
response: Dashscope 流式响应 generator
Yields:
AgentResponse 对象
"""
response_queue = queue.Queue()
consumer_thread = threading.Thread(
target=self._consume_sync_generator,
args=(response, response_queue),
daemon=True,
)
consumer_thread.start()
output_text = ""
doc_references = None
while True:
try:
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
None, response_queue.get, True, 1
)
except queue.Empty:
continue
if item_type == "done":
break
elif item_type == "error":
raise item_data
elif item_type == "data":
chunk = item_data
assert isinstance(chunk, ApplicationResponse)
(
output_text,
chunk_doc_refs,
response,
) = await self._process_stream_chunk(chunk, output_text)
if response:
if response.type == "err":
yield response
return
yield response
if chunk_doc_refs:
doc_references = chunk_doc_refs
if chunk.output.session_id:
await sp.put_async(
scope="umo",
scope_id=session_id,
key="dashscope_conversation_id",
value=chunk.output.session_id,
)
# 添加 RAG 引用
if self.output_reference and doc_references:
ref_text = self._format_doc_references(doc_references)
output_text += ref_text
if self.streaming:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=MessageChain().message(ref_text)),
)
# 创建最终响应
chain = MessageChain(chain=[Comp.Plain(output_text)])
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def _execute_dashscope_request(self):
"""执行 Dashscope 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
contexts = self.req.contexts or []
system_prompt = self.req.system_prompt
# 检查图片输入
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
# 构建请求payload
payload = await self._build_request_payload(
prompt, session_id, contexts, system_prompt
)
if not self.streaming:
payload["incremental_output"] = False
# 发起请求
partial = functools.partial(Application.call, **payload)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
async for resp in self._handle_streaming_response(response, session_id):
yield resp
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -0,0 +1,336 @@
import base64
import os
import sys
import typing as T
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .dify_api_client import DifyAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DifyAgentRunner(BaseAgentRunner[TContext]):
"""Dify Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("dify_api_key", "")
self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_type = provider_config.get("dify_api_type", "chat")
self.workflow_output_key = provider_config.get(
"dify_workflow_output_key",
"astrbot_wf_output",
)
self.dify_query_input_key = provider_config.get(
"dify_query_input_key",
"astrbot_text_query",
)
self.variables: dict = provider_config.get("variables", {}) or {}
self.timeout = provider_config.get("timeout", 60)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.api_client = DifyAPIClient(self.api_key, self.api_base)
@override
async def step(self):
"""
执行 Dify Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Dify 请求并处理结果
async for response in self._execute_dify_request():
yield response
except Exception as e:
logger.error(f"Dify 请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"Dify 请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Dify 请求失败:{str(e)}")
),
)
finally:
await self.api_client.close()
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
async def _execute_dify_request(self):
"""执行 Dify 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
system_prompt = self.req.system_prompt
conversation_id = await sp.get_async(
scope="umo",
scope_id=session_id,
key="dify_conversation_id",
default="",
)
result = ""
# 处理图片上传
files_payload = []
for image_url in image_urls:
# image_url is a base64 string
try:
image_data = base64.b64decode(image_url)
file_response = await self.api_client.file_upload(
file_data=image_data,
user=session_id,
mime_type="image/png",
file_name="image.png",
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
}
)
except Exception as e:
logger.warning(f"上传图片失败:{e}")
continue
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
# 处理不同的 API 类型
match self.api_type:
case "chat" | "agent" | "chatflow":
if not prompt:
prompt = "请描述这张图片。"
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk["event"] == "message" or chunk["event"] == "agent_message":
result += chunk["answer"]
if not conversation_id:
await sp.put_async(
scope="umo",
scope_id=session_id,
key="dify_conversation_id",
value=chunk["conversation_id"],
)
conversation_id = chunk["conversation_id"]
# 如果是流式响应,发送增量数据
if self.streaming and chunk["answer"]:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(chunk["answer"])
),
)
elif chunk["event"] == "message_end":
logger.debug("Dify message end")
break
elif chunk["event"] == "error":
logger.error(f"Dify 出现错误:{chunk}")
raise Exception(
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
)
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**payload_vars,
},
user=session_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify workflow resp chunk: {chunk}")
match chunk["event"]:
case "workflow_started":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。"
)
case "node_finished":
logger.debug(
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。"
)
case "text_chunk":
if self.streaming and chunk["data"]["text"]:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(
chunk["data"]["text"]
)
),
)
case "workflow_finished":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
)
logger.debug(f"Dify 工作流结果:{chunk}")
if chunk["data"]["error"]:
logger.error(
f"Dify 工作流出现错误:{chunk['data']['error']}"
)
raise Exception(
f"Dify 工作流出现错误:{chunk['data']['error']}"
)
if self.workflow_output_key not in chunk["data"]["outputs"]:
raise Exception(
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
)
result = chunk
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
if not result:
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
# 解析结果
chain = await self.parse_dify_result(result)
# 创建最终响应
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
"""解析 Dify 的响应结果"""
if isinstance(chunk, str):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict):
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"{item['filename']}.wav")
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
case "video":
return Comp.Video(file=item["url"])
case _:
return Comp.File(name=item["filename"], file=item["url"])
output = chunk["data"]["outputs"][self.workflow_output_key]
chains = []
if isinstance(output, str):
# 纯文本输出
chains.append(Comp.Plain(output))
elif isinstance(output, list):
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
for item in output:
# handle Array[File]
if (
not isinstance(item, dict)
or item.get("dify_model_identity", "") != "__dify__file__"
):
chains.append(Comp.Plain(str(output)))
break
else:
chains.append(Comp.Plain(str(output)))
# scan file
files = chunk["data"].get("files", [])
for item in files:
comp = await parse_file(item)
chains.append(comp)
return MessageChain(chain=chains)
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -3,7 +3,7 @@ import json
from collections.abc import AsyncGenerator
from typing import Any
from aiohttp import ClientResponse, ClientSession
from aiohttp import ClientResponse, ClientSession, FormData
from astrbot.core import logger
@@ -101,21 +101,59 @@ class DifyAPIClient:
async def file_upload(
self,
file_path: str,
user: str,
file_path: str | None = None,
file_data: bytes | None = None,
file_name: str | None = None,
mime_type: str | None = None,
) -> dict[str, Any]:
"""Upload a file to Dify. Must provide either file_path or file_data.
Args:
user: The user ID.
file_path: The path to the file to upload.
file_data: The file data in bytes.
file_name: Optional file name when using file_data.
Returns:
A dictionary containing the uploaded file information.
"""
url = f"{self.api_base}/files/upload"
with open(file_path, "rb") as f:
payload = {
"user": user,
"file": f,
}
async with self.session.post(
url,
data=payload,
headers=self.headers,
) as resp:
return await resp.json() # {"id": "xxx", ...}
form = FormData()
form.add_field("user", user)
if file_data is not None:
# 使用 bytes 数据
form.add_field(
"file",
file_data,
filename=file_name or "uploaded_file",
content_type=mime_type or "application/octet-stream",
)
elif file_path is not None:
# 使用文件路径
import os
with open(file_path, "rb") as f:
file_content = f.read()
form.add_field(
"file",
file_content,
filename=os.path.basename(file_path),
content_type=mime_type or "application/octet-stream",
)
else:
raise ValueError("file_path 和 file_data 不能同时为 None")
async with self.session.post(
url,
data=form,
headers=self.headers, # 不包含 Content-Type让 aiohttp 自动设置
) as resp:
if resp.status != 200 and resp.status != 201:
text = await resp.text()
raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()

View File

@@ -69,12 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
self.run_context.messages = messages
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
if self.streaming:

View File

@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.6.0"
VERSION = "4.6.1"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -68,6 +68,10 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"agent_runner_type": "local",
"dify_agent_runner_provider_id": "",
"coze_agent_runner_provider_id": "",
"dashscope_agent_runner_provider_id": "",
"unsupported_streaming_strategy": "realtime_segmenting",
"max_agent_step": 30,
"tool_call_timeout": 60,
@@ -1011,7 +1015,7 @@ CONFIG_METADATA_2 = {
"id": "dify_app_default",
"provider": "dify",
"type": "dify",
"provider_type": "chat_completion",
"provider_type": "agent_runner",
"enable": True,
"dify_api_type": "chat",
"dify_api_key": "",
@@ -1025,20 +1029,20 @@ CONFIG_METADATA_2 = {
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "chat_completion",
"provider_type": "agent_runner",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
"auto_save_history": True,
# "auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
"type": "dashscope",
"provider_type": "chat_completion",
"provider_type": "agent_runner",
"enable": True,
"dashscope_app_type": "agent",
"dashscope_api_key": "",
@@ -1308,6 +1312,19 @@ CONFIG_METADATA_2 = {
"timeout": 20,
"launch_model_if_not_running": False,
},
"阿里云百炼重排序": {
"id": "bailian_rerank",
"type": "bailian_rerank",
"provider": "bailian",
"provider_type": "rerank",
"enable": True,
"rerank_api_key": "",
"rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
"rerank_model": "qwen3-rerank",
"timeout": 30,
"return_documents": False,
"instruct": "",
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
@@ -1342,6 +1359,16 @@ CONFIG_METADATA_2 = {
"description": "重排序模型名称",
"type": "string",
},
"return_documents": {
"description": "是否在排序结果中返回文档原文",
"type": "bool",
"hint": "默认值false以减少网络传输开销。",
},
"instruct": {
"description": "自定义排序任务类型说明",
"type": "string",
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。",
},
"launch_model_if_not_running": {
"description": "模型未运行时自动启动",
"type": "bool",
@@ -1884,7 +1911,6 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用",
"type": "bool",
"hint": "是否启用。",
},
"key": {
"description": "API Key",
@@ -2014,12 +2040,22 @@ CONFIG_METADATA_2 = {
"unsupported_streaming_strategy": {
"type": "string",
},
"agent_runner_type": {
"type": "string",
},
"dify_agent_runner_provider_id": {
"type": "string",
},
"coze_agent_runner_provider_id": {
"type": "string",
},
"dashscope_agent_runner_provider_id": {
"type": "string",
},
"max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
},
"tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
},
@@ -2157,30 +2193,75 @@ CONFIG_METADATA_3 = {
"ai_group": {
"name": "AI 配置",
"metadata": {
"ai": {
"description": "模型",
"agent_runner": {
"description": "Agent 执行方式",
"hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。",
"type": "object",
"items": {
"provider_settings.enable": {
"description": "启用大语言模型聊天",
"description": "启用",
"type": "bool",
"hint": "AI 对话总开关",
},
"provider_settings.agent_runner_type": {
"description": "执行器",
"type": "string",
"options": ["local", "dify", "coze", "dashscope"],
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"],
"condition": {
"provider_settings.enable": True,
},
},
"provider_settings.coze_agent_runner_provider_id": {
"description": "Coze Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:coze",
"condition": {
"provider_settings.agent_runner_type": "coze",
"provider_settings.enable": True,
},
},
"provider_settings.dify_agent_runner_provider_id": {
"description": "Dify Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:dify",
"condition": {
"provider_settings.agent_runner_type": "dify",
"provider_settings.enable": True,
},
},
"provider_settings.dashscope_agent_runner_provider_id": {
"description": "阿里云百炼应用 Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:dashscope",
"condition": {
"provider_settings.agent_runner_type": "dashscope",
"provider_settings.enable": True,
},
},
},
},
"ai": {
"description": "模型",
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"type": "object",
"items": {
"provider_settings.default_provider_id": {
"description": "默认聊天模型",
"type": "string",
"_special": "select_provider",
"hint": "留空时使用第一个模型",
"hint": "留空时使用第一个模型",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
"_special": "select_provider",
"hint": "留空代表不使用可用于不支持视觉模态的聊天模型",
"hint": "留空代表不使用可用于非多模态模型",
},
"provider_stt_settings.enable": {
"description": "启用语音转文本",
"type": "bool",
"hint": "STT 总开关",
"hint": "STT 总开关",
},
"provider_stt_settings.provider_id": {
"description": "默认语音转文本模型",
@@ -2194,12 +2275,11 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": {
"description": "启用文本转语音",
"type": "bool",
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
"hint": "TTS 总开关",
},
"provider_tts_settings.provider_id": {
"description": "默认文本转语音模型",
"type": "string",
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
"_special": "select_provider_tts",
"condition": {
"provider_tts_settings.enable": True,
@@ -2210,6 +2290,9 @@ CONFIG_METADATA_3 = {
"type": "text",
},
},
"condition": {
"provider_settings.enable": True,
},
},
"persona": {
"description": "人格",
@@ -2221,6 +2304,10 @@ CONFIG_METADATA_3 = {
"_special": "select_persona",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"knowledgebase": {
"description": "知识库",
@@ -2249,6 +2336,10 @@ CONFIG_METADATA_3 = {
"hint": "启用后,知识库检索将作为 LLM Tool由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"websearch": {
"description": "网页搜索",
@@ -2285,6 +2376,10 @@ CONFIG_METADATA_3 = {
"type": "bool",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"others": {
"description": "其他配置",
@@ -2293,34 +2388,51 @@ CONFIG_METADATA_3 = {
"provider_settings.display_reasoning_text": {
"description": "显示思考内容",
"type": "bool",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.identifier": {
"description": "用户识别",
"type": "bool",
"hint": "启用后,会在提示词前包含用户 ID 信息。",
},
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
"hint": "启用后,在支持的平台(OneBot v11)上会在提示词前包含群名称信息。",
},
"provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知",
"type": "bool",
"hint": "启用后,会在系统提示词中附带当前时间信息。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.show_tool_use_status": {
"description": "输出函数调用状态",
"type": "bool",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.streaming_response": {
"description": "流式回复",
"description": "流式输出",
"type": "bool",
},
"provider_settings.unsupported_streaming_strategy": {
@@ -2336,17 +2448,23 @@ CONFIG_METADATA_3 = {
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条-1 为不限制",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat则需要 /chat 才会触发 LLM 请求",
},
"provider_settings.prompt_prefix": {
"description": "用户提示词",
@@ -2358,6 +2476,9 @@ CONFIG_METADATA_3 = {
"type": "bool",
},
},
"condition": {
"provider_settings.enable": True,
},
},
},
},

View File

@@ -16,13 +16,12 @@ import time
import traceback
from asyncio import Queue
from astrbot.core import LogBroker, logger, sp
from astrbot.api import logger, sp
from astrbot.core import LogBroker
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
@@ -34,6 +33,7 @@ from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.migra_helper import migra
from . import astrbot_config, html_renderer
from .event_bus import EventBus
@@ -97,18 +97,16 @@ class AstrBotCoreLifecycle:
sp=sp,
)
# 4.5 to 4.6 migration for umop_config_router
# apply migration
try:
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
await migra(
self.db,
self.astrbot_config_mgr,
self.umop_config_router,
self.astrbot_config_mgr,
)
except Exception as e:
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# migration for webchat session
try:
await migrate_webchat_session(self.db)
except Exception as e:
logger.error(f"Migration for webchat session failed: {e!s}")
logger.error(f"AstrBot migration failed: {e!s}")
logger.error(traceback.format_exc())
# 初始化事件队列

View File

@@ -25,7 +25,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
"""
# 检查是否已经完成迁移
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_webchat_session"
"global", "global", "migration_done_webchat_session_1"
)
if migration_done:
return
@@ -43,7 +43,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
func.max(PlatformMessageHistory.updated_at).label("latest"),
)
.where(col(PlatformMessageHistory.platform_id) == "webchat")
.where(col(PlatformMessageHistory.sender_id) == "astrbot")
.where(col(PlatformMessageHistory.sender_id) != "bot")
.group_by(col(PlatformMessageHistory.user_id))
)
@@ -53,7 +53,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
if not webchat_users:
logger.info("没有找到需要迁移的 WebChat 数据")
await sp.put_async(
"global", "global", "migration_done_webchat_session", True
"global", "global", "migration_done_webchat_session_1", True
)
return
@@ -124,7 +124,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
logger.info("没有新会话需要迁移")
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_webchat_session", True)
await sp.put_async("global", "global", "migration_done_webchat_session_1", True)
except Exception as e:
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)

View File

@@ -173,7 +173,7 @@ class PlatformSession(SQLModel, table=True):
max_length=100,
nullable=False,
unique=True,
default_factory=lambda: f"webchat_{uuid.uuid4()}",
default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(default="webchat", nullable=False)
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""

View File

@@ -794,7 +794,7 @@ class SQLiteDatabase(BaseDatabase):
await session.execute(
update(PlatformSession)
.where(col(PlatformSession.session_id == session_id))
.where(col(PlatformSession.session_id) == session_id)
.values(**values),
)
@@ -805,6 +805,6 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(PlatformSession).where(
col(PlatformSession.session_id == session_id),
col(PlatformSession.session_id) == session_id,
),
)

View File

@@ -0,0 +1,48 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_llm_manager import SessionServiceManager
from ...context import PipelineContext
from ..stage import Stage
from .agent_sub_stages.internal import InternalAgentSubStage
from .agent_sub_stages.third_party import ThirdPartyAgentSubStage
class AgentRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.bot_wake_prefixs: list[str] = self.config["wake_prefix"]
self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"]
for bwp in self.bot_wake_prefixs:
if self.prov_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
)
self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :]
agent_runner_type = self.config["provider_settings"]["agent_runner_type"]
if agent_runner_type == "local":
self.agent_sub_stage = InternalAgentSubStage()
else:
self.agent_sub_stage = ThirdPartyAgentSubStage()
await self.agent_sub_stage.initialize(ctx)
async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]:
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug(
"This pipeline does not enable AI capability, skip processing."
)
return
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
)
return
async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix):
yield resp

View File

@@ -21,27 +21,24 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
from ....astr_agent_context import AgentContextWrapper
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....astr_agent_run_util import AgentRunner, run_agent
from ....astr_agent_tool_exec import FunctionToolExecutor
from ...context import PipelineContext, call_event_hook
from ..stage import Stage
from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
from .....astr_agent_context import AgentContextWrapper
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from .....astr_agent_run_util import AgentRunner, run_agent
from .....astr_agent_tool_exec import FunctionToolExecutor
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
class LLMRequestSubStage(Stage):
class InternalAgentSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
self.provider_wake_prefix: str = settings["wake_prefix"] # str
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
@@ -59,13 +56,6 @@ class LLMRequestSubStage(Stage):
self.show_reasoning = settings.get("display_reasoning_text", False)
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
)
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
@@ -304,21 +294,10 @@ class LLMRequestSubStage(Stage):
return fixed_messages
async def process(
self,
event: AstrMessageEvent,
_nested: bool = False,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent, provider_wake_prefix: str
) -> AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
@@ -348,12 +327,12 @@ class LLMRequestSubStage(Stage):
req.image_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix and not event.message_str.startswith(
self.provider_wake_prefix
if provider_wake_prefix and not event.message_str.startswith(
provider_wake_prefix
):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
req.prompt = event.message_str[len(provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:

View File

@@ -0,0 +1,202 @@
import asyncio
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from astrbot.core import logger
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner,
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
ProviderRequest,
)
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
AGENT_RUNNER_TYPE_KEY = {
"dify": "dify_agent_runner_provider_id",
"coze": "coze_agent_runner_provider_id",
"dashscope": "dashscope_agent_runner_provider_id",
}
async def run_third_party_agent(
runner: "BaseAgentRunner",
stream_to_general: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""
运行第三方 agent runner 并转换响应格式
类似于 run_agent 函数,但专门处理第三方 agent runner
"""
try:
async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
if resp.type == "streaming_delta":
if stream_to_general:
continue
yield resp.data["chain"]
elif resp.type == "llm_result":
if stream_to_general:
yield resp.data["chain"]
except Exception as e:
logger.error(f"Third party agent runner error: {e}")
err_msg = (
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
)
yield MessageChain().message(err_msg)
class ThirdPartyAgentSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.conf = ctx.astrbot_config
self.runner_type = self.conf["provider_settings"]["agent_runner_type"]
self.prov_id = self.conf["provider_settings"].get(
AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
"",
)
settings = ctx.astrbot_config["provider_settings"]
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
]
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
) -> AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if provider_wake_prefix and not event.message_str.startswith(
provider_wake_prefix
):
return
self.prov_cfg: dict = next(
(p for p in self.conf["provider"] if p["id"] == self.prov_id),
{},
)
if not self.prov_id or not self.prov_cfg:
logger.error(
"Third Party Agent Runner provider ID is not configured properly."
)
return
# make provider request
req = ProviderRequest()
req.session_id = event.unified_msg_origin
req.prompt = event.message_str[len(provider_wake_prefix) :]
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_base64()
req.image_urls.append(image_path)
if not req.prompt and not req.image_urls:
return
# call event hook
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if self.runner_type == "dify":
runner = DifyAgentRunner[AstrAgentContext]()
elif self.runner_type == "coze":
runner = CozeAgentRunner[AstrAgentContext]()
elif self.runner_type == "dashscope":
runner = DashscopeAgentRunner[AstrAgentContext]()
else:
raise ValueError(
f"Unsupported third party agent runner type: {self.runner_type}",
)
astr_agent_ctx = AstrAgentContext(
context=self.ctx.plugin_manager.context,
event=event,
)
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
await runner.reset(
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=60,
),
agent_hooks=MAIN_AGENT_HOOKS,
provider_config=self.prov_cfg,
streaming=streaming_response,
)
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_third_party_agent(
runner,
stream_to_general=False,
),
),
)
yield
if runner.done():
final_resp = runner.get_final_llm_resp()
if final_resp and final_resp.result_chain:
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
# 非流式响应或转换为普通响应
async for _ in run_third_party_agent(
runner,
stream_to_general=stream_to_general,
):
yield
final_resp = runner.get_final_llm_resp()
if not final_resp or not final_resp.result_chain:
logger.warning("Agent Runner 未返回最终结果。")
return
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.LLM_RESULT,
),
)
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=self.runner_type,
provider_type=self.runner_type,
),
)

View File

@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
) -> AsyncGenerator[None, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers",
)

View File

@@ -7,7 +7,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata
from ..context import PipelineContext
from ..stage import Stage, register_stage
from .method.llm_request import LLMRequestSubStage
from .method.agent_request import AgentRequestSubStage
from .method.star_request import StarRequestSubStage
@@ -17,9 +17,12 @@ class ProcessStage(Stage):
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.llm_request_sub_stage = LLMRequestSubStage()
await self.llm_request_sub_stage.initialize(ctx)
# initialize agent sub stage
self.agent_sub_stage = AgentRequestSubStage()
await self.agent_sub_stage.initialize(ctx)
# initialize star request sub stage
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
@@ -39,7 +42,7 @@ class ProcessStage(Stage):
# Handler 的 LLM 请求
event.set_extra("provider_request", resp)
_t = False
async for _ in self.llm_request_sub_stage.process(event):
async for _ in self.agent_sub_stage.process(event):
_t = True
yield
if not _t:
@@ -67,5 +70,5 @@ class ProcessStage(Stage):
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
async for _ in self.llm_request_sub_stage.process(event):
async for _ in self.agent_sub_stage.process(event):
yield

View File

@@ -227,6 +227,8 @@ class ProviderManager:
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
return
if provider_config.get("provider_type", "") == "agent_runner":
return
logger.info(
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...",
@@ -247,14 +249,6 @@ class ProviderManager:
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "coze":
from .sources.coze_source import ProviderCoze as ProviderCoze
case "dashscope":
from .sources.dashscope_source import (
ProviderDashscope as ProviderDashscope,
)
case "googlegenai_chat_completion":
from .sources.gemini_source import (
ProviderGoogleGenAI as ProviderGoogleGenAI,
@@ -331,6 +325,10 @@ class ProviderManager:
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
case "bailian_rerank":
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",

View File

@@ -0,0 +1,236 @@
import os
import aiohttp
from astrbot import logger
from ..entities import ProviderType, RerankResult
from ..provider import RerankProvider
from ..register import register_provider_adapter
class BailianRerankError(Exception):
"""百炼重排序服务异常基类"""
pass
class BailianAPIError(BailianRerankError):
"""百炼API返回错误"""
pass
class BailianNetworkError(BailianRerankError):
"""百炼网络请求错误"""
pass
@register_provider_adapter(
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
)
class BailianRerankProvider(RerankProvider):
"""阿里云百炼文本重排序适配器."""
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
# API配置
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
"DASHSCOPE_API_KEY", ""
)
if not self.api_key:
raise ValueError("阿里云百炼 API Key 不能为空。")
self.model = provider_config.get("rerank_model", "qwen3-rerank")
self.timeout = provider_config.get("timeout", 30)
self.return_documents = provider_config.get("return_documents", False)
self.instruct = provider_config.get("instruct", "")
self.base_url = provider_config.get(
"rerank_api_base",
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
)
# 设置HTTP客户端
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
self.client = aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
)
# 设置模型名称
self.set_model(self.model)
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
def _build_payload(
self, query: str, documents: list[str], top_n: int | None
) -> dict:
"""构建请求载荷
Args:
query: 查询文本
documents: 文档列表
top_n: 返回前N个结果如果为None则返回所有结果
Returns:
请求载荷字典
"""
base = {"model": self.model, "input": {"query": query, "documents": documents}}
params = {
k: v
for k, v in [
("top_n", top_n if top_n is not None and top_n > 0 else None),
("return_documents", True if self.return_documents else None),
(
"instruct",
self.instruct
if self.instruct and self.model == "qwen3-rerank"
else None,
),
]
if v is not None
}
if params:
base["parameters"] = params
return base
def _parse_results(self, data: dict) -> list[RerankResult]:
"""解析API响应结果
Args:
data: API响应数据
Returns:
重排序结果列表
Raises:
BailianAPIError: API返回错误
KeyError: 结果缺少必要字段
"""
# 检查响应状态
if data.get("code", "200") != "200":
raise BailianAPIError(
f"百炼 API 错误: {data.get('code')} {data.get('message', '')}"
)
results = data.get("output", {}).get("results", [])
if not results:
logger.warning(f"百炼 Rerank 返回空结果: {data}")
return []
# 转换为RerankResult对象使用.get()避免KeyError
rerank_results = []
for idx, result in enumerate(results):
try:
index = result.get("index", idx)
relevance_score = result.get("relevance_score", 0.0)
if relevance_score is None:
logger.warning(f"结果 {idx} 缺少 relevance_score使用默认值 0.0")
relevance_score = 0.0
rerank_result = RerankResult(
index=index, relevance_score=relevance_score
)
rerank_results.append(rerank_result)
except Exception as e:
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
continue
return rerank_results
def _log_usage(self, data: dict) -> None:
"""记录使用量信息
Args:
data: API响应数据
"""
tokens = data.get("usage", {}).get("total_tokens", 0)
if tokens > 0:
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
async def rerank(
self,
query: str,
documents: list[str],
top_n: int | None = None,
) -> list[RerankResult]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序的文档列表
top_n: 返回前N个结果如果为None则使用配置中的默认值
Returns:
重排序结果列表
"""
if not documents:
logger.warning("文档列表为空,返回空结果")
return []
if not query.strip():
logger.warning("查询文本为空,返回空结果")
return []
# 检查限制
if len(documents) > 500:
logger.warning(
f"文档数量({len(documents)})超过限制(500)将截断前500个文档"
)
documents = documents[:500]
try:
# 构建请求载荷如果top_n为None则返回所有重排序结果
payload = self._build_payload(query, documents, top_n)
logger.debug(
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
)
# 发送请求
async with self.client.post(self.base_url, json=payload) as response:
response.raise_for_status()
response_data = await response.json()
# 解析结果并记录使用量
results = self._parse_results(response_data)
self._log_usage(response_data)
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
return results
except aiohttp.ClientError as e:
error_msg = f"网络请求失败: {e}"
logger.error(f"百炼 Rerank 网络请求失败: {e}")
raise BailianNetworkError(error_msg) from e
except BailianRerankError:
raise
except Exception as e:
error_msg = f"重排序失败: {e}"
logger.error(f"百炼 Rerank 处理失败: {e}")
raise BailianRerankError(error_msg) from e
async def terminate(self) -> None:
"""关闭HTTP客户端会话."""
if self.client:
logger.info("关闭 百炼 Rerank 客户端会话")
try:
await self.client.close()
except Exception as e:
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
finally:
self.client = None

View File

@@ -1,650 +0,0 @@
import base64
import hashlib
import json
import os
from collections.abc import AsyncGenerator
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse
from ..register import register_provider_adapter
from .coze_api_client import CozeAPIClient
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
class ProviderCoze(Provider):
def __init__(
self,
provider_config,
provider_settings,
) -> None:
super().__init__(
provider_config,
provider_settings,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://"),
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
self.conversation_ids: dict[str, str] = {}
self.file_id_cache: dict[str, dict[str, str]] = {}
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
"""生成统一的缓存键
Args:
data: 图片数据或路径
is_base64: 是否是 base64 数据
Returns:
str: 缓存键
"""
try:
if is_base64 and data.startswith("data:image/"):
try:
header, encoded = data.split(",", 1)
image_bytes = base64.b64decode(encoded)
cache_key = hashlib.md5(image_bytes).hexdigest()
return cache_key
except Exception:
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
return cache_key
elif data.startswith(("http://", "https://")):
# URL图片使用URL作为缓存键
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
return cache_key
else:
clean_path = (
data.split("_")[0]
if "_" in data and len(data.split("_")) >= 3
else data
)
if os.path.exists(clean_path):
with open(clean_path, "rb") as f:
file_content = f.read()
cache_key = hashlib.md5(file_content).hexdigest()
return cache_key
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
return cache_key
except Exception as e:
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
return cache_key
async def _upload_file(
self,
file_data: bytes,
session_id: str | None = None,
cache_key: str | None = None,
) -> str:
"""上传文件到 Coze 并返回 file_id"""
# 使用 API 客户端上传文件
file_id = await self.api_client.upload_file(file_data)
# 缓存 file_id
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
async def _download_and_upload_image(
self,
image_url: str,
session_id: str | None = None,
) -> str:
"""下载图片并上传到 Coze返回 file_id"""
# 计算哈希实现缓存
cache_key = self._generate_cache_key(image_url) if session_id else None
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self._upload_file(image_data, session_id, cache_key)
if session_id and cache_key:
self.file_id_cache[session_id][cache_key] = file_id
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {e!s}")
raise Exception(f"处理图片失败: {e!s}")
async def _process_context_images(
self,
content: str | list,
session_id: str,
) -> str:
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
try:
if isinstance(content, str):
return content
processed_content = []
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
for item in content:
if not isinstance(item, dict):
processed_content.append(item)
continue
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片逻辑
if "file_id" in item:
# 已经有 file_id
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
processed_content.append(item)
else:
# 获取图片数据
image_data = ""
if "image_url" in item and isinstance(item["image_url"], dict):
image_data = item["image_url"].get("url", "")
elif "data" in item:
image_data = item.get("data", "")
elif "url" in item:
image_data = item.get("url", "")
if not image_data:
continue
# 计算哈希用于缓存
cache_key = self._generate_cache_key(
image_data,
is_base64=image_data.startswith("data:image/"),
)
# 检查缓存
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
processed_content.append(
{"type": "image", "file_id": file_id},
)
else:
# 上传图片并缓存
if image_data.startswith("data:image/"):
# base64 处理
_, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
elif image_data.startswith(("http://", "https://")):
# URL 图片
file_id = await self._download_and_upload_image(
image_data,
session_id,
)
# 为URL图片也添加缓存
self.file_id_cache[session_id][cache_key] = file_id
elif os.path.exists(image_data):
# 本地文件
with open(image_data, "rb") as f:
image_bytes = f.read()
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
else:
logger.warning(
f"无法处理的图片格式: {image_data[:50]}...",
)
continue
processed_content.append(
{"type": "image", "file_id": file_id},
)
result = json.dumps(processed_content, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"处理上下文图片失败: {e!s}")
if isinstance(content, str):
return content
return json.dumps(content, ensure_ascii=False)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
"""文本对话, 内部使用流式接口实现非流式
Args:
prompt (str): 用户提示词
session_id (str): 会话ID
image_urls (List[str]): 图片URL列表
func_tool (FuncCall): 函数调用工具(不支持)
contexts (List): 上下文列表
system_prompt (str): 系统提示语
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
model (str): 模型名称(不支持)
Returns:
LLMResponse: LLM响应对象
"""
accumulated_content = ""
final_response = None
async for llm_response in self.text_chat_stream(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
model=model,
**kwargs,
):
if llm_response.is_chunk:
if llm_response.completion_text:
accumulated_content += llm_response.completion_text
else:
final_response = llm_response
if final_response:
return final_response
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
return LLMResponse(role="assistant", result_chain=chain)
return LLMResponse(role="assistant", completion_text="")
async def text_chat_stream(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话接口"""
# 用户ID参数(参考文档, 可以自定义)
user_id = session_id or kwargs.get("user", "default_user")
# 获取或创建会话ID
conversation_id = self.conversation_ids.get(user_id)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{
"role": "system",
"content": system_prompt,
"content_type": "text",
},
)
contexts = self._ensure_message_to_dicts(contexts)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
content = ctx["content"]
content_type = ctx.get("content_type", "text")
# 处理可能包含图片的上下文
if (
content_type == "object_string"
or (isinstance(content, str) and content.startswith("["))
or (
isinstance(content, list)
and any(
isinstance(item, dict)
and item.get("type") == "image_url"
for item in content
)
)
):
processed_content = await self._process_context_images(
content,
user_id,
)
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
},
)
else:
# 纯文本
additional_messages.append(
{
"role": ctx["role"],
"content": (
content
if isinstance(content, str)
else json.dumps(content, ensure_ascii=False)
),
"content_type": "text",
},
)
else:
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
try:
if url.startswith(("http://", "https://")):
# 网络图片
file_id = await self._download_and_upload_image(
url,
user_id,
)
else:
# 本地文件或 base64
if url.startswith("data:image/"):
# base64
_, encoded = url.split(",", 1)
image_data = base64.b64decode(encoded)
cache_key = self._generate_cache_key(
url,
is_base64=True,
)
file_id = await self._upload_file(
image_data,
user_id,
cache_key,
)
# 本地文件
elif os.path.exists(url):
with open(url, "rb") as f:
image_data = f.read()
# 用文件路径和修改时间来缓存
file_stat = os.stat(url)
cache_key = self._generate_cache_key(
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
is_base64=False,
)
file_id = await self._upload_file(
image_data,
user_id,
cache_key,
)
else:
logger.warning(f"图片文件不存在: {url}")
continue
object_string_content.append(
{
"type": "image",
"file_id": file_id,
},
)
except Exception as e:
logger.error(f"处理图片失败 {url}: {e!s}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
},
)
# 纯文本
elif prompt:
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
},
)
try:
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
self.conversation_ids[user_id] = data["conversation_id"]
elif event_type == "conversation.message.delta":
if isinstance(data, dict):
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
message_started = True
accumulated_content += content
yield LLMResponse(
role="assistant",
completion_text=content,
is_chunk=True,
)
elif event_type == "conversation.message.completed":
if isinstance(data, dict):
msg_type = data.get("type")
if msg_type == "answer" and data.get("role") == "assistant":
final_content = data.get("content", "")
if not accumulated_content and final_content:
chain = MessageChain(chain=[Comp.Plain(final_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
elif event_type == "conversation.chat.completed":
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
break
elif event_type == "done":
break
elif event_type == "error":
error_msg = (
data.get("message", "未知错误")
if isinstance(data, dict)
else str(data)
)
logger.error(f"Coze 流式响应错误: {error_msg}")
yield LLMResponse(
role="err",
completion_text=f"Coze 错误: {error_msg}",
is_chunk=False,
)
break
if not message_started and not accumulated_content:
yield LLMResponse(
role="assistant",
completion_text="LLM 未响应任何内容。",
is_chunk=False,
)
elif message_started and accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
except Exception as e:
logger.error(f"Coze 流式请求失败: {e!s}")
yield LLMResponse(
role="err",
completion_text=f"Coze 流式请求失败: {e!s}",
is_chunk=False,
)
async def forget(self, session_id: str):
"""清空指定会话的上下文"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if user_id in self.file_id_cache:
self.file_id_cache.pop(user_id, None)
if not conversation_id:
return True
try:
response = await self.api_client.clear_context(conversation_id)
if "code" in response and response["code"] == 0:
self.conversation_ids.pop(user_id, None)
return True
logger.warning(f"清空 Coze 会话上下文失败: {response}")
return False
except Exception as e:
logger.error(f"清空 Coze 会话失败: {e!s}")
return False
async def get_current_key(self):
"""获取当前API Key"""
return self.api_key
async def set_key(self, key: str):
"""设置新的API Key"""
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
async def get_models(self):
"""获取可用模型列表"""
return [f"bot_{self.bot_id}"]
def get_model(self):
"""获取当前模型"""
return f"bot_{self.bot_id}"
def set_model(self, model: str):
"""设置模型在Coze中是Bot ID"""
if model.startswith("bot_"):
self.bot_id = model[4:]
else:
self.bot_id = model
async def get_human_readable_context(
self,
session_id: str,
page: int = 1,
page_size: int = 10,
):
"""获取人类可读的上下文历史"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if not conversation_id:
return []
try:
data = await self.api_client.get_message_list(
conversation_id=conversation_id,
order="desc",
limit=page_size,
offset=(page - 1) * page_size,
)
if data.get("code") != 0:
logger.warning(f"获取 Coze 消息历史失败: {data}")
return []
messages = data.get("data", {}).get("messages", [])
readable_history = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
msg_type = msg.get("type", "")
if role == "user":
readable_history.append(f"用户: {content}")
elif role == "assistant" and msg_type == "answer":
readable_history.append(f"助手: {content}")
return readable_history
except Exception as e:
logger.error(f"获取 Coze 消息历史失败: {e!s}")
return []
async def terminate(self):
"""清理资源"""
await self.api_client.close()

View File

@@ -1,207 +0,0 @@
import asyncio
import functools
import re
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from .. import Provider
from ..entities import LLMResponse
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
class ProviderDashscope(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
Provider.__init__(
self,
provider_config,
provider_settings,
)
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
raise Exception("阿里云百炼 API Key 不能为空。")
self.app_id = provider_config.get("dashscope_app_id", "")
if not self.app_id:
raise Exception("阿里云百炼 APP ID 不能为空。")
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空。")
self.model_name = "dashscope"
self.variables: dict = provider_config.get("variables", {})
self.rag_options: dict = provider_config.get("rag_options", {})
self.output_reference = self.rag_options.get("output_reference", False)
self.rag_options = self.rag_options.copy()
self.rag_options.pop("output_reference", None)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
"""判断是否有 RAG 选项
Returns:
bool: 是否有 RAG 选项
"""
if self.rag_options and (
len(self.rag_options.get("pipeline_ids", [])) > 0
or len(self.rag_options.get("file_ids", [])) > 0
):
return True
return False
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
image_urls = []
if contexts is None:
contexts = []
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
if (
self.dashscope_app_type in ["agent", "dialog-workflow"]
and not self.has_rag_options()
):
# 支持多轮对话的
new_record = {"role": "user", "content": prompt}
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
contexts_no_img = await self._remove_image_from_context(contexts)
context_query = [*contexts_no_img, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"api_key": self.api_key,
"messages": context_query,
"biz_params": payload_vars or None,
}
partial = functools.partial(
Application.call,
**payload,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
else:
# 不支持多轮对话的
# 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
partial = functools.partial(
Application.call,
**payload,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
assert isinstance(response, ApplicationResponse)
logger.debug(f"dashscope resp: {response}")
if response.status_code != 200:
logger.error(
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
)
return LLMResponse(
role="err",
result_chain=MessageChain().message(
f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
),
)
output_text = response.output.get("text", "") or ""
# RAG 引用脚标格式化
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
if self.output_reference and response.output.get("doc_references", None):
ref_parts = []
for ref in response.output.get("doc_references", []) or []:
ref_title = (
ref.get("title", "")
if ref.get("title")
else ref.get("doc_name", "")
)
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
ref_str = "".join(ref_parts)
output_text += f"\n\n回答来源:\n{ref_str}"
llm_response = LLMResponse("assistant")
llm_response.result_chain = MessageChain().message(output_text)
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def forget(self, session_id):
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
async def terminate(self):
pass

View File

@@ -1,285 +0,0 @@
import os
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_file, download_image_by_url
from .. import Provider
from ..entities import LLMResponse
from ..register import register_provider_adapter
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
def __init__(
self,
provider_config,
provider_settings,
) -> None:
super().__init__(
provider_config,
provider_settings,
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:
raise Exception("Dify API Key 不能为空。")
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_type = provider_config.get("dify_api_type", "")
if not self.api_type:
raise Exception("Dify API 类型不能为空。")
self.model_name = "dify"
self.workflow_output_key = provider_config.get(
"dify_workflow_output_key",
"astrbot_wf_output",
)
self.dify_query_input_key = provider_config.get(
"dify_query_input_key",
"astrbot_text_query",
)
if not self.dify_query_input_key:
self.dify_query_input_key = "astrbot_text_query"
if not self.workflow_output_key:
self.workflow_output_key = "astrbot_wf_output"
self.variables: dict = provider_config.get("variables", {})
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.conversation_ids = {}
"""记录当前 session id 的对话 ID"""
self.api_client = DifyAPIClient(self.api_key, api_base)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
image_urls = []
result = ""
session_id = session_id or kwargs.get("user") or "unknown" # 1734
conversation_id = self.conversation_ids.get(session_id, "")
files_payload = []
for image_url in image_urls:
image_path = (
await download_image_by_url(image_url)
if image_url.startswith("http")
else image_url
)
file_response = await self.api_client.file_upload(
image_path,
user=session_id,
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。",
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
},
)
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
try:
match self.api_type:
case "chat" | "agent" | "chatflow":
if not prompt:
prompt = "请描述这张图片。"
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify resp chunk: {chunk}")
if (
chunk["event"] == "message"
or chunk["event"] == "agent_message"
):
result += chunk["answer"]
if not conversation_id:
self.conversation_ids[session_id] = chunk[
"conversation_id"
]
conversation_id = chunk["conversation_id"]
elif chunk["event"] == "message_end":
logger.debug("Dify message end")
break
elif chunk["event"] == "error":
logger.error(f"Dify 出现错误:{chunk}")
raise Exception(
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}",
)
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**payload_vars,
},
user=session_id,
files=files_payload,
timeout=self.timeout,
):
match chunk["event"]:
case "workflow_started":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。",
)
case "node_finished":
logger.debug(
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。",
)
case "workflow_finished":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束",
)
logger.debug(f"Dify 工作流结果:{chunk}")
if chunk["data"]["error"]:
logger.error(
f"Dify 工作流出现错误:{chunk['data']['error']}",
)
raise Exception(
f"Dify 工作流出现错误:{chunk['data']['error']}",
)
if (
self.workflow_output_key
not in chunk["data"]["outputs"]
):
raise Exception(
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}",
)
result = chunk
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
except Exception as e:
logger.error(f"Dify 请求失败:{e!s}")
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}")
if not result:
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
chain = await self.parse_dify_result(result)
return LLMResponse(role="assistant", result_chain=chain)
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
if isinstance(chunk, str):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict):
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"{item['filename']}.wav")
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
case "video":
return Comp.Video(file=item["url"])
case _:
return Comp.File(name=item["filename"], file=item["url"])
output = chunk["data"]["outputs"][self.workflow_output_key]
chains = []
if isinstance(output, str):
# 纯文本输出
chains.append(Comp.Plain(output))
elif isinstance(output, list):
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
for item in output:
# handle Array[File]
if (
not isinstance(item, dict)
or item.get("dify_model_identity", "") != "__dify__file__"
):
chains.append(Comp.Plain(str(output)))
break
else:
chains.append(Comp.Plain(str(output)))
# scan file
files = chunk["data"].get("files", [])
for item in files:
comp = await parse_file(item)
chains.append(comp)
return MessageChain(chain=chains)
async def forget(self, session_id):
self.conversation_ids[session_id] = ""
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("Dify 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 Dify 的历史消息记录。")
async def terminate(self):
await self.api_client.close()

View File

@@ -85,3 +85,22 @@ class UmopConfigRouter:
self.umop_to_conf_id[umo] = conf_id
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
async def delete_route(self, umo: str):
"""删除一条路由
Args:
umo (str): 需要删除的 UMO 字符串
Raises:
ValueError: 当 umo 格式不正确时抛出
"""
if not isinstance(umo, str) or len(umo.split(":")) != 3:
raise ValueError(
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
)
if umo in self.umop_to_conf_id:
del self.umop_to_conf_id[umo]
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)

View File

@@ -0,0 +1,73 @@
import traceback
from astrbot.core import astrbot_config, logger
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None:
"""
Migra agent runner configs from provider configs.
"""
try:
default_prov_id = conf["provider_settings"]["default_provider_id"]
if default_prov_id in ids_map:
conf["provider_settings"]["default_provider_id"] = ""
p = ids_map[default_prov_id]
if p["type"] == "dify":
conf["provider_settings"]["dify_agent_runner_provider_id"] = p["id"]
conf["provider_settings"]["agent_runner_type"] = "dify"
elif p["type"] == "coze":
conf["provider_settings"]["coze_agent_runner_provider_id"] = p["id"]
conf["provider_settings"]["agent_runner_type"] = "coze"
elif p["type"] == "dashscope":
conf["provider_settings"]["dashscope_agent_runner_provider_id"] = p[
"id"
]
conf["provider_settings"]["agent_runner_type"] = "dashscope"
conf.save_config()
except Exception as e:
logger.error(f"Migration for third party agent runner configs failed: {e!s}")
logger.error(traceback.format_exc())
async def migra(
db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager
) -> None:
"""
Stores the migration logic here.
btw, i really don't like migration :(
"""
# 4.5 to 4.6 migration for umop_config_router
try:
await migrate_45_to_46(astrbot_config_mgr, umop_config_router)
except Exception as e:
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# migration for webchat session
try:
await migrate_webchat_session(db)
except Exception as e:
logger.error(f"Migration for webchat session failed: {e!s}")
logger.error(traceback.format_exc())
# migra third party agent runner configs
_c = False
providers = astrbot_config["provider"]
ids_map = {}
for prov in providers:
type_ = prov.get("type")
if type_ in ["dify", "coze", "dashscope"]:
prov["provider_type"] = "agent_runner"
ids_map[prov["id"]] = {
"type": type_,
"id": prov["id"],
}
_c = True
if _c:
astrbot_config.save_config()
for conf in acm.confs.values():
_migra_agent_runner_configs(conf, ids_map)

View File

@@ -40,9 +40,6 @@ class SharedPreferences:
else:
ret = default
return ret
raise ValueError(
"scope_id and key cannot be None when getting a specific preference.",
)
async def range_get_async(
self,
@@ -56,30 +53,6 @@ class SharedPreferences:
ret = await self.db_helper.get_preferences(scope, scope_id, key)
return ret
@overload
async def session_get(
self,
umo: None,
key: str,
default: Any = None,
) -> list[Preference]: ...
@overload
async def session_get(
self,
umo: str,
key: None,
default: Any = None,
) -> list[Preference]: ...
@overload
async def session_get(
self,
umo: None,
key: None,
default: Any = None,
) -> list[Preference]: ...
async def session_get(
self,
umo: str | None,
@@ -88,7 +61,7 @@ class SharedPreferences:
) -> _VT | list[Preference]:
"""获取会话范围的偏好设置
Note: 当 scope_id 或者 key 为 None返回 Preference 列表,其中的 value 属性是一个 dictvalue["val"] 为值。
Note: 当 umo 或者 key 为 None返回 Preference 列表,其中的 value 属性是一个 dictvalue["val"] 为值。
"""
if umo is None or key is None:
return await self.range_get_async("umo", umo, key)

View File

@@ -56,6 +56,7 @@ class ChatRoute(Route):
self.conv_mgr = core_lifecycle.conversation_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.db = db
self.umop_config_router = core_lifecycle.umop_config_router
self.running_convs: dict[str, bool] = {}
@@ -266,7 +267,8 @@ class ChatRoute(Route):
return Response().error("Permission denied").__dict__
# 删除该会话下的所有对话
unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}"
message_type = "GroupMessage" if session.is_group else "FriendMessage"
unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}"
await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
# 删除消息历史
@@ -276,6 +278,16 @@ class ChatRoute(Route):
offset_sec=99999999,
)
# 删除与会话关联的配置路由
try:
await self.umop_config_router.delete_route(unified_msg_origin)
except ValueError as exc:
logger.warning(
"Failed to delete UMO route %s during session cleanup: %s",
unified_msg_origin,
exc,
)
# 清理队列(仅对 webchat
if session.platform_id == "webchat":
webchat_queue_mgr.remove_queues(session_id)

29
changelogs/v4.6.1.md Normal file
View File

@@ -0,0 +1,29 @@
## What's Changed
**hot fix of v4.6.0**
fix(core.db): 修复升级后 webchat 相关对话数据未正确迁移的问题 ([#3745](https://github.com/AstrBotDevs/AstrBot/issues/3745))
---
1. 新增: 支持 gemini-3 系列的 thought signature ([#3698](https://github.com/AstrBotDevs/AstrBot/issues/3698))
2. 新增: 支持知识库的 Agentic 检索功能 ([#3667](https://github.com/AstrBotDevs/AstrBot/issues/3667))
3. 新增: 为知识库添加 URL 文档解析器 ([#3622](https://github.com/AstrBotDevs/AstrBot/issues/3622))
4. 修复(core.platform): 修复启用多个企业微信智能机器人适配器时消息混乱的问题 ([#3693](https://github.com/AstrBotDevs/AstrBot/issues/3693))
5. 修复: MCP Server 连接成功一段时间后,调用 mcp 工具时可能出现 `anyio.ClosedResourceError` 错误 ([#3700](https://github.com/AstrBotDevs/AstrBot/issues/3700))
6. 新增(chat): 重构聊天组件结构并添加新功能 ([#3701](https://github.com/AstrBotDevs/AstrBot/issues/3701))
7. 修复(dashboard.i18n): 完善缺失的英文国际化键值 ([#3699](https://github.com/AstrBotDevs/AstrBot/issues/3699))
8. 重构: 实现 WebChat 会话管理及从版本 4.6 迁移到 4.7
9. 持续集成(docker-build): 每日构建 Nightly 版本 Docker 镜像 ([#3120](https://github.com/AstrBotDevs/AstrBot/issues/3120))
---
1. feat: add supports for gemini-3 series thought signature ([#3698](https://github.com/AstrBotDevs/AstrBot/issues/3698))
2. feat: supports knowledge base agentic search ([#3667](https://github.com/AstrBotDevs/AstrBot/issues/3667))
3. feat: Add URL document parser for knowledge base ([#3622](https://github.com/AstrBotDevs/AstrBot/issues/3622))
4. fix(core.platform): fix message mix-up issue when enabling multiple WeCom AI Bot adapters ([#3693](https://github.com/AstrBotDevs/AstrBot/issues/3693))
5. fix: fix `anyio.ClosedResourceError` that may occur when calling mcp tools after a period of successful connection to MCP Server ([#3700](https://github.com/AstrBotDevs/AstrBot/issues/3700))
6. feat(chat): refactor chat component structure and add new features ([#3701](https://github.com/AstrBotDevs/AstrBot/issues/3701))
7. fix(dashboard.i18n): complete the missing i18n keys for en([#3699](https://github.com/AstrBotDevs/AstrBot/issues/3699))
8. refactor: Implement WebChat session management and migration from version 4.6 to 4.7
9. ci(docker-build): build nightly image everyday ([#3120](https://github.com/AstrBotDevs/AstrBot/issues/3120))

View File

@@ -87,6 +87,8 @@
:disabled="isStreaming || isConvRunning"
:enableStreaming="enableStreaming"
:isRecording="isRecording"
:session-id="currSessionId || null"
:current-session="getCurrentSession"
@send="handleSendMessage"
@toggleStreaming="toggleStreaming"
@removeImage="removeImage"

View File

@@ -11,7 +11,14 @@
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
<div style="display: flex; justify-content: space-between; align-items: center; padding: 0px 12px;">
<div style="display: flex; justify-content: flex-start; margin-top: 4px; align-items: center; gap: 8px;">
<ProviderModelSelector ref="providerModelSelectorRef" />
<ConfigSelector
:session-id="sessionId || null"
:platform-id="sessionPlatformId"
:is-group="sessionIsGroup"
:initial-config-id="props.configId"
@config-changed="handleConfigChange"
/>
<ProviderModelSelector v-if="showProviderSelector" ref="providerModelSelectorRef" />
<v-tooltip :text="enableStreaming ? tm('streaming.enabled') : tm('streaming.disabled')" location="top">
<template v-slot:activator="{ props }">
@@ -58,9 +65,11 @@
</template>
<script setup lang="ts">
import { ref, computed, onMounted, onBeforeUnmount, watch } from 'vue';
import { ref, computed, onMounted, onBeforeUnmount } from 'vue';
import { useModuleI18n } from '@/i18n/composables';
import ProviderModelSelector from './ProviderModelSelector.vue';
import ConfigSelector from './ConfigSelector.vue';
import type { Session } from '@/composables/useSessions';
interface Props {
prompt: string;
@@ -69,9 +78,16 @@ interface Props {
disabled: boolean;
enableStreaming: boolean;
isRecording: boolean;
sessionId?: string | null;
currentSession?: Session | null;
configId?: string | null;
}
const props = defineProps<Props>();
const props = withDefaults(defineProps<Props>(), {
sessionId: null,
currentSession: null,
configId: null
});
const emit = defineEmits<{
'update:prompt': [value: string];
@@ -90,12 +106,16 @@ const { tm } = useModuleI18n('features/chat');
const inputField = ref<HTMLTextAreaElement | null>(null);
const imageInputRef = ref<HTMLInputElement | null>(null);
const providerModelSelectorRef = ref<InstanceType<typeof ProviderModelSelector> | null>(null);
const showProviderSelector = ref(true);
const localPrompt = computed({
get: () => props.prompt,
set: (value) => emit('update:prompt', value)
});
const sessionPlatformId = computed(() => props.currentSession?.platform_id || 'webchat');
const sessionIsGroup = computed(() => Boolean(props.currentSession?.is_group));
const canSend = computed(() => {
return (props.prompt && props.prompt.trim()) || props.stagedImagesUrl.length > 0 || props.stagedAudioUrl;
});
@@ -168,7 +188,16 @@ function handleRecordClick() {
}
}
function handleConfigChange(payload: { configId: string; agentRunnerType: string }) {
const runnerType = (payload.agentRunnerType || '').toLowerCase();
const isInternal = runnerType === 'internal' || runnerType === 'local';
showProviderSelector.value = isInternal;
}
function getCurrentSelection() {
if (!showProviderSelector.value) {
return null;
}
return providerModelSelectorRef.value?.getCurrentSelection();
}

View File

@@ -0,0 +1,313 @@
<template>
<div>
<v-tooltip text="选择用于当前会话的配置文件" location="top">
<template #activator="{ props: tooltipProps }">
<v-chip
v-bind="tooltipProps"
class="text-none config-chip"
variant="tonal"
size="x-small"
rounded="lg"
@click="openDialog"
:disabled="loadingConfigs || saving"
>
<v-icon start size="14">mdi-cog</v-icon>
{{ selectedConfigLabel }}
</v-chip>
</template>
</v-tooltip>
<v-dialog v-model="dialog" max-width="480" persistent>
<v-card>
<v-card-title class="d-flex align-center justify-space-between">
<span>选择配置文件</span>
<v-btn icon variant="text" @click="closeDialog">
<v-icon>mdi-close</v-icon>
</v-btn>
</v-card-title>
<v-card-text>
<div v-if="loadingConfigs" class="text-center py-6">
<v-progress-circular indeterminate color="primary"></v-progress-circular>
</div>
<v-list v-else class="config-list" density="comfortable">
<v-list-item
v-for="config in configOptions"
:key="config.id"
:active="tempSelectedConfig === config.id"
rounded="lg"
variant="text"
@click="tempSelectedConfig = config.id"
>
<v-list-item-title>{{ config.name }}</v-list-item-title>
<v-list-item-subtitle class="text-caption text-grey">
{{ config.id }}
</v-list-item-subtitle>
<template #append>
<v-icon v-if="tempSelectedConfig === config.id" color="primary">mdi-check</v-icon>
</template>
</v-list-item>
<div v-if="configOptions.length === 0" class="text-center text-body-2 text-medium-emphasis">
暂无可选配置请先在配置页创建
</div>
</v-list>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn variant="text" @click="closeDialog">取消</v-btn>
<v-btn
color="primary"
@click="confirmSelection"
:disabled="!tempSelectedConfig"
:loading="saving"
>
应用
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div>
</template>
<script setup lang="ts">
import { computed, onMounted, ref, watch } from 'vue';
import axios from 'axios';
import { useToast } from '@/utils/toast';
interface ConfigInfo {
id: string;
name: string;
}
interface ConfigChangedPayload {
configId: string;
agentRunnerType: string;
}
const STORAGE_KEY = 'chat.selectedConfigId';
const props = withDefaults(defineProps<{
sessionId?: string | null;
platformId?: string;
isGroup?: boolean;
initialConfigId?: string | null;
}>(), {
sessionId: null,
platformId: 'webchat',
isGroup: false,
initialConfigId: null
});
const emit = defineEmits<{ 'config-changed': [ConfigChangedPayload] }>();
const configOptions = ref<ConfigInfo[]>([]);
const loadingConfigs = ref(false);
const dialog = ref(false);
const tempSelectedConfig = ref('');
const selectedConfigId = ref('default');
const agentRunnerType = ref('local');
const saving = ref(false);
const pendingSync = ref(false);
const routingEntries = ref<Array<{ pattern: string; confId: string }>>([]);
const configCache = ref<Record<string, string>>({});
const toast = useToast();
const normalizedSessionId = computed(() => {
const id = props.sessionId?.trim();
return id ? id : null;
});
const hasActiveSession = computed(() => !!normalizedSessionId.value);
const messageType = computed(() => (props.isGroup ? 'GroupMessage' : 'FriendMessage'));
const username = computed(() => localStorage.getItem('user') || 'guest');
const sessionKey = computed(() => {
if (!normalizedSessionId.value) {
return null;
}
return `${props.platformId}!${username.value}!${normalizedSessionId.value}`;
});
const targetUmo = computed(() => {
if (!sessionKey.value) {
return null;
}
return `${props.platformId}:${messageType.value}:${sessionKey.value}`;
});
const selectedConfigLabel = computed(() => {
const target = configOptions.value.find((item) => item.id === selectedConfigId.value);
return target?.name || selectedConfigId.value || 'default';
});
function openDialog() {
tempSelectedConfig.value = selectedConfigId.value;
dialog.value = true;
}
function closeDialog() {
dialog.value = false;
}
async function fetchConfigList() {
loadingConfigs.value = true;
try {
const res = await axios.get('/api/config/abconfs');
configOptions.value = res.data.data?.info_list || [];
} catch (error) {
console.error('加载配置文件列表失败', error);
configOptions.value = [];
} finally {
loadingConfigs.value = false;
}
}
async function fetchRoutingEntries() {
try {
const res = await axios.get('/api/config/umo_abconf_routes');
const routing = res.data.data?.routing || {};
routingEntries.value = Object.entries(routing).map(([pattern, confId]) => ({
pattern,
confId: confId as string
}));
} catch (error) {
console.error('获取配置路由失败', error);
routingEntries.value = [];
}
}
function matchesPattern(pattern: string, target: string): boolean {
const parts = pattern.split(':');
const targetParts = target.split(':');
if (parts.length !== 3 || targetParts.length !== 3) {
return false;
}
return parts.every((part, index) => part === '' || part === '*' || part === targetParts[index]);
}
function resolveConfigId(umo: string | null): string {
if (!umo) {
return 'default';
}
for (const entry of routingEntries.value) {
if (matchesPattern(entry.pattern, umo)) {
return entry.confId;
}
}
return 'default';
}
async function getAgentRunnerType(confId: string): Promise<string> {
if (configCache.value[confId]) {
return configCache.value[confId];
}
try {
const res = await axios.get('/api/config/abconf', {
params: { id: confId }
});
const type = res.data.data?.config?.provider_settings?.agent_runner_type || 'local';
configCache.value[confId] = type;
return type;
} catch (error) {
console.error('获取配置文件详情失败', error);
return 'local';
}
}
async function setSelection(confId: string) {
const normalized = confId || 'default';
selectedConfigId.value = normalized;
const runnerType = await getAgentRunnerType(normalized);
agentRunnerType.value = runnerType;
emit('config-changed', {
configId: normalized,
agentRunnerType: runnerType
});
}
async function applySelectionToBackend(confId: string): Promise<boolean> {
if (!targetUmo.value) {
pendingSync.value = true;
return true;
}
saving.value = true;
try {
await axios.post('/api/config/umo_abconf_route/update', {
umo: targetUmo.value,
conf_id: confId
});
const filtered = routingEntries.value.filter((entry) => entry.pattern !== targetUmo.value);
filtered.push({ pattern: targetUmo.value, confId });
routingEntries.value = filtered;
return true;
} catch (error) {
const err = error as any;
console.error('更新配置文件失败', err);
toast.error(err?.response?.data?.message || '配置文件应用失败');
return false;
} finally {
saving.value = false;
}
}
async function confirmSelection() {
if (!tempSelectedConfig.value) {
return;
}
const previousId = selectedConfigId.value;
await setSelection(tempSelectedConfig.value);
localStorage.setItem(STORAGE_KEY, tempSelectedConfig.value);
const applied = await applySelectionToBackend(tempSelectedConfig.value);
if (!applied) {
localStorage.setItem(STORAGE_KEY, previousId);
await setSelection(previousId);
}
dialog.value = false;
}
async function syncSelectionForSession() {
if (!targetUmo.value) {
pendingSync.value = true;
return;
}
if (pendingSync.value) {
pendingSync.value = false;
await applySelectionToBackend(selectedConfigId.value);
return;
}
await fetchRoutingEntries();
const resolved = resolveConfigId(targetUmo.value);
await setSelection(resolved);
localStorage.setItem(STORAGE_KEY, resolved);
}
watch(
() => [props.sessionId, props.platformId, props.isGroup],
async () => {
await syncSelectionForSession();
}
);
onMounted(async () => {
await fetchConfigList();
const stored = props.initialConfigId || localStorage.getItem(STORAGE_KEY) || 'default';
selectedConfigId.value = stored;
await setSelection(stored);
await syncSelectionForSession();
});
</script>
<style scoped>
.config-chip {
cursor: pointer;
justify-content: flex-start;
}
.config-list {
max-height: 360px;
overflow-y: auto;
}
</style>

View File

@@ -64,7 +64,7 @@
@click.stop="$emit('editTitle', item.session_id, item.display_name)" />
<v-btn icon="mdi-delete" size="x-small" variant="text"
class="delete-conversation-btn" color="error"
@click.stop="$emit('deleteConversation', item.session_id)" />
@click.stop="handleDeleteConversation(item)" />
</div>
</template>
</v-list-item>
@@ -85,7 +85,7 @@
<script setup lang="ts">
import { ref } from 'vue';
import { useI18n, useModuleI18n } from '@/i18n/composables';
import { useModuleI18n } from '@/i18n/composables';
import type { Session } from '@/composables/useSessions';
interface Props {
@@ -109,7 +109,6 @@ const emit = defineEmits<{
}>();
const { tm } = useModuleI18n('features/chat');
const { t } = useI18n();
const sidebarCollapsed = ref(true);
const sidebarHovered = ref(false);
@@ -159,6 +158,14 @@ function handleSidebarMouseLeave() {
}
sidebarHoverExpanded.value = false;
}
function handleDeleteConversation(session: Session) {
const sessionTitle = session.display_name || tm('conversation.newConversation');
const message = tm('conversation.confirmDelete', { name: sessionTitle });
if (window.confirm(message)) {
emit('deleteConversation', session.session_id);
}
}
</script>
<style scoped>

View File

@@ -3,6 +3,7 @@
<!-- 选择提供商和模型按钮 -->
<v-chip class="text-none" variant="tonal" size="x-small"
v-if="selectedProviderId && selectedModelName" @click="openDialog">
<v-icon start size="14">mdi-creation</v-icon>
{{ selectedProviderId }} / {{ selectedModelName }}
</v-chip>
<v-chip variant="tonal" rounded="xl" size="x-small" v-else @click="openDialog">

View File

@@ -0,0 +1,319 @@
<template>
<v-card class="standalone-chat-card" elevation="0" rounded="0">
<v-card-text class="standalone-chat-container">
<div class="chat-layout">
<!-- 聊天内容区域 -->
<div class="chat-content-panel">
<MessageList v-if="messages && messages.length > 0" :messages="messages" :isDark="isDark"
:isStreaming="isStreaming || isConvRunning" @openImagePreview="openImagePreview"
ref="messageList" />
<div class="welcome-container fade-in" v-else>
<div class="welcome-title">
<span>Hello, I'm</span>
<span class="bot-name">AstrBot ⭐</span>
</div>
<p class="text-caption text-medium-emphasis mt-2">
测试配置: {{ configId || 'default' }}
</p>
</div>
<!-- 输入区域 -->
<ChatInput
v-model:prompt="prompt"
:stagedImagesUrl="stagedImagesUrl"
:stagedAudioUrl="stagedAudioUrl"
:disabled="isStreaming || isConvRunning"
:enableStreaming="enableStreaming"
:isRecording="isRecording"
:session-id="currSessionId || null"
:current-session="getCurrentSession"
:config-id="configId"
@send="handleSendMessage"
@toggleStreaming="toggleStreaming"
@removeImage="removeImage"
@removeAudio="removeAudio"
@startRecording="handleStartRecording"
@stopRecording="handleStopRecording"
@pasteImage="handlePaste"
@fileSelect="handleFileSelect"
ref="chatInputRef"
/>
</div>
</div>
</v-card-text>
</v-card>
<!-- 图片预览对话框 -->
<v-dialog v-model="imagePreviewDialog" max-width="90vw" max-height="90vh">
<v-card class="image-preview-card" elevation="8">
<v-card-title class="d-flex justify-space-between align-center pa-4">
<span>{{ t('core.common.imagePreview') }}</span>
<v-btn icon="mdi-close" variant="text" @click="imagePreviewDialog = false" />
</v-card-title>
<v-card-text class="text-center pa-4">
<img :src="previewImageUrl" class="preview-image-large" />
</v-card-text>
</v-card>
</v-dialog>
</template>
<script setup lang="ts">
import { ref, computed, onMounted, onBeforeUnmount, nextTick } from 'vue';
import axios from 'axios';
import { useCustomizerStore } from '@/stores/customizer';
import { useI18n, useModuleI18n } from '@/i18n/composables';
import { useTheme } from 'vuetify';
import MessageList from '@/components/chat/MessageList.vue';
import ChatInput from '@/components/chat/ChatInput.vue';
import { useMessages } from '@/composables/useMessages';
import { useMediaHandling } from '@/composables/useMediaHandling';
import { useRecording } from '@/composables/useRecording';
import { useToast } from '@/utils/toast';
interface Props {
configId?: string | null;
}
const props = withDefaults(defineProps<Props>(), {
configId: null
});
const { t } = useI18n();
const { error: showError } = useToast();
// UI 状态
const imagePreviewDialog = ref(false);
const previewImageUrl = ref('');
// 会话管理(不使用 useSessions 避免路由跳转)
const currSessionId = ref('');
const getCurrentSession = computed(() => null); // 独立测试模式不需要会话信息
async function newSession() {
try {
const response = await axios.get('/api/chat/new_session');
const sessionId = response.data.data.session_id;
currSessionId.value = sessionId;
return sessionId;
} catch (err) {
console.error(err);
throw err;
}
}
function updateSessionTitle(sessionId: string, title: string) {
// 独立模式不需要更新会话标题
}
function getSessions() {
// 独立模式不需要加载会话列表
}
const {
stagedImagesName,
stagedImagesUrl,
stagedAudioUrl,
getMediaFile,
processAndUploadImage,
handlePaste,
removeImage,
removeAudio,
clearStaged,
cleanupMediaCache
} = useMediaHandling();
const { isRecording, startRecording: startRec, stopRecording: stopRec } = useRecording();
const {
messages,
isStreaming,
isConvRunning,
enableStreaming,
getSessionMessages: getSessionMsg,
sendMessage: sendMsg,
toggleStreaming
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
// 组件引用
const messageList = ref<InstanceType<typeof MessageList> | null>(null);
const chatInputRef = ref<InstanceType<typeof ChatInput> | null>(null);
// 输入状态
const prompt = ref('');
const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark');
function openImagePreview(imageUrl: string) {
previewImageUrl.value = imageUrl;
imagePreviewDialog.value = true;
}
async function handleStartRecording() {
await startRec();
}
async function handleStopRecording() {
const audioFilename = await stopRec();
stagedAudioUrl.value = audioFilename;
}
async function handleFileSelect(files: FileList) {
for (const file of files) {
await processAndUploadImage(file);
}
}
async function handleSendMessage() {
if (!prompt.value.trim() && stagedImagesName.value.length === 0 && !stagedAudioUrl.value) {
return;
}
try {
if (!currSessionId.value) {
await newSession();
}
const promptToSend = prompt.value.trim();
const imageNamesToSend = [...stagedImagesName.value];
const audioNameToSend = stagedAudioUrl.value;
// 清空输入和附件
prompt.value = '';
clearStaged();
// 获取选择的提供商和模型
const selection = chatInputRef.value?.getCurrentSelection();
const selectedProviderId = selection?.providerId || '';
const selectedModelName = selection?.modelName || '';
await sendMsg(
promptToSend,
imageNamesToSend,
audioNameToSend,
selectedProviderId,
selectedModelName
);
// 滚动到底部
nextTick(() => {
messageList.value?.scrollToBottom();
});
} catch (err) {
console.error('Failed to send message:', err);
showError(t('features.chat.errors.sendMessageFailed'));
// 恢复输入内容,让用户可以重试
// 注意:附件已经上传到服务器,所以不恢复附件
}
}
onMounted(async () => {
// 独立模式在挂载时创建新会话
try {
await newSession();
} catch (err) {
console.error('Failed to create initial session:', err);
showError(t('features.chat.errors.createSessionFailed'));
}
});
onBeforeUnmount(() => {
cleanupMediaCache();
});
</script>
<style scoped>
/* 基础动画 */
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.standalone-chat-card {
width: 100%;
height: 100%;
max-height: 100%;
overflow: hidden;
}
.standalone-chat-container {
width: 100%;
height: 100%;
max-height: 100%;
padding: 0;
overflow: hidden;
}
.chat-layout {
height: 100%;
max-height: 100%;
display: flex;
overflow: hidden;
}
.chat-content-panel {
height: 100%;
max-height: 100%;
width: 100%;
display: flex;
flex-direction: column;
overflow: hidden;
}
.conversation-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 8px;
padding-left: 16px;
border-bottom: 1px solid var(--v-theme-border);
width: 100%;
padding-right: 32px;
flex-shrink: 0;
}
.conversation-header-info h4 {
margin: 0;
font-weight: 500;
}
.conversation-header-actions {
display: flex;
gap: 8px;
align-items: center;
}
.welcome-container {
height: 100%;
display: flex;
justify-content: center;
align-items: center;
flex-direction: column;
}
.welcome-title {
font-size: 28px;
margin-bottom: 8px;
}
.bot-name {
font-weight: 700;
margin-left: 8px;
color: var(--v-theme-secondary);
}
.fade-in {
animation: fadeIn 0.3s ease-in-out;
}
.preview-image-large {
max-width: 100%;
max-height: 70vh;
object-fit: contain;
}
</style>

View File

@@ -7,6 +7,10 @@
<v-icon start>mdi-message-text</v-icon>
{{ tm('dialogs.addProvider.tabs.basic') }}
</v-tab>
<v-tab value="agent_runner" class="font-weight-medium px-3">
<v-icon start>mdi-cogs</v-icon>
{{ tm('dialogs.addProvider.tabs.agentRunner') }}
</v-tab>
<v-tab value="speech_to_text" class="font-weight-medium px-3">
<v-icon start>mdi-microphone-message</v-icon>
{{ tm('dialogs.addProvider.tabs.speechToText') }}
@@ -27,7 +31,7 @@
<v-window v-model="activeProviderTab" class="mt-4">
<v-window-item
v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
v-for="tabType in ['chat_completion', 'agent_runner', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
:key="tabType" :value="tabType">
<v-row class="mt-1">
<v-col v-for="(template, name) in getTemplatesByType(tabType)" :key="name" cols="12" sm="6"
@@ -36,7 +40,7 @@
@click="selectProviderTemplate(name)">
<div class="provider-card-content">
<div class="provider-card-text">
<v-card-title class="provider-card-title">接入 {{ name }}</v-card-title>
<v-card-title class="provider-card-title">{{ name }}</v-card-title>
<v-card-text
class="text-caption text-medium-emphasis provider-card-description">
{{ getProviderDescription(template, name) }}
@@ -54,7 +58,7 @@
</v-col>
<v-col v-if="Object.keys(getTemplatesByType(tabType)).length === 0" cols="12">
<v-alert type="info" variant="tonal">
{{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
{{ t('dialogs.addProvider.noTemplates') }}
</v-alert>
</v-col>
</v-row>
@@ -104,19 +108,6 @@ export default {
this.$emit('update:show', value);
}
},
// 翻译消息的计算属性
messages() {
return {
tabTypes: {
'chat_completion': this.tm('providers.tabs.chatCompletion'),
'speech_to_text': this.tm('providers.tabs.speechToText'),
'text_to_speech': this.tm('providers.tabs.textToSpeech'),
'embedding': this.tm('providers.tabs.embedding'),
'rerank': this.tm('providers.tabs.rerank')
}
};
}
},
methods: {
closeDialog() {
@@ -140,11 +131,6 @@ export default {
// 从工具函数导入
getProviderIcon,
// 获取Tab类型的中文名称
getTabTypeName(tabType) {
return this.messages.tabTypes[tabType] || tabType;
},
// 获取提供商简介
getProviderDescription(template, name) {
return getProviderDescription(template, name, this.tm);

View File

@@ -101,6 +101,21 @@ function shouldShowItem(itemMeta, itemKey) {
return true
}
// 检查最外层的 object 是否应该显示
function shouldShowSection() {
const sectionMeta = props.metadata[props.metadataKey]
if (!sectionMeta?.condition) {
return true
}
for (const [conditionKey, expectedValue] of Object.entries(sectionMeta.condition)) {
const actualValue = getValueBySelector(props.iterable, conditionKey)
if (actualValue !== expectedValue) {
return false
}
}
return true
}
function hasVisibleItemsAfter(items, currentIndex) {
const itemEntries = Object.entries(items)
@@ -114,12 +129,33 @@ function hasVisibleItemsAfter(items, currentIndex) {
return false
}
function parseSpecialValue(value) {
if (!value || typeof value !== 'string') {
return { name: '', subtype: '' }
}
const [name, ...rest] = value.split(':')
return {
name,
subtype: rest.join(':') || ''
}
}
function getSpecialName(value) {
return parseSpecialValue(value).name
}
function getSpecialSubtype(value) {
return parseSpecialValue(value).subtype
}
</script>
<template>
<v-card style="margin-bottom: 16px; padding-bottom: 8px; background-color: rgb(var(--v-theme-background));" rounded="md" variant="outlined">
<v-card v-if="shouldShowSection()" style="margin-bottom: 16px; padding-bottom: 8px; background-color: rgb(var(--v-theme-background));"
rounded="md" variant="outlined">
<v-card-text class="config-section" v-if="metadata[metadataKey]?.type === 'object'" style="padding-bottom: 8px;">
<v-list-item-title class="config-title">
{{ metadata[metadataKey]?.description }}
@@ -187,22 +223,16 @@ function hasVisibleItemsAfter(items, currentIndex) {
<!-- Boolean switch for JSON selector -->
<v-switch v-else-if="itemMeta?.type === 'bool'" v-model="createSelectorModel(itemKey).value"
color="primary" inset density="compact" hide-details style="display: flex; justify-content: end;"></v-switch>
color="primary" inset density="compact" hide-details
style="display: flex; justify-content: end;"></v-switch>
<!-- List item for JSON selector -->
<ListConfigItem
v-else-if="itemMeta?.type === 'list'"
v-model="createSelectorModel(itemKey).value"
button-text="修改"
class="config-field"
/>
<ListConfigItem v-else-if="itemMeta?.type === 'list'" v-model="createSelectorModel(itemKey).value"
button-text="修改" class="config-field" />
<!-- Object editor for JSON selector -->
<ObjectEditor
v-else-if="itemMeta?.type === 'dict'"
v-model="createSelectorModel(itemKey).value"
class="config-field"
/>
<ObjectEditor v-else-if="itemMeta?.type === 'dict'" v-model="createSelectorModel(itemKey).value"
class="config-field" />
<!-- Fallback for JSON selector -->
<v-text-field v-else v-model="createSelectorModel(itemKey).value" density="compact" variant="outlined"
@@ -211,50 +241,36 @@ function hasVisibleItemsAfter(items, currentIndex) {
<!-- Special handling for specific metadata types -->
<div v-else-if="itemMeta?._special === 'select_provider'">
<ProviderSelector
v-model="createSelectorModel(itemKey).value"
:provider-type="'chat_completion'"
/>
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'chat_completion'" />
</div>
<div v-else-if="itemMeta?._special === 'select_provider_stt'">
<ProviderSelector
v-model="createSelectorModel(itemKey).value"
:provider-type="'speech_to_text'"
/>
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'speech_to_text'" />
</div>
<div v-else-if="itemMeta?._special === 'select_provider_tts'">
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'text_to_speech'" />
</div>
<div v-else-if="getSpecialName(itemMeta?._special) === 'select_agent_runner_provider'">
<ProviderSelector
v-model="createSelectorModel(itemKey).value"
:provider-type="'text_to_speech'"
:provider-type="'agent_runner'"
:provider-subtype="getSpecialSubtype(itemMeta?._special)"
/>
</div>
<div v-else-if="itemMeta?._special === 'provider_pool'">
<ProviderSelector
v-model="createSelectorModel(itemKey).value"
:provider-type="'chat_completion'"
button-text="选择提供商池..."
/>
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'chat_completion'"
button-text="选择提供商池..." />
</div>
<div v-else-if="itemMeta?._special === 'select_persona'">
<PersonaSelector
v-model="createSelectorModel(itemKey).value"
/>
<PersonaSelector v-model="createSelectorModel(itemKey).value" />
</div>
<div v-else-if="itemMeta?._special === 'persona_pool'">
<PersonaSelector
v-model="createSelectorModel(itemKey).value"
button-text="选择人格池..."
/>
<PersonaSelector v-model="createSelectorModel(itemKey).value" button-text="选择人格池..." />
</div>
<div v-else-if="itemMeta?._special === 'select_knowledgebase'">
<KnowledgeBaseSelector
v-model="createSelectorModel(itemKey).value"
/>
<KnowledgeBaseSelector v-model="createSelectorModel(itemKey).value" />
</div>
<div v-else-if="itemMeta?._special === 'select_plugin_set'">
<PluginSetSelector
v-model="createSelectorModel(itemKey).value"
/>
<PluginSetSelector v-model="createSelectorModel(itemKey).value" />
</div>
<div v-else-if="itemMeta?._special === 't2i_template'">
<T2ITemplateEditor />
@@ -263,21 +279,17 @@ function hasVisibleItemsAfter(items, currentIndex) {
</v-row>
<!-- Plugin Set Selector 全宽显示区域 -->
<v-row v-if="!itemMeta?.invisible && itemMeta?._special === 'select_plugin_set'" class="plugin-set-display-row">
<v-row v-if="!itemMeta?.invisible && itemMeta?._special === 'select_plugin_set'"
class="plugin-set-display-row">
<v-col cols="12" class="plugin-set-display">
<div v-if="createSelectorModel(itemKey).value && createSelectorModel(itemKey).value.length > 0" class="selected-plugins-full-width">
<div v-if="createSelectorModel(itemKey).value && createSelectorModel(itemKey).value.length > 0"
class="selected-plugins-full-width">
<div class="plugins-header">
<small class="text-grey">已选择的插件</small>
</div>
<div class="d-flex flex-wrap ga-2 mt-2">
<v-chip
v-for="plugin in (createSelectorModel(itemKey).value || [])"
:key="plugin"
size="small"
label
color="primary"
variant="outlined"
>
<v-chip v-for="plugin in (createSelectorModel(itemKey).value || [])" :key="plugin" size="small" label
color="primary" variant="outlined">
{{ plugin === '*' ? '所有插件' : plugin }}
</v-chip>
</div>
@@ -285,7 +297,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
</v-col>
</v-row>
</template>
<v-divider class="config-divider" v-if="shouldShowItem(itemMeta, itemKey) && hasVisibleItemsAfter(metadata[metadataKey].items, index)"></v-divider>
<v-divider class="config-divider"
v-if="shouldShowItem(itemMeta, itemKey) && hasVisibleItemsAfter(metadata[metadataKey].items, index)"></v-divider>
</div>
</div>

View File

@@ -94,6 +94,10 @@ const props = defineProps({
type: String,
default: 'chat_completion'
},
providerSubtype: {
type: String,
default: ''
},
buttonText: {
type: String,
default: '选择提供商...'
@@ -127,7 +131,10 @@ async function loadProviders() {
}
})
if (response.data.status === 'ok') {
providerList.value = response.data.data || []
const providers = response.data.data || []
providerList.value = props.providerSubtype
? providers.filter((provider) => matchesProviderSubtype(provider, props.providerSubtype))
: providers
}
} catch (error) {
console.error('加载提供商列表失败:', error)
@@ -137,6 +144,17 @@ async function loadProviders() {
}
}
function matchesProviderSubtype(provider, subtype) {
if (!subtype) {
return true
}
const normalized = String(subtype).toLowerCase()
const candidates = [provider.type, provider.provider, provider.id]
.filter(Boolean)
.map((value) => String(value).toLowerCase())
return candidates.includes(normalized)
}
function selectProvider(provider) {
selectedProvider.value = provider.id
}

View File

@@ -4,8 +4,12 @@ import { useRouter } from 'vue-router';
export interface Session {
session_id: string;
display_name: string;
display_name: string | null;
updated_at: string;
platform_id: string;
creator: string;
is_group: number;
created_at: string;
}
export function useSessions(chatboxMode: boolean = false) {

View File

@@ -74,6 +74,7 @@
"delete": "Delete",
"copy": "Copy",
"edit": "Edit",
"copy": "Copy",
"noData": "No data available"
}
}

View File

@@ -51,7 +51,8 @@
"editDisplayName": "Edit Session Name",
"displayName": "Session Name",
"displayNameUpdated": "Session name updated",
"displayNameUpdateFailed": "Failed to update session name"
"displayNameUpdateFailed": "Failed to update session name",
"confirmDelete": "Are you sure you want to delete \"{name}\"? This action cannot be undone."
},
"modes": {
"darkMode": "Switch to Dark Mode",
@@ -84,5 +85,9 @@
"reconnected": "Chat connection re-established",
"failed": "Connection failed, please refresh the page"
}
},
"errors": {
"sendMessageFailed": "Failed to send message, please try again",
"createSessionFailed": "Failed to create session, please refresh the page"
}
}

View File

@@ -9,6 +9,7 @@
"tabs": {
"all": "All",
"chatCompletion": "Chat Completion",
"agentRunner": "Agent Runner",
"speechToText": "Speech to Text",
"textToSpeech": "Text to Speech",
"embedding": "Embedding",
@@ -44,12 +45,13 @@
"title": "Service Provider",
"tabs": {
"basic": "Basic",
"agentRunner": "Agent Runner",
"speechToText": "Speech to Text",
"textToSpeech": "Text to Speech",
"embedding": "Embedding",
"rerank": "Rerank"
},
"noTemplates": "No {type} type provider templates available"
"noTemplates": "No this type provider templates available"
},
"config": {
"addTitle": "Add",

View File

@@ -51,7 +51,8 @@
"editDisplayName": "编辑会话名称",
"displayName": "会话名称",
"displayNameUpdated": "会话名称已更新",
"displayNameUpdateFailed": "更新会话名称失败"
"displayNameUpdateFailed": "更新会话名称失败",
"confirmDelete": "确定要删除“{name}”吗?此操作无法撤销。"
},
"modes": {
"darkMode": "切换到夜间模式",
@@ -84,5 +85,9 @@
"reconnected": "聊天连接已重新建立",
"failed": "连接失败,请刷新页面重试"
}
},
"errors": {
"sendMessageFailed": "发送消息失败,请重试",
"createSessionFailed": "创建会话失败,请刷新页面重试"
}
}

View File

@@ -8,7 +8,8 @@
"providerType": "提供商类型",
"tabs": {
"all": "全部",
"chatCompletion": "基本对话",
"chatCompletion": "对话",
"agentRunner": "Agent 执行器",
"speechToText": "语音转文字",
"textToSpeech": "文字转语音",
"embedding": "嵌入(Embedding)",
@@ -44,13 +45,14 @@
"addProvider": {
"title": "模型提供商",
"tabs": {
"basic": "基本",
"basic": "对话",
"agentRunner": "Agent 执行器",
"speechToText": "语音转文字",
"textToSpeech": "文字转语音",
"embedding": "嵌入(Embedding)",
"rerank": "重排序(Rerank)"
},
"noTemplates": "暂无{type}类型的提供商模板"
"noTemplates": "暂无类型的提供商模板"
},
"config": {
"addTitle": "新增",

View File

@@ -45,6 +45,15 @@
@click="configToString(); codeEditorDialog = true">
</v-btn>
<v-tooltip text="测试当前配置" location="left" v-if="!isSystemConfig">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" icon="mdi-chat-processing" size="x-large"
style="position: fixed; right: 52px; bottom: 196px;" color="secondary"
@click="openTestChat">
</v-btn>
</template>
</v-tooltip>
</div>
</v-slide-y-transition>
@@ -135,6 +144,34 @@
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
<!-- 测试聊天抽屉 -->
<v-overlay
v-model="testChatDrawer"
class="test-chat-overlay"
location="right"
transition="slide-x-reverse-transition"
:scrim="true"
@click:outside="closeTestChat"
>
<v-card class="test-chat-card" elevation="12">
<div class="test-chat-header">
<div>
<span class="text-h6">测试配置</span>
<div v-if="selectedConfigInfo.name" class="text-caption text-grey">
{{ selectedConfigInfo.name }} ({{ testConfigId }})
</div>
</div>
<v-btn icon variant="text" @click="closeTestChat">
<v-icon>mdi-close</v-icon>
</v-btn>
</div>
<v-divider></v-divider>
<div class="test-chat-content">
<StandaloneChat v-if="testChatDrawer" :configId="testConfigId" />
</div>
</v-card>
</v-overlay>
</template>
@@ -142,6 +179,7 @@
import axios from 'axios';
import AstrBotCoreConfigWrapper from '@/components/config/AstrBotCoreConfigWrapper.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import StandaloneChat from '@/components/chat/StandaloneChat.vue';
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
import { useI18n, useModuleI18n } from '@/i18n/composables';
@@ -150,7 +188,8 @@ export default {
components: {
AstrBotCoreConfigWrapper,
VueMonacoEditor,
WaitingForRestart
WaitingForRestart,
StandaloneChat
},
props: {
initialConfigId: {
@@ -238,6 +277,10 @@ export default {
name: '',
},
editingConfigId: null,
// 测试聊天
testChatDrawer: false,
testConfigId: null,
}
},
mounted() {
@@ -506,6 +549,20 @@ export default {
this.getConfigInfoList("default");
}
}
},
openTestChat() {
if (!this.selectedConfigID) {
this.save_message = "请先选择一个配置文件";
this.save_message_snack = true;
this.save_message_success = "warning";
return;
}
this.testConfigId = this.selectedConfigID;
this.testChatDrawer = true;
},
closeTestChat() {
this.testChatDrawer = false;
this.testConfigId = null;
}
},
}
@@ -565,4 +622,32 @@ export default {
width: 100%;
}
}
/* 测试聊天抽屉样式 */
.test-chat-overlay {
align-items: stretch;
justify-content: flex-end;
}
.test-chat-card {
width: clamp(320px, 50vw, 720px);
height: calc(100vh - 32px);
display: flex;
flex-direction: column;
margin: 16px;
}
.test-chat-header {
display: flex;
align-items: center;
justify-content: space-between;
padding: 16px 20px 12px 20px;
}
.test-chat-content {
flex: 1;
overflow: hidden;
padding: 0;
border-radius: 0 0 16px 16px;
}
</style>

View File

@@ -30,6 +30,10 @@
<v-icon start>mdi-message-text</v-icon>
{{ tm('providers.tabs.chatCompletion') }}
</v-tab>
<v-tab value="agent_runner" class="font-weight-medium px-3">
<v-icon start>mdi-message-text</v-icon>
{{ tm('providers.tabs.agentRunner') }}
</v-tab>
<v-tab value="speech_to_text" class="font-weight-medium px-3">
<v-icon start>mdi-microphone-message</v-icon>
{{ tm('providers.tabs.speechToText') }}
@@ -48,30 +52,62 @@
</v-tab>
</v-tabs>
<v-row v-if="filteredProviders.length === 0">
<v-col cols="12" class="text-center pa-8">
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
<p class="text-grey mt-4">{{ getEmptyText() }}</p>
</v-col>
</v-row>
<v-row v-else>
<v-col v-for="(provider, index) in filteredProviders" :key="index" cols="12" md="6" lg="4" xl="3">
<item-card :item="provider" title-field="id" enabled-field="enable"
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
@copy="copyProvider" :show-copy-button="true">
<template #actions="{ item }">
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
{{ tm('availability.test') }}
</v-btn>
</template>
<template v-slot:details="{ item }">
</template>
</item-card>
</v-col>
</v-row>
<template v-if="activeProviderTypeTab === 'all'">
<v-row v-if="groupedProviders.length === 0">
<v-col cols="12" class="text-center pa-8">
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
<p class="text-grey mt-4">{{ getEmptyText() }}</p>
</v-col>
</v-row>
<div v-else>
<div v-for="group in groupedProviders" :key="group.typeKey" class="mb-8">
<h1 class="text-h3 font-weight-bold mb-4">{{ group.label }}</h1>
<v-row>
<v-col v-for="(provider, index) in group.items" :key="`${group.typeKey}-${index}`" cols="12" md="6"
lg="4" xl="3">
<item-card :item="provider" title-field="id" enabled-field="enable"
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
@copy="copyProvider" :show-copy-button="true">
<template #actions="{ item }">
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
{{ tm('availability.test') }}
</v-btn>
</template>
<template v-slot:details="{ item }">
</template>
</item-card>
</v-col>
</v-row>
</div>
</div>
</template>
<template v-else>
<v-row v-if="filteredProviders.length === 0">
<v-col cols="12" class="text-center pa-8">
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
<p class="text-grey mt-4">{{ getEmptyText() }}</p>
</v-col>
</v-row>
<v-row v-else>
<v-col v-for="(provider, index) in filteredProviders" :key="index" cols="12" md="6" lg="4" xl="3">
<item-card :item="provider" title-field="id" enabled-field="enable"
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
@copy="copyProvider" :show-copy-button="true">
<template #actions="{ item }">
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
{{ tm('availability.test') }}
</v-btn>
</template>
<template v-slot:details="{ item }">
</template>
</item-card>
</v-col>
</v-row>
</template>
</div>
<!-- 供应商状态部分 -->
@@ -289,8 +325,8 @@ export default {
"anthropic_chat_completion": "chat_completion",
"googlegenai_chat_completion": "chat_completion",
"zhipu_chat_completion": "chat_completion",
"dify": "chat_completion",
"coze": "chat_completion",
"dify": "agent_runner",
"coze": "agent_runner",
"dashscope": "chat_completion",
"openai_whisper_api": "speech_to_text",
"openai_whisper_selfhost": "speech_to_text",
@@ -334,6 +370,7 @@ export default {
},
tabTypes: {
'chat_completion': this.tm('providers.tabs.chatCompletion'),
'agent_runner': this.tm('providers.tabs.agentRunner'),
'speech_to_text': this.tm('providers.tabs.speechToText'),
'text_to_speech': this.tm('providers.tabs.textToSpeech'),
'embedding': this.tm('providers.tabs.embedding'),
@@ -363,6 +400,52 @@ export default {
};
},
groupedProviders() {
if (!this.config_data.provider) {
return [];
}
const typeOrder = [
'chat_completion',
'agent_runner',
'speech_to_text',
'text_to_speech',
'embedding',
'rerank',
];
const assigned = new Set();
const groups = typeOrder
.map((typeKey) => {
const items = this.config_data.provider.filter((provider) => {
const resolved = this.getProviderType(provider);
if (resolved === typeKey) {
assigned.add(provider.id);
return true;
}
return false;
});
return {
typeKey,
label: this.messages.tabTypes[typeKey] || typeKey,
items,
};
})
.filter((group) => group.items.length > 0);
const remaining = this.config_data.provider.filter(
(provider) => !assigned.has(provider.id),
);
if (remaining.length > 0) {
groups.push({
typeKey: 'others',
label: this.tm('providers.tabs.all'),
items: remaining,
});
}
return groups;
},
// 根据选择的标签过滤提供商列表
filteredProviders() {
if (!this.config_data.provider || this.activeProviderTypeTab === 'all') {
@@ -371,13 +454,7 @@ export default {
return this.config_data.provider.filter(provider => {
// 如果provider.provider_type已经存在直接使用它
if (provider.provider_type) {
return provider.provider_type === this.activeProviderTypeTab;
}
// 否则使用映射关系
const mappedType = this.oldVersionProviderTypeMapping[provider.type];
return mappedType === this.activeProviderTypeTab;
return this.getProviderType(provider) === this.activeProviderTypeTab;
});
}
},
@@ -387,6 +464,14 @@ export default {
},
methods: {
getProviderType(provider) {
if (!provider) return undefined;
if (provider.provider_type) {
return provider.provider_type;
}
return this.oldVersionProviderTypeMapping[provider.type];
},
getConfig() {
axios.get('/api/config/get').then((res) => {
this.config_data = res.data.data.config;
@@ -690,6 +775,9 @@ export default {
if (!provider.enable) {
throw new Error('该提供商未被用户启用');
}
if (provider.provider_type === 'agent_runner') {
throw new Error('暂时无法测试 Agent Runner 类型的提供商');
}
const res = await axios.get(`/api/config/provider/check_one?id=${provider.id}`);
if (res.data && res.data.status === 'ok') {

View File

@@ -2,14 +2,18 @@ import datetime
from astrbot.api import logger, sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.platform.message_type import MessageType
from astrbot.core.provider.sources.coze_source import ProviderCoze
from astrbot.core.provider.sources.dify_source import ProviderDify
from ..long_term_memory import LongTermMemory
from .utils.rst_scene import RstScene
THIRD_PARTY_AGENT_RUNNER_KEY = {
"dify": "dify_conversation_id",
"coze": "coze_conversation_id",
}
THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
class ConversationCommands:
def __init__(self, context: star.Context, ltm: LongTermMemory | None = None):
@@ -38,9 +42,9 @@ class ConversationCommands:
async def reset(self, message: AstrMessageEvent):
"""重置 LLM 会话"""
is_unique_session = self.context.get_config()["platform_settings"][
"unique_session"
]
umo = message.unified_msg_origin
cfg = self.context.get_config(umo=message.unified_msg_origin)
is_unique_session = cfg["platform_settings"]["unique_session"]
is_group = bool(message.get_group_id())
scene = RstScene.get_scene(is_group, is_unique_session)
@@ -63,28 +67,23 @@ class ConversationCommands:
)
return
if not self.context.get_using_provider(message.unified_msg_origin):
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
await sp.remove_async(
scope="umo",
scope_id=umo,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功。"))
return
if not self.context.get_using_provider(umo):
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type in ["dify", "coze"]:
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
"provider type is not dify or coze"
)
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message(
"已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。",
),
)
return
cid = await self.context.conversation_manager.get_curr_conversation_id(
message.unified_msg_origin,
)
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
if not cid:
message.set_result(
@@ -95,7 +94,7 @@ class ConversationCommands:
return
await self.context.conversation_manager.update_conversation(
message.unified_msg_origin,
umo,
cid,
[],
)
@@ -152,29 +151,14 @@ class ConversationCommands:
async def convs(self, message: AstrMessageEvent, page: int = 1):
"""查看对话列表"""
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
"""原有的Dify处理逻辑保持不变"""
parts = ["Dify 对话列表:\n"]
assert isinstance(provider, ProviderDify)
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
idx = 1
for conv in data["data"]:
ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime(
"%m-%d %H:%M",
)
parts.append(
f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n"
)
idx += 1
if idx == 1:
parts.append("没有找到任何对话。")
dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None)
parts.append(
f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。"
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
message.set_result(
MessageEventResult().message(
f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。",
),
)
ret = "".join(parts)
message.set_result(MessageEventResult().message(ret))
return
size_per_page = 6
@@ -227,9 +211,8 @@ class ConversationCommands:
else:
ret += "\n当前对话: 无"
unique_session = self.context.get_config()["platform_settings"][
"unique_session"
]
cfg = self.context.get_config(umo=message.unified_msg_origin)
unique_session = cfg["platform_settings"]["unique_session"]
if unique_session:
ret += "\n会话隔离粒度: 个人"
else:
@@ -243,15 +226,15 @@ class ConversationCommands:
async def new_conv(self, message: AstrMessageEvent):
"""创建新对话"""
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type in ["dify", "coze"]:
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
"provider type is not dify or coze"
)
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message("成功,下次聊天将是新对话。"),
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
await sp.remove_async(
scope="umo",
scope_id=message.unified_msg_origin,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("已创建新对话。"))
return
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
@@ -274,19 +257,9 @@ class ConversationCommands:
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
"""创建新群聊对话"""
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type in ["dify", "coze"]:
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
"provider type is not dify or coze"
)
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message("成功,下次聊天将是新对话。"),
)
return
if sid:
session = str(
MessageSesion(
MessageSession(
platform_name=message.platform_meta.id,
message_type=MessageType("GroupMessage"),
session_id=sid,
@@ -321,31 +294,6 @@ class ConversationCommands:
)
return
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify), "provider type is not dify"
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
if not data["data"]:
message.set_result(MessageEventResult().message("未找到任何对话。"))
return
selected_conv = None
if index is not None:
try:
selected_conv = data["data"][index - 1]
except IndexError:
message.set_result(
MessageEventResult().message("对话序号错误,请使用 /ls 查看"),
)
return
else:
selected_conv = data["data"][0]
ret = (
f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。"
)
provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"]
message.set_result(MessageEventResult().message(ret))
return
if index is None:
message.set_result(
MessageEventResult().message(
@@ -378,19 +326,6 @@ class ConversationCommands:
if not new_name:
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
return
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
cid = provider.conversation_ids.get(message.unified_msg_origin, None)
if not cid:
message.set_result(MessageEventResult().message("未找到当前对话。"))
return
await provider.api_client.rename(cid, new_name, message.unified_msg_origin)
message.set_result(MessageEventResult().message("重命名对话成功。"))
return
await self.context.conversation_manager.update_conversation_title(
message.unified_msg_origin,
new_name,
@@ -399,9 +334,8 @@ class ConversationCommands:
async def del_conv(self, message: AstrMessageEvent):
"""删除当前对话"""
is_unique_session = self.context.get_config()["platform_settings"][
"unique_session"
]
cfg = self.context.get_config(umo=message.unified_msg_origin)
is_unique_session = cfg["platform_settings"]["unique_session"]
if message.get_group_id() and not is_unique_session and message.role != "admin":
# 群聊,没开独立会话,发送人不是管理员
message.set_result(
@@ -411,20 +345,14 @@ class ConversationCommands:
)
return
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None)
if dify_cid:
await provider.api_client.delete_chat_conv(
message.unified_msg_origin,
dify_cid,
)
message.set_result(
MessageEventResult().message(
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。",
),
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
await sp.remove_async(
scope="umo",
scope_id=message.unified_msg_origin,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功。"))
return
session_curr_cid = (

View File

@@ -5,7 +5,6 @@ from astrbot.api.event import AstrMessageEvent, filter
from astrbot.api.message_components import Image, Plain
from astrbot.api.provider import LLMResponse, ProviderRequest
from astrbot.core import logger
from astrbot.core.provider.sources.dify_source import ProviderDify
from .commands import (
AdminCommands,
@@ -279,33 +278,20 @@ class Main(star.Star):
return
try:
conv = None
if provider.meta().type != "dify":
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
event.unified_msg_origin,
)
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
event.unified_msg_origin,
)
if not session_curr_cid:
logger.error(
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
)
return
if not session_curr_cid:
logger.error(
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
)
return
conv = await self.context.conversation_manager.get_conversation(
event.unified_msg_origin,
session_curr_cid,
)
else:
# Dify 自己有维护对话,不需要 bot 端维护。
assert isinstance(provider, ProviderDify)
cid = provider.conversation_ids.get(
event.unified_msg_origin,
None,
)
if cid is None:
logger.error(
"[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
)
return
conv = await self.context.conversation_manager.get_conversation(
event.unified_msg_origin,
session_curr_cid,
)
prompt = event.message_str

View File

@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.6.0"
version = "4.6.1"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.10"