Merge pull request #3760 from AstrBotDevs/feat/agent-runner
refactor: transfer dify, coze and alibaba dashscope from chat provider to agent runner
This commit is contained in:
@@ -2,13 +2,12 @@ import abc
|
||||
import typing as T
|
||||
from enum import Enum, auto
|
||||
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..response import AgentResponse
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
@@ -24,9 +23,7 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
@abc.abstractmethod
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
@@ -60,3 +57,9 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
This method should be called after the agent is done.
|
||||
"""
|
||||
...
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""Transition the agent state."""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
367
astrbot/core/agent/runners/coze/coze_agent_runner.py
Normal file
367
astrbot/core/agent/runners/coze/coze_agent_runner.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Coze Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://"),
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||
|
||||
# 创建 API 客户端
|
||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||
|
||||
# 会话相关缓存
|
||||
self.file_id_cache: dict[str, dict[str, str]] = {}
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Coze Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Coze 请求并处理结果
|
||||
async for response in self._execute_coze_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"Coze 请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"Coze 请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await self.api_client.close()
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _execute_coze_request(self):
|
||||
"""执行 Coze 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
contexts = self.req.contexts or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
# 用户ID参数
|
||||
user_id = session_id
|
||||
|
||||
# 获取或创建会话ID
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=user_id,
|
||||
key="coze_conversation_id",
|
||||
default="",
|
||||
)
|
||||
|
||||
# 构建消息
|
||||
additional_messages = []
|
||||
|
||||
if system_prompt:
|
||||
if not self.auto_save_history or not conversation_id:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
# 处理历史上下文
|
||||
if not self.auto_save_history and contexts:
|
||||
for ctx in contexts:
|
||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||
# 处理上下文中的图片
|
||||
content = ctx["content"]
|
||||
if isinstance(content, list):
|
||||
# 多模态内容,需要处理图片
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
processed_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# 处理图片上传
|
||||
try:
|
||||
image_data = item.get("image_url", {})
|
||||
url = image_data.get("url", "")
|
||||
if url:
|
||||
file_id = (
|
||||
await self._download_and_upload_image(
|
||||
url, session_id
|
||||
)
|
||||
)
|
||||
processed_content.append(
|
||||
{
|
||||
"type": "file",
|
||||
"file_id": file_id,
|
||||
"file_url": url,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理上下文图片失败: {e}")
|
||||
continue
|
||||
|
||||
if processed_content:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": processed_content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本内容
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": content,
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
|
||||
# 构建当前消息
|
||||
if prompt or image_urls:
|
||||
if image_urls:
|
||||
# 多模态
|
||||
object_string_content = []
|
||||
if prompt:
|
||||
object_string_content.append({"type": "text", "text": prompt})
|
||||
|
||||
for url in image_urls:
|
||||
# the url is a base64 string
|
||||
try:
|
||||
image_data = base64.b64decode(url)
|
||||
file_id = await self.api_client.upload_file(image_data)
|
||||
object_string_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理图片失败 {url}: {e}")
|
||||
continue
|
||||
|
||||
if object_string_content:
|
||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
elif prompt:
|
||||
# 纯文本
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
# 执行 Coze API 请求
|
||||
accumulated_content = ""
|
||||
message_started = False
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
if event_type == "conversation.chat.created":
|
||||
if isinstance(data, dict) and "conversation_id" in data:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=user_id,
|
||||
key="coze_conversation_id",
|
||||
value=data["conversation_id"],
|
||||
)
|
||||
|
||||
if event_type == "conversation.message.delta":
|
||||
# 增量消息
|
||||
content = data.get("content", "")
|
||||
if not content and "delta" in data:
|
||||
content = data["delta"].get("content", "")
|
||||
if not content and "text" in data:
|
||||
content = data.get("text", "")
|
||||
|
||||
if content:
|
||||
accumulated_content += content
|
||||
message_started = True
|
||||
|
||||
# 如果是流式响应,发送增量数据
|
||||
if self.streaming:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(content)
|
||||
),
|
||||
)
|
||||
|
||||
elif event_type == "conversation.message.completed":
|
||||
# 消息完成
|
||||
logger.debug("Coze message completed")
|
||||
message_started = True
|
||||
|
||||
elif event_type == "conversation.chat.completed":
|
||||
# 对话完成
|
||||
logger.debug("Coze chat completed")
|
||||
break
|
||||
|
||||
elif event_type == "error":
|
||||
# 错误处理
|
||||
error_msg = data.get("msg", "未知错误")
|
||||
error_code = data.get("code", "UNKNOWN")
|
||||
logger.error(f"Coze 出现错误: {error_code} - {error_msg}")
|
||||
raise Exception(f"Coze 出现错误: {error_code} - {error_msg}")
|
||||
|
||||
if not message_started and not accumulated_content:
|
||||
logger.warning("Coze 未返回任何内容")
|
||||
accumulated_content = ""
|
||||
|
||||
# 创建最终响应
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def _download_and_upload_image(
|
||||
self,
|
||||
image_url: str,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
import hashlib
|
||||
|
||||
# 计算哈希实现缓存
|
||||
cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest()
|
||||
|
||||
if session_id:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}")
|
||||
return file_id
|
||||
|
||||
try:
|
||||
image_data = await self.api_client.download_image(image_url)
|
||||
file_id = await self.api_client.upload_file(image_data)
|
||||
|
||||
if session_id:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {image_url}: {e!s}")
|
||||
raise Exception(f"处理图片失败: {e!s}")
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
403
astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py
Normal file
403
astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py
Normal file
@@ -0,0 +1,403 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import queue
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import typing as T
|
||||
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Dashscope Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("阿里云百炼 API Key 不能为空。")
|
||||
self.app_id = provider_config.get("dashscope_app_id", "")
|
||||
if not self.app_id:
|
||||
raise Exception("阿里云百炼 APP ID 不能为空。")
|
||||
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
||||
if not self.dashscope_app_type:
|
||||
raise Exception("阿里云百炼 APP 类型不能为空。")
|
||||
|
||||
self.variables: dict = provider_config.get("variables", {}) or {}
|
||||
self.rag_options: dict = provider_config.get("rag_options", {})
|
||||
self.output_reference = self.rag_options.get("output_reference", False)
|
||||
self.rag_options = self.rag_options.copy()
|
||||
self.rag_options.pop("output_reference", None)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
def has_rag_options(self):
|
||||
"""判断是否有 RAG 选项
|
||||
|
||||
Returns:
|
||||
bool: 是否有 RAG 选项
|
||||
|
||||
"""
|
||||
if self.rag_options and (
|
||||
len(self.rag_options.get("pipeline_ids", [])) > 0
|
||||
or len(self.rag_options.get("file_ids", [])) > 0
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Dashscope Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Dashscope 请求并处理结果
|
||||
async for response in self._execute_dashscope_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"阿里云百炼请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
def _consume_sync_generator(
|
||||
self, response: T.Any, response_queue: queue.Queue
|
||||
) -> None:
|
||||
"""在线程中消费同步generator,将结果放入队列
|
||||
|
||||
Args:
|
||||
response: 同步generator对象
|
||||
response_queue: 用于传递数据的队列
|
||||
|
||||
"""
|
||||
try:
|
||||
if self.streaming:
|
||||
for chunk in response:
|
||||
response_queue.put(("data", chunk))
|
||||
else:
|
||||
response_queue.put(("data", response))
|
||||
except Exception as e:
|
||||
response_queue.put(("error", e))
|
||||
finally:
|
||||
response_queue.put(("done", None))
|
||||
|
||||
async def _process_stream_chunk(
|
||||
self, chunk: ApplicationResponse, output_text: str
|
||||
) -> tuple[str, list | None, AgentResponse | None]:
|
||||
"""处理流式响应的单个chunk
|
||||
|
||||
Args:
|
||||
chunk: Dashscope响应chunk
|
||||
output_text: 当前累积的输出文本
|
||||
|
||||
Returns:
|
||||
(更新后的output_text, doc_references, AgentResponse或None)
|
||||
|
||||
"""
|
||||
logger.debug(f"dashscope stream chunk: {chunk}")
|
||||
|
||||
if chunk.status_code != 200:
|
||||
logger.error(
|
||||
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
|
||||
)
|
||||
self._transition_state(AgentState.ERROR)
|
||||
error_msg = (
|
||||
f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
|
||||
)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err",
|
||||
result_chain=MessageChain().message(error_msg),
|
||||
)
|
||||
return (
|
||||
output_text,
|
||||
None,
|
||||
AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(chain=MessageChain().message(error_msg)),
|
||||
),
|
||||
)
|
||||
|
||||
chunk_text = chunk.output.get("text", "") or ""
|
||||
# RAG 引用脚标格式化
|
||||
chunk_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", chunk_text)
|
||||
|
||||
response = None
|
||||
if chunk_text:
|
||||
output_text += chunk_text
|
||||
response = AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=MessageChain().message(chunk_text)),
|
||||
)
|
||||
|
||||
# 获取文档引用
|
||||
doc_references = chunk.output.get("doc_references", None)
|
||||
|
||||
return output_text, doc_references, response
|
||||
|
||||
def _format_doc_references(self, doc_references: list) -> str:
|
||||
"""格式化文档引用为文本
|
||||
|
||||
Args:
|
||||
doc_references: 文档引用列表
|
||||
|
||||
Returns:
|
||||
格式化后的引用文本
|
||||
|
||||
"""
|
||||
ref_parts = []
|
||||
for ref in doc_references:
|
||||
ref_title = (
|
||||
ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
|
||||
)
|
||||
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
|
||||
ref_str = "".join(ref_parts)
|
||||
return f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
async def _build_request_payload(
|
||||
self, prompt: str, session_id: str, contexts: list, system_prompt: str
|
||||
) -> dict:
|
||||
"""构建请求payload
|
||||
|
||||
Args:
|
||||
prompt: 用户输入
|
||||
session_id: 会话ID
|
||||
contexts: 上下文列表
|
||||
system_prompt: 系统提示词
|
||||
|
||||
Returns:
|
||||
请求payload字典
|
||||
|
||||
"""
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dashscope_conversation_id",
|
||||
default="",
|
||||
)
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
session_var = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_variables",
|
||||
default={},
|
||||
)
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if (
|
||||
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
||||
and not self.has_rag_options()
|
||||
):
|
||||
# 支持多轮对话的
|
||||
p = {
|
||||
"app_id": self.app_id,
|
||||
"api_key": self.api_key,
|
||||
"prompt": prompt,
|
||||
"biz_params": payload_vars or None,
|
||||
"stream": self.streaming,
|
||||
"incremental_output": True,
|
||||
}
|
||||
if conversation_id:
|
||||
p["session_id"] = conversation_id
|
||||
return p
|
||||
else:
|
||||
# 不支持多轮对话的
|
||||
payload = {
|
||||
"app_id": self.app_id,
|
||||
"prompt": prompt,
|
||||
"api_key": self.api_key,
|
||||
"biz_params": payload_vars or None,
|
||||
"stream": self.streaming,
|
||||
"incremental_output": True,
|
||||
}
|
||||
if self.rag_options:
|
||||
payload["rag_options"] = self.rag_options
|
||||
return payload
|
||||
|
||||
async def _handle_streaming_response(
|
||||
self, response: T.Any, session_id: str
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""处理流式响应
|
||||
|
||||
Args:
|
||||
response: Dashscope 流式响应 generator
|
||||
|
||||
Yields:
|
||||
AgentResponse 对象
|
||||
|
||||
"""
|
||||
response_queue = queue.Queue()
|
||||
consumer_thread = threading.Thread(
|
||||
target=self._consume_sync_generator,
|
||||
args=(response, response_queue),
|
||||
daemon=True,
|
||||
)
|
||||
consumer_thread.start()
|
||||
|
||||
output_text = ""
|
||||
doc_references = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
|
||||
None, response_queue.get, True, 1
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if item_type == "done":
|
||||
break
|
||||
elif item_type == "error":
|
||||
raise item_data
|
||||
elif item_type == "data":
|
||||
chunk = item_data
|
||||
assert isinstance(chunk, ApplicationResponse)
|
||||
|
||||
(
|
||||
output_text,
|
||||
chunk_doc_refs,
|
||||
response,
|
||||
) = await self._process_stream_chunk(chunk, output_text)
|
||||
|
||||
if response:
|
||||
if response.type == "err":
|
||||
yield response
|
||||
return
|
||||
yield response
|
||||
|
||||
if chunk_doc_refs:
|
||||
doc_references = chunk_doc_refs
|
||||
|
||||
if chunk.output.session_id:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dashscope_conversation_id",
|
||||
value=chunk.output.session_id,
|
||||
)
|
||||
|
||||
# 添加 RAG 引用
|
||||
if self.output_reference and doc_references:
|
||||
ref_text = self._format_doc_references(doc_references)
|
||||
output_text += ref_text
|
||||
|
||||
if self.streaming:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=MessageChain().message(ref_text)),
|
||||
)
|
||||
|
||||
# 创建最终响应
|
||||
chain = MessageChain(chain=[Comp.Plain(output_text)])
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def _execute_dashscope_request(self):
|
||||
"""执行 Dashscope 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
contexts = self.req.contexts or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
# 检查图片输入
|
||||
if image_urls:
|
||||
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
||||
|
||||
# 构建请求payload
|
||||
payload = await self._build_request_payload(
|
||||
prompt, session_id, contexts, system_prompt
|
||||
)
|
||||
|
||||
if not self.streaming:
|
||||
payload["incremental_output"] = False
|
||||
|
||||
# 发起请求
|
||||
partial = functools.partial(Application.call, **payload)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
async for resp in self._handle_streaming_response(response, session_id):
|
||||
yield resp
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
336
astrbot/core/agent/runners/dify/dify_agent_runner.py
Normal file
336
astrbot/core/agent/runners/dify/dify_agent_runner.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .dify_api_client import DifyAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DifyAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Dify Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||
self.api_type = provider_config.get("dify_api_type", "chat")
|
||||
self.workflow_output_key = provider_config.get(
|
||||
"dify_workflow_output_key",
|
||||
"astrbot_wf_output",
|
||||
)
|
||||
self.dify_query_input_key = provider_config.get(
|
||||
"dify_query_input_key",
|
||||
"astrbot_text_query",
|
||||
)
|
||||
self.variables: dict = provider_config.get("variables", {}) or {}
|
||||
self.timeout = provider_config.get("timeout", 60)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
self.api_client = DifyAPIClient(self.api_key, self.api_base)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Dify Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Dify 请求并处理结果
|
||||
async for response in self._execute_dify_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"Dify 请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"Dify 请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"Dify 请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await self.api_client.close()
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _execute_dify_request(self):
|
||||
"""执行 Dify 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dify_conversation_id",
|
||||
default="",
|
||||
)
|
||||
result = ""
|
||||
|
||||
# 处理图片上传
|
||||
files_payload = []
|
||||
for image_url in image_urls:
|
||||
# image_url is a base64 string
|
||||
try:
|
||||
image_data = base64.b64decode(image_url)
|
||||
file_response = await self.api_client.file_upload(
|
||||
file_data=image_data,
|
||||
user=session_id,
|
||||
mime_type="image/png",
|
||||
file_name="image.png",
|
||||
)
|
||||
logger.debug(f"Dify 上传图片响应:{file_response}")
|
||||
if "id" not in file_response:
|
||||
logger.warning(
|
||||
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
|
||||
)
|
||||
continue
|
||||
files_payload.append(
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response["id"],
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"上传图片失败:{e}")
|
||||
continue
|
||||
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_var = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_variables",
|
||||
default={},
|
||||
)
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
# 处理不同的 API 类型
|
||||
match self.api_type:
|
||||
case "chat" | "agent" | "chatflow":
|
||||
if not prompt:
|
||||
prompt = "请描述这张图片。"
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={
|
||||
**payload_vars,
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk["event"] == "message" or chunk["event"] == "agent_message":
|
||||
result += chunk["answer"]
|
||||
if not conversation_id:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dify_conversation_id",
|
||||
value=chunk["conversation_id"],
|
||||
)
|
||||
conversation_id = chunk["conversation_id"]
|
||||
|
||||
# 如果是流式响应,发送增量数据
|
||||
if self.streaming and chunk["answer"]:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(chunk["answer"])
|
||||
),
|
||||
)
|
||||
elif chunk["event"] == "message_end":
|
||||
logger.debug("Dify message end")
|
||||
break
|
||||
elif chunk["event"] == "error":
|
||||
logger.error(f"Dify 出现错误:{chunk}")
|
||||
raise Exception(
|
||||
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
|
||||
)
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
self.dify_query_input_key: prompt,
|
||||
"astrbot_session_id": session_id,
|
||||
**payload_vars,
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify workflow resp chunk: {chunk}")
|
||||
match chunk["event"]:
|
||||
case "workflow_started":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。"
|
||||
)
|
||||
case "node_finished":
|
||||
logger.debug(
|
||||
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。"
|
||||
)
|
||||
case "text_chunk":
|
||||
if self.streaming and chunk["data"]["text"]:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
chunk["data"]["text"]
|
||||
)
|
||||
),
|
||||
)
|
||||
case "workflow_finished":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
|
||||
)
|
||||
logger.debug(f"Dify 工作流结果:{chunk}")
|
||||
if chunk["data"]["error"]:
|
||||
logger.error(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
||||
)
|
||||
if self.workflow_output_key not in chunk["data"]["outputs"]:
|
||||
raise Exception(
|
||||
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
|
||||
)
|
||||
result = chunk
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
|
||||
if not result:
|
||||
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
|
||||
|
||||
# 解析结果
|
||||
chain = await self.parse_dify_result(result)
|
||||
|
||||
# 创建最终响应
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||
"""解析 Dify 的响应结果"""
|
||||
if isinstance(chunk, str):
|
||||
# Chat
|
||||
return MessageChain(chain=[Comp.Plain(chunk)])
|
||||
|
||||
async def parse_file(item: dict):
|
||||
match item["type"]:
|
||||
case "image":
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
return Comp.Video(file=item["url"])
|
||||
case _:
|
||||
return Comp.File(name=item["filename"], file=item["url"])
|
||||
|
||||
output = chunk["data"]["outputs"][self.workflow_output_key]
|
||||
chains = []
|
||||
if isinstance(output, str):
|
||||
# 纯文本输出
|
||||
chains.append(Comp.Plain(output))
|
||||
elif isinstance(output, list):
|
||||
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
|
||||
for item in output:
|
||||
# handle Array[File]
|
||||
if (
|
||||
not isinstance(item, dict)
|
||||
or item.get("dify_model_identity", "") != "__dify__file__"
|
||||
):
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
break
|
||||
else:
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
|
||||
# scan file
|
||||
files = chunk["data"].get("files", [])
|
||||
for item in files:
|
||||
comp = await parse_file(item)
|
||||
chains.append(comp)
|
||||
|
||||
return MessageChain(chain=chains)
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import ClientResponse, ClientSession
|
||||
from aiohttp import ClientResponse, ClientSession, FormData
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -101,21 +101,59 @@ class DifyAPIClient:
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
file_path: str,
|
||||
user: str,
|
||||
file_path: str | None = None,
|
||||
file_data: bytes | None = None,
|
||||
file_name: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Upload a file to Dify. Must provide either file_path or file_data.
|
||||
|
||||
Args:
|
||||
user: The user ID.
|
||||
file_path: The path to the file to upload.
|
||||
file_data: The file data in bytes.
|
||||
file_name: Optional file name when using file_data.
|
||||
Returns:
|
||||
A dictionary containing the uploaded file information.
|
||||
"""
|
||||
url = f"{self.api_base}/files/upload"
|
||||
with open(file_path, "rb") as f:
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": f,
|
||||
}
|
||||
async with self.session.post(
|
||||
url,
|
||||
data=payload,
|
||||
headers=self.headers,
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
form = FormData()
|
||||
form.add_field("user", user)
|
||||
|
||||
if file_data is not None:
|
||||
# 使用 bytes 数据
|
||||
form.add_field(
|
||||
"file",
|
||||
file_data,
|
||||
filename=file_name or "uploaded_file",
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
)
|
||||
elif file_path is not None:
|
||||
# 使用文件路径
|
||||
import os
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
form.add_field(
|
||||
"file",
|
||||
file_content,
|
||||
filename=os.path.basename(file_path),
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
)
|
||||
else:
|
||||
raise ValueError("file_path 和 file_data 不能同时为 None")
|
||||
|
||||
async with self.session.post(
|
||||
url,
|
||||
data=form,
|
||||
headers=self.headers, # 不包含 Content-Type,让 aiohttp 自动设置
|
||||
) as resp:
|
||||
if resp.status != 200 and resp.status != 201:
|
||||
text = await resp.text()
|
||||
raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
@@ -69,12 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
if self.streaming:
|
||||
|
||||
@@ -68,6 +68,10 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"agent_runner_type": "local",
|
||||
"dify_agent_runner_provider_id": "",
|
||||
"coze_agent_runner_provider_id": "",
|
||||
"dashscope_agent_runner_provider_id": "",
|
||||
"unsupported_streaming_strategy": "realtime_segmenting",
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
@@ -1011,7 +1015,7 @@ CONFIG_METADATA_2 = {
|
||||
"id": "dify_app_default",
|
||||
"provider": "dify",
|
||||
"type": "dify",
|
||||
"provider_type": "chat_completion",
|
||||
"provider_type": "agent_runner",
|
||||
"enable": True,
|
||||
"dify_api_type": "chat",
|
||||
"dify_api_key": "",
|
||||
@@ -1025,20 +1029,20 @@ CONFIG_METADATA_2 = {
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
"provider": "coze",
|
||||
"provider_type": "chat_completion",
|
||||
"provider_type": "agent_runner",
|
||||
"type": "coze",
|
||||
"enable": True,
|
||||
"coze_api_key": "",
|
||||
"bot_id": "",
|
||||
"coze_api_base": "https://api.coze.cn",
|
||||
"timeout": 60,
|
||||
"auto_save_history": True,
|
||||
# "auto_save_history": True,
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"provider": "dashscope",
|
||||
"type": "dashscope",
|
||||
"provider_type": "chat_completion",
|
||||
"provider_type": "agent_runner",
|
||||
"enable": True,
|
||||
"dashscope_app_type": "agent",
|
||||
"dashscope_api_key": "",
|
||||
@@ -1907,7 +1911,6 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用",
|
||||
"type": "bool",
|
||||
"hint": "是否启用。",
|
||||
},
|
||||
"key": {
|
||||
"description": "API Key",
|
||||
@@ -2037,12 +2040,22 @@ CONFIG_METADATA_2 = {
|
||||
"unsupported_streaming_strategy": {
|
||||
"type": "string",
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"type": "string",
|
||||
},
|
||||
"dify_agent_runner_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"dashscope_agent_runner_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
},
|
||||
"tool_call_timeout": {
|
||||
"description": "工具调用超时时间(秒)",
|
||||
"type": "int",
|
||||
},
|
||||
},
|
||||
@@ -2180,30 +2193,75 @@ CONFIG_METADATA_3 = {
|
||||
"ai_group": {
|
||||
"name": "AI 配置",
|
||||
"metadata": {
|
||||
"ai": {
|
||||
"description": "模型",
|
||||
"agent_runner": {
|
||||
"description": "Agent 执行方式",
|
||||
"hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.enable": {
|
||||
"description": "启用大语言模型聊天",
|
||||
"description": "启用",
|
||||
"type": "bool",
|
||||
"hint": "AI 对话总开关",
|
||||
},
|
||||
"provider_settings.agent_runner_type": {
|
||||
"description": "执行器",
|
||||
"type": "string",
|
||||
"options": ["local", "dify", "coze", "dashscope"],
|
||||
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"],
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent 执行器提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_agent_runner_provider:coze",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "coze",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.dify_agent_runner_provider_id": {
|
||||
"description": "Dify Agent 执行器提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_agent_runner_provider:dify",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "dify",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.dashscope_agent_runner_provider_id": {
|
||||
"description": "阿里云百炼应用 Agent 执行器提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_agent_runner_provider:dashscope",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "dashscope",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ai": {
|
||||
"description": "模型",
|
||||
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.default_provider_id": {
|
||||
"description": "默认聊天模型",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时使用第一个模型。",
|
||||
"hint": "留空时使用第一个模型",
|
||||
},
|
||||
"provider_settings.default_image_caption_provider_id": {
|
||||
"description": "默认图片转述模型",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||
"hint": "留空代表不使用,可用于非多模态模型",
|
||||
},
|
||||
"provider_stt_settings.enable": {
|
||||
"description": "启用语音转文本",
|
||||
"type": "bool",
|
||||
"hint": "STT 总开关。",
|
||||
"hint": "STT 总开关",
|
||||
},
|
||||
"provider_stt_settings.provider_id": {
|
||||
"description": "默认语音转文本模型",
|
||||
@@ -2217,12 +2275,11 @@ CONFIG_METADATA_3 = {
|
||||
"provider_tts_settings.enable": {
|
||||
"description": "启用文本转语音",
|
||||
"type": "bool",
|
||||
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
|
||||
"hint": "TTS 总开关",
|
||||
},
|
||||
"provider_tts_settings.provider_id": {
|
||||
"description": "默认文本转语音模型",
|
||||
"type": "string",
|
||||
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
|
||||
"_special": "select_provider_tts",
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
@@ -2233,6 +2290,9 @@ CONFIG_METADATA_3 = {
|
||||
"type": "text",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
"description": "人格",
|
||||
@@ -2244,6 +2304,10 @@ CONFIG_METADATA_3 = {
|
||||
"_special": "select_persona",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"knowledgebase": {
|
||||
"description": "知识库",
|
||||
@@ -2272,6 +2336,10 @@ CONFIG_METADATA_3 = {
|
||||
"hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"websearch": {
|
||||
"description": "网页搜索",
|
||||
@@ -2308,6 +2376,10 @@ CONFIG_METADATA_3 = {
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
@@ -2316,34 +2388,51 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.display_reasoning_text": {
|
||||
"description": "显示思考内容",
|
||||
"type": "bool",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.identifier": {
|
||||
"description": "用户识别",
|
||||
"type": "bool",
|
||||
"hint": "启用后,会在提示词前包含用户 ID 信息。",
|
||||
},
|
||||
"provider_settings.group_name_display": {
|
||||
"description": "显示群名称",
|
||||
"type": "bool",
|
||||
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
|
||||
"hint": "启用后,在支持的平台(OneBot v11)上会在提示词前包含群名称信息。",
|
||||
},
|
||||
"provider_settings.datetime_system_prompt": {
|
||||
"description": "现实世界时间感知",
|
||||
"type": "bool",
|
||||
"hint": "启用后,会在系统提示词中附带当前时间信息。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.show_tool_use_status": {
|
||||
"description": "输出函数调用状态",
|
||||
"type": "bool",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.tool_call_timeout": {
|
||||
"description": "工具调用超时时间(秒)",
|
||||
"type": "int",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.streaming_response": {
|
||||
"description": "流式回复",
|
||||
"description": "流式输出",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.unsupported_streaming_strategy": {
|
||||
@@ -2359,17 +2448,23 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数。",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
||||
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求",
|
||||
},
|
||||
"provider_settings.prompt_prefix": {
|
||||
"description": "用户提示词",
|
||||
@@ -2381,6 +2476,9 @@ CONFIG_METADATA_3 = {
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -16,13 +16,12 @@ import time
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
|
||||
from astrbot.core import LogBroker, logger, sp
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
|
||||
@@ -34,6 +33,7 @@ from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.migra_helper import migra
|
||||
|
||||
from . import astrbot_config, html_renderer
|
||||
from .event_bus import EventBus
|
||||
@@ -97,18 +97,16 @@ class AstrBotCoreLifecycle:
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
# 4.5 to 4.6 migration for umop_config_router
|
||||
# apply migration
|
||||
try:
|
||||
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
|
||||
await migra(
|
||||
self.db,
|
||||
self.astrbot_config_mgr,
|
||||
self.umop_config_router,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for webchat session
|
||||
try:
|
||||
await migrate_webchat_session(self.db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(f"AstrBot migration failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
|
||||
48
astrbot/core/pipeline/process_stage/method/agent_request.py
Normal file
48
astrbot/core/pipeline/process_stage/method/agent_request.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from .agent_sub_stages.internal import InternalAgentSubStage
|
||||
from .agent_sub_stages.third_party import ThirdPartyAgentSubStage
|
||||
|
||||
|
||||
class AgentRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
|
||||
self.bot_wake_prefixs: list[str] = self.config["wake_prefix"]
|
||||
self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"]
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.prov_wake_prefix.startswith(bwp):
|
||||
logger.info(
|
||||
f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
||||
)
|
||||
self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :]
|
||||
|
||||
agent_runner_type = self.config["provider_settings"]["agent_runner_type"]
|
||||
if agent_runner_type == "local":
|
||||
self.agent_sub_stage = InternalAgentSubStage()
|
||||
else:
|
||||
self.agent_sub_stage = ThirdPartyAgentSubStage()
|
||||
await self.agent_sub_stage.initialize(ctx)
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]:
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug(
|
||||
"This pipeline does not enable AI capability, skip processing."
|
||||
)
|
||||
return
|
||||
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(
|
||||
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
|
||||
)
|
||||
return
|
||||
|
||||
async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix):
|
||||
yield resp
|
||||
@@ -21,27 +21,24 @@ from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from ....astr_agent_context import AgentContextWrapper
|
||||
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....astr_agent_run_util import AgentRunner, run_agent
|
||||
from ....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from ...context import PipelineContext, call_event_hook
|
||||
from ..stage import Stage
|
||||
from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
|
||||
from .....astr_agent_context import AgentContextWrapper
|
||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from .....astr_agent_run_util import AgentRunner, run_agent
|
||||
from .....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
class InternalAgentSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
|
||||
self.provider_wake_prefix: str = settings["wake_prefix"] # str
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
@@ -59,13 +56,6 @@ class LLMRequestSubStage(Stage):
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
logger.info(
|
||||
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
||||
)
|
||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -304,21 +294,10 @@ class LLMRequestSubStage(Stage):
|
||||
return fixed_messages
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
_nested: bool = False,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
@@ -348,12 +327,12 @@ class LLMRequestSubStage(Stage):
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if self.provider_wake_prefix and not event.message_str.startswith(
|
||||
self.provider_wake_prefix
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
@@ -0,0 +1,202 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
|
||||
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
|
||||
DashscopeAgentRunner,
|
||||
)
|
||||
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.agent.runners.base import BaseAgentRunner
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
|
||||
AGENT_RUNNER_TYPE_KEY = {
|
||||
"dify": "dify_agent_runner_provider_id",
|
||||
"coze": "coze_agent_runner_provider_id",
|
||||
"dashscope": "dashscope_agent_runner_provider_id",
|
||||
}
|
||||
|
||||
|
||||
async def run_third_party_agent(
|
||||
runner: "BaseAgentRunner",
|
||||
stream_to_general: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
"""
|
||||
运行第三方 agent runner 并转换响应格式
|
||||
类似于 run_agent 函数,但专门处理第三方 agent runner
|
||||
"""
|
||||
try:
|
||||
async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
|
||||
if resp.type == "streaming_delta":
|
||||
if stream_to_general:
|
||||
continue
|
||||
yield resp.data["chain"]
|
||||
elif resp.type == "llm_result":
|
||||
if stream_to_general:
|
||||
yield resp.data["chain"]
|
||||
except Exception as e:
|
||||
logger.error(f"Third party agent runner error: {e}")
|
||||
err_msg = (
|
||||
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
|
||||
f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
)
|
||||
yield MessageChain().message(err_msg)
|
||||
|
||||
|
||||
class ThirdPartyAgentSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.conf = ctx.astrbot_config
|
||||
self.runner_type = self.conf["provider_settings"]["agent_runner_type"]
|
||||
self.prov_id = self.conf["provider_settings"].get(
|
||||
AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
|
||||
"",
|
||||
)
|
||||
settings = ctx.astrbot_config["provider_settings"]
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
]
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
self.prov_cfg: dict = next(
|
||||
(p for p in self.conf["provider"] if p["id"] == self.prov_id),
|
||||
{},
|
||||
)
|
||||
if not self.prov_id or not self.prov_cfg:
|
||||
logger.error(
|
||||
"Third Party Agent Runner provider ID is not configured properly."
|
||||
)
|
||||
return
|
||||
|
||||
# make provider request
|
||||
req = ProviderRequest()
|
||||
req.session_id = event.unified_msg_origin
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_base64()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
if self.runner_type == "dify":
|
||||
runner = DifyAgentRunner[AstrAgentContext]()
|
||||
elif self.runner_type == "coze":
|
||||
runner = CozeAgentRunner[AstrAgentContext]()
|
||||
elif self.runner_type == "dashscope":
|
||||
runner = DashscopeAgentRunner[AstrAgentContext]()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported third party agent runner type: {self.runner_type}",
|
||||
)
|
||||
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
|
||||
await runner.reset(
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=60,
|
||||
),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
provider_config=self.prov_cfg,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_third_party_agent(
|
||||
runner,
|
||||
stream_to_general=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if runner.done():
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
if final_resp and final_resp.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=final_resp.result_chain.chain or [],
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 非流式响应或转换为普通响应
|
||||
async for _ in run_third_party_agent(
|
||||
runner,
|
||||
stream_to_general=stream_to_general,
|
||||
):
|
||||
yield
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
|
||||
if not final_resp or not final_resp.result_chain:
|
||||
logger.warning("Agent Runner 未返回最终结果。")
|
||||
return
|
||||
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=final_resp.result_chain.chain or [],
|
||||
result_content_type=ResultContentType.LLM_RESULT,
|
||||
),
|
||||
)
|
||||
yield
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=self.runner_type,
|
||||
provider_type=self.runner_type,
|
||||
),
|
||||
)
|
||||
@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[None, None]:
|
||||
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||
"activated_handlers",
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
from .method.llm_request import LLMRequestSubStage
|
||||
from .method.agent_request import AgentRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
|
||||
|
||||
@@ -17,9 +17,12 @@ class ProcessStage(Stage):
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
self.llm_request_sub_stage = LLMRequestSubStage()
|
||||
await self.llm_request_sub_stage.initialize(ctx)
|
||||
|
||||
# initialize agent sub stage
|
||||
self.agent_sub_stage = AgentRequestSubStage()
|
||||
await self.agent_sub_stage.initialize(ctx)
|
||||
|
||||
# initialize star request sub stage
|
||||
self.star_request_sub_stage = StarRequestSubStage()
|
||||
await self.star_request_sub_stage.initialize(ctx)
|
||||
|
||||
@@ -39,7 +42,7 @@ class ProcessStage(Stage):
|
||||
# Handler 的 LLM 请求
|
||||
event.set_extra("provider_request", resp)
|
||||
_t = False
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
_t = True
|
||||
yield
|
||||
if not _t:
|
||||
@@ -67,5 +70,5 @@ class ProcessStage(Stage):
|
||||
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||
return
|
||||
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -227,6 +227,8 @@ class ProviderManager:
|
||||
async def load_provider(self, provider_config: dict):
|
||||
if not provider_config["enable"]:
|
||||
return
|
||||
if provider_config.get("provider_type", "") == "agent_runner":
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...",
|
||||
@@ -247,14 +249,6 @@ class ProviderManager:
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "coze":
|
||||
from .sources.coze_source import ProviderCoze as ProviderCoze
|
||||
case "dashscope":
|
||||
from .sources.dashscope_source import (
|
||||
ProviderDashscope as ProviderDashscope,
|
||||
)
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import (
|
||||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||
|
||||
@@ -1,650 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
|
||||
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
|
||||
class ProviderCoze(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
)
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://"),
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||
self.conversation_ids: dict[str, str] = {}
|
||||
self.file_id_cache: dict[str, dict[str, str]] = {}
|
||||
|
||||
# 创建 API 客户端
|
||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||
|
||||
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
|
||||
"""生成统一的缓存键
|
||||
|
||||
Args:
|
||||
data: 图片数据或路径
|
||||
is_base64: 是否是 base64 数据
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
|
||||
"""
|
||||
try:
|
||||
if is_base64 and data.startswith("data:image/"):
|
||||
try:
|
||||
header, encoded = data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
cache_key = hashlib.md5(image_bytes).hexdigest()
|
||||
return cache_key
|
||||
except Exception:
|
||||
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
elif data.startswith(("http://", "https://")):
|
||||
# URL图片,使用URL作为缓存键
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
clean_path = (
|
||||
data.split("_")[0]
|
||||
if "_" in data and len(data.split("_")) >= 3
|
||||
else data
|
||||
)
|
||||
|
||||
if os.path.exists(clean_path):
|
||||
with open(clean_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
cache_key = hashlib.md5(file_content).hexdigest()
|
||||
return cache_key
|
||||
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
|
||||
return cache_key
|
||||
|
||||
async def _upload_file(
|
||||
self,
|
||||
file_data: bytes,
|
||||
session_id: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> str:
|
||||
"""上传文件到 Coze 并返回 file_id"""
|
||||
# 使用 API 客户端上传文件
|
||||
file_id = await self.api_client.upload_file(file_data)
|
||||
|
||||
# 缓存 file_id
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
async def _download_and_upload_image(
|
||||
self,
|
||||
image_url: str,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
# 计算哈希实现缓存
|
||||
cache_key = self._generate_cache_key(image_url) if session_id else None
|
||||
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
return file_id
|
||||
|
||||
try:
|
||||
image_data = await self.api_client.download_image(image_url)
|
||||
|
||||
file_id = await self._upload_file(image_data, session_id, cache_key)
|
||||
|
||||
if session_id and cache_key:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {image_url}: {e!s}")
|
||||
raise Exception(f"处理图片失败: {e!s}")
|
||||
|
||||
async def _process_context_images(
|
||||
self,
|
||||
content: str | list,
|
||||
session_id: str,
|
||||
) -> str:
|
||||
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
processed_content = []
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
processed_content.append(item)
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
processed_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# 处理图片逻辑
|
||||
if "file_id" in item:
|
||||
# 已经有 file_id
|
||||
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
|
||||
processed_content.append(item)
|
||||
else:
|
||||
# 获取图片数据
|
||||
image_data = ""
|
||||
if "image_url" in item and isinstance(item["image_url"], dict):
|
||||
image_data = item["image_url"].get("url", "")
|
||||
elif "data" in item:
|
||||
image_data = item.get("data", "")
|
||||
elif "url" in item:
|
||||
image_data = item.get("url", "")
|
||||
|
||||
if not image_data:
|
||||
continue
|
||||
# 计算哈希用于缓存
|
||||
cache_key = self._generate_cache_key(
|
||||
image_data,
|
||||
is_base64=image_data.startswith("data:image/"),
|
||||
)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id},
|
||||
)
|
||||
else:
|
||||
# 上传图片并缓存
|
||||
if image_data.startswith("data:image/"):
|
||||
# base64 处理
|
||||
_, encoded = image_data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
elif image_data.startswith(("http://", "https://")):
|
||||
# URL 图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
image_data,
|
||||
session_id,
|
||||
)
|
||||
# 为URL图片也添加缓存
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
elif os.path.exists(image_data):
|
||||
# 本地文件
|
||||
with open(image_data, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"无法处理的图片格式: {image_data[:50]}...",
|
||||
)
|
||||
continue
|
||||
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id},
|
||||
)
|
||||
|
||||
result = json.dumps(processed_content, ensure_ascii=False)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"处理上下文图片失败: {e!s}")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""文本对话, 内部使用流式接口实现非流式
|
||||
|
||||
Args:
|
||||
prompt (str): 用户提示词
|
||||
session_id (str): 会话ID
|
||||
image_urls (List[str]): 图片URL列表
|
||||
func_tool (FuncCall): 函数调用工具(不支持)
|
||||
contexts (List): 上下文列表
|
||||
system_prompt (str): 系统提示语
|
||||
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
|
||||
model (str): 模型名称(不支持)
|
||||
|
||||
Returns:
|
||||
LLMResponse: LLM响应对象
|
||||
|
||||
"""
|
||||
accumulated_content = ""
|
||||
final_response = None
|
||||
|
||||
async for llm_response in self.text_chat_stream(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
model=model,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.completion_text:
|
||||
accumulated_content += llm_response.completion_text
|
||||
else:
|
||||
final_response = llm_response
|
||||
|
||||
if final_response:
|
||||
return final_response
|
||||
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
return LLMResponse(role="assistant", completion_text="")
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话接口"""
|
||||
# 用户ID参数(参考文档, 可以自定义)
|
||||
user_id = session_id or kwargs.get("user", "default_user")
|
||||
|
||||
# 获取或创建会话ID
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
# 构建消息
|
||||
additional_messages = []
|
||||
|
||||
if system_prompt:
|
||||
if not self.auto_save_history or not conversation_id:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
contexts = self._ensure_message_to_dicts(contexts)
|
||||
if not self.auto_save_history and contexts:
|
||||
# 如果关闭了自动保存历史,传入上下文
|
||||
for ctx in contexts:
|
||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||
content = ctx["content"]
|
||||
content_type = ctx.get("content_type", "text")
|
||||
|
||||
# 处理可能包含图片的上下文
|
||||
if (
|
||||
content_type == "object_string"
|
||||
or (isinstance(content, str) and content.startswith("["))
|
||||
or (
|
||||
isinstance(content, list)
|
||||
and any(
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "image_url"
|
||||
for item in content
|
||||
)
|
||||
)
|
||||
):
|
||||
processed_content = await self._process_context_images(
|
||||
content,
|
||||
user_id,
|
||||
)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": processed_content,
|
||||
"content_type": "object_string",
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 纯文本
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": (
|
||||
content
|
||||
if isinstance(content, str)
|
||||
else json.dumps(content, ensure_ascii=False)
|
||||
),
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
|
||||
|
||||
if prompt or image_urls:
|
||||
if image_urls:
|
||||
# 多模态
|
||||
object_string_content = []
|
||||
if prompt:
|
||||
object_string_content.append({"type": "text", "text": prompt})
|
||||
|
||||
for url in image_urls:
|
||||
try:
|
||||
if url.startswith(("http://", "https://")):
|
||||
# 网络图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
url,
|
||||
user_id,
|
||||
)
|
||||
else:
|
||||
# 本地文件或 base64
|
||||
if url.startswith("data:image/"):
|
||||
# base64
|
||||
_, encoded = url.split(",", 1)
|
||||
image_data = base64.b64decode(encoded)
|
||||
cache_key = self._generate_cache_key(
|
||||
url,
|
||||
is_base64=True,
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data,
|
||||
user_id,
|
||||
cache_key,
|
||||
)
|
||||
# 本地文件
|
||||
elif os.path.exists(url):
|
||||
with open(url, "rb") as f:
|
||||
image_data = f.read()
|
||||
# 用文件路径和修改时间来缓存
|
||||
file_stat = os.stat(url)
|
||||
cache_key = self._generate_cache_key(
|
||||
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
|
||||
is_base64=False,
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data,
|
||||
user_id,
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"图片文件不存在: {url}")
|
||||
continue
|
||||
|
||||
object_string_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"file_id": file_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {url}: {e!s}")
|
||||
continue
|
||||
|
||||
if object_string_content:
|
||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content_type": "object_string",
|
||||
},
|
||||
)
|
||||
# 纯文本
|
||||
elif prompt:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
accumulated_content = ""
|
||||
message_started = False
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
if event_type == "conversation.chat.created":
|
||||
if isinstance(data, dict) and "conversation_id" in data:
|
||||
self.conversation_ids[user_id] = data["conversation_id"]
|
||||
|
||||
elif event_type == "conversation.message.delta":
|
||||
if isinstance(data, dict):
|
||||
content = data.get("content", "")
|
||||
if not content and "delta" in data:
|
||||
content = data["delta"].get("content", "")
|
||||
if not content and "text" in data:
|
||||
content = data.get("text", "")
|
||||
|
||||
if content:
|
||||
message_started = True
|
||||
accumulated_content += content
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=content,
|
||||
is_chunk=True,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.message.completed":
|
||||
if isinstance(data, dict):
|
||||
msg_type = data.get("type")
|
||||
if msg_type == "answer" and data.get("role") == "assistant":
|
||||
final_content = data.get("content", "")
|
||||
if not accumulated_content and final_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(final_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.chat.completed":
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
elif event_type == "done":
|
||||
break
|
||||
|
||||
elif event_type == "error":
|
||||
error_msg = (
|
||||
data.get("message", "未知错误")
|
||||
if isinstance(data, dict)
|
||||
else str(data)
|
||||
)
|
||||
logger.error(f"Coze 流式响应错误: {error_msg}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 错误: {error_msg}",
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
if not message_started and not accumulated_content:
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="LLM 未响应任何内容。",
|
||||
is_chunk=False,
|
||||
)
|
||||
elif message_started and accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 流式请求失败: {e!s}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 流式请求失败: {e!s}",
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
async def forget(self, session_id: str):
|
||||
"""清空指定会话的上下文"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if user_id in self.file_id_cache:
|
||||
self.file_id_cache.pop(user_id, None)
|
||||
|
||||
if not conversation_id:
|
||||
return True
|
||||
|
||||
try:
|
||||
response = await self.api_client.clear_context(conversation_id)
|
||||
|
||||
if "code" in response and response["code"] == 0:
|
||||
self.conversation_ids.pop(user_id, None)
|
||||
return True
|
||||
logger.warning(f"清空 Coze 会话上下文失败: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空 Coze 会话失败: {e!s}")
|
||||
return False
|
||||
|
||||
async def get_current_key(self):
|
||||
"""获取当前API Key"""
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key: str):
|
||||
"""设置新的API Key"""
|
||||
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
"""获取可用模型列表"""
|
||||
return [f"bot_{self.bot_id}"]
|
||||
|
||||
def get_model(self):
|
||||
"""获取当前模型"""
|
||||
return f"bot_{self.bot_id}"
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型(在Coze中是Bot ID)"""
|
||||
if model.startswith("bot_"):
|
||||
self.bot_id = model[4:]
|
||||
else:
|
||||
self.bot_id = model
|
||||
|
||||
async def get_human_readable_context(
|
||||
self,
|
||||
session_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
):
|
||||
"""获取人类可读的上下文历史"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if not conversation_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await self.api_client.get_message_list(
|
||||
conversation_id=conversation_id,
|
||||
order="desc",
|
||||
limit=page_size,
|
||||
offset=(page - 1) * page_size,
|
||||
)
|
||||
|
||||
if data.get("code") != 0:
|
||||
logger.warning(f"获取 Coze 消息历史失败: {data}")
|
||||
return []
|
||||
|
||||
messages = data.get("data", {}).get("messages", [])
|
||||
|
||||
readable_history = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
msg_type = msg.get("type", "")
|
||||
|
||||
if role == "user":
|
||||
readable_history.append(f"用户: {content}")
|
||||
elif role == "assistant" and msg_type == "answer":
|
||||
readable_history.append(f"助手: {content}")
|
||||
|
||||
return readable_history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Coze 消息历史失败: {e!s}")
|
||||
return []
|
||||
|
||||
async def terminate(self):
|
||||
"""清理资源"""
|
||||
await self.api_client.close()
|
||||
@@ -1,207 +0,0 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import re
|
||||
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
||||
class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
Provider.__init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
)
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("阿里云百炼 API Key 不能为空。")
|
||||
self.app_id = provider_config.get("dashscope_app_id", "")
|
||||
if not self.app_id:
|
||||
raise Exception("阿里云百炼 APP ID 不能为空。")
|
||||
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
||||
if not self.dashscope_app_type:
|
||||
raise Exception("阿里云百炼 APP 类型不能为空。")
|
||||
self.model_name = "dashscope"
|
||||
self.variables: dict = provider_config.get("variables", {})
|
||||
self.rag_options: dict = provider_config.get("rag_options", {})
|
||||
self.output_reference = self.rag_options.get("output_reference", False)
|
||||
self.rag_options = self.rag_options.copy()
|
||||
self.rag_options.pop("output_reference", None)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
def has_rag_options(self):
|
||||
"""判断是否有 RAG 选项
|
||||
|
||||
Returns:
|
||||
bool: 是否有 RAG 选项
|
||||
|
||||
"""
|
||||
if self.rag_options and (
|
||||
len(self.rag_options.get("pipeline_ids", [])) > 0
|
||||
or len(self.rag_options.get("file_ids", [])) > 0
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if image_urls is None:
|
||||
image_urls = []
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if (
|
||||
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
||||
and not self.has_rag_options()
|
||||
):
|
||||
# 支持多轮对话的
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
if image_urls:
|
||||
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
||||
contexts_no_img = await self._remove_image_from_context(contexts)
|
||||
context_query = [*contexts_no_img, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
# 调用阿里云百炼 API
|
||||
payload = {
|
||||
"app_id": self.app_id,
|
||||
"api_key": self.api_key,
|
||||
"messages": context_query,
|
||||
"biz_params": payload_vars or None,
|
||||
}
|
||||
partial = functools.partial(
|
||||
Application.call,
|
||||
**payload,
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
else:
|
||||
# 不支持多轮对话的
|
||||
# 调用阿里云百炼 API
|
||||
payload = {
|
||||
"app_id": self.app_id,
|
||||
"prompt": prompt,
|
||||
"api_key": self.api_key,
|
||||
"biz_params": payload_vars or None,
|
||||
}
|
||||
if self.rag_options:
|
||||
payload["rag_options"] = self.rag_options
|
||||
partial = functools.partial(
|
||||
Application.call,
|
||||
**payload,
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
assert isinstance(response, ApplicationResponse)
|
||||
|
||||
logger.debug(f"dashscope resp: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
|
||||
)
|
||||
return LLMResponse(
|
||||
role="err",
|
||||
result_chain=MessageChain().message(
|
||||
f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
|
||||
),
|
||||
)
|
||||
|
||||
output_text = response.output.get("text", "") or ""
|
||||
# RAG 引用脚标格式化
|
||||
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
||||
if self.output_reference and response.output.get("doc_references", None):
|
||||
ref_parts = []
|
||||
for ref in response.output.get("doc_references", []) or []:
|
||||
ref_title = (
|
||||
ref.get("title", "")
|
||||
if ref.get("title")
|
||||
else ref.get("doc_name", "")
|
||||
)
|
||||
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
|
||||
ref_str = "".join(ref_parts)
|
||||
output_text += f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.result_chain = MessageChain().message(output_text)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def forget(self, session_id):
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key):
|
||||
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
|
||||
|
||||
async def terminate(self):
|
||||
pass
|
||||
@@ -1,285 +0,0 @@
|
||||
import os
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url
|
||||
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
class ProviderDify(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Dify API Key 不能为空。")
|
||||
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||
self.api_type = provider_config.get("dify_api_type", "")
|
||||
if not self.api_type:
|
||||
raise Exception("Dify API 类型不能为空。")
|
||||
self.model_name = "dify"
|
||||
self.workflow_output_key = provider_config.get(
|
||||
"dify_workflow_output_key",
|
||||
"astrbot_wf_output",
|
||||
)
|
||||
self.dify_query_input_key = provider_config.get(
|
||||
"dify_query_input_key",
|
||||
"astrbot_text_query",
|
||||
)
|
||||
if not self.dify_query_input_key:
|
||||
self.dify_query_input_key = "astrbot_text_query"
|
||||
if not self.workflow_output_key:
|
||||
self.workflow_output_key = "astrbot_wf_output"
|
||||
self.variables: dict = provider_config.get("variables", {})
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.conversation_ids = {}
|
||||
"""记录当前 session id 的对话 ID"""
|
||||
|
||||
self.api_client = DifyAPIClient(self.api_key, api_base)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if image_urls is None:
|
||||
image_urls = []
|
||||
result = ""
|
||||
session_id = session_id or kwargs.get("user") or "unknown" # 1734
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
files_payload = []
|
||||
for image_url in image_urls:
|
||||
image_path = (
|
||||
await download_image_by_url(image_url)
|
||||
if image_url.startswith("http")
|
||||
else image_url
|
||||
)
|
||||
file_response = await self.api_client.file_upload(
|
||||
image_path,
|
||||
user=session_id,
|
||||
)
|
||||
logger.debug(f"Dify 上传图片响应:{file_response}")
|
||||
if "id" not in file_response:
|
||||
logger.warning(
|
||||
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。",
|
||||
)
|
||||
continue
|
||||
files_payload.append(
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response["id"],
|
||||
},
|
||||
)
|
||||
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
try:
|
||||
match self.api_type:
|
||||
case "chat" | "agent" | "chatflow":
|
||||
if not prompt:
|
||||
prompt = "请描述这张图片。"
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={
|
||||
**payload_vars,
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if (
|
||||
chunk["event"] == "message"
|
||||
or chunk["event"] == "agent_message"
|
||||
):
|
||||
result += chunk["answer"]
|
||||
if not conversation_id:
|
||||
self.conversation_ids[session_id] = chunk[
|
||||
"conversation_id"
|
||||
]
|
||||
conversation_id = chunk["conversation_id"]
|
||||
elif chunk["event"] == "message_end":
|
||||
logger.debug("Dify message end")
|
||||
break
|
||||
elif chunk["event"] == "error":
|
||||
logger.error(f"Dify 出现错误:{chunk}")
|
||||
raise Exception(
|
||||
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}",
|
||||
)
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
self.dify_query_input_key: prompt,
|
||||
"astrbot_session_id": session_id,
|
||||
**payload_vars,
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
match chunk["event"]:
|
||||
case "workflow_started":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。",
|
||||
)
|
||||
case "node_finished":
|
||||
logger.debug(
|
||||
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。",
|
||||
)
|
||||
case "workflow_finished":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束",
|
||||
)
|
||||
logger.debug(f"Dify 工作流结果:{chunk}")
|
||||
if chunk["data"]["error"]:
|
||||
logger.error(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}",
|
||||
)
|
||||
raise Exception(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}",
|
||||
)
|
||||
if (
|
||||
self.workflow_output_key
|
||||
not in chunk["data"]["outputs"]
|
||||
):
|
||||
raise Exception(
|
||||
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}",
|
||||
)
|
||||
result = chunk
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Dify 请求失败:{e!s}")
|
||||
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}")
|
||||
|
||||
if not result:
|
||||
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
|
||||
|
||||
chain = await self.parse_dify_result(result)
|
||||
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||
if isinstance(chunk, str):
|
||||
# Chat
|
||||
return MessageChain(chain=[Comp.Plain(chunk)])
|
||||
|
||||
async def parse_file(item: dict):
|
||||
match item["type"]:
|
||||
case "image":
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
return Comp.Video(file=item["url"])
|
||||
case _:
|
||||
return Comp.File(name=item["filename"], file=item["url"])
|
||||
|
||||
output = chunk["data"]["outputs"][self.workflow_output_key]
|
||||
chains = []
|
||||
if isinstance(output, str):
|
||||
# 纯文本输出
|
||||
chains.append(Comp.Plain(output))
|
||||
elif isinstance(output, list):
|
||||
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
|
||||
for item in output:
|
||||
# handle Array[File]
|
||||
if (
|
||||
not isinstance(item, dict)
|
||||
or item.get("dify_model_identity", "") != "__dify__file__"
|
||||
):
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
break
|
||||
else:
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
|
||||
# scan file
|
||||
files = chunk["data"].get("files", [])
|
||||
for item in files:
|
||||
comp = await parse_file(item)
|
||||
chains.append(comp)
|
||||
|
||||
return MessageChain(chain=chains)
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.conversation_ids[session_id] = ""
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key):
|
||||
raise Exception("Dify 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
raise Exception("暂不支持获得 Dify 的历史消息记录。")
|
||||
|
||||
async def terminate(self):
|
||||
await self.api_client.close()
|
||||
@@ -85,3 +85,22 @@ class UmopConfigRouter:
|
||||
|
||||
self.umop_to_conf_id[umo] = conf_id
|
||||
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
|
||||
|
||||
async def delete_route(self, umo: str):
|
||||
"""删除一条路由
|
||||
|
||||
Args:
|
||||
umo (str): 需要删除的 UMO 字符串
|
||||
|
||||
Raises:
|
||||
ValueError: 当 umo 格式不正确时抛出
|
||||
"""
|
||||
|
||||
if not isinstance(umo, str) or len(umo.split(":")) != 3:
|
||||
raise ValueError(
|
||||
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
|
||||
)
|
||||
|
||||
if umo in self.umop_to_conf_id:
|
||||
del self.umop_to_conf_id[umo]
|
||||
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
|
||||
|
||||
73
astrbot/core/utils/migra_helper.py
Normal file
73
astrbot/core/utils/migra_helper.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import traceback
|
||||
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
|
||||
|
||||
def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None:
|
||||
"""
|
||||
Migra agent runner configs from provider configs.
|
||||
"""
|
||||
try:
|
||||
default_prov_id = conf["provider_settings"]["default_provider_id"]
|
||||
if default_prov_id in ids_map:
|
||||
conf["provider_settings"]["default_provider_id"] = ""
|
||||
p = ids_map[default_prov_id]
|
||||
if p["type"] == "dify":
|
||||
conf["provider_settings"]["dify_agent_runner_provider_id"] = p["id"]
|
||||
conf["provider_settings"]["agent_runner_type"] = "dify"
|
||||
elif p["type"] == "coze":
|
||||
conf["provider_settings"]["coze_agent_runner_provider_id"] = p["id"]
|
||||
conf["provider_settings"]["agent_runner_type"] = "coze"
|
||||
elif p["type"] == "dashscope":
|
||||
conf["provider_settings"]["dashscope_agent_runner_provider_id"] = p[
|
||||
"id"
|
||||
]
|
||||
conf["provider_settings"]["agent_runner_type"] = "dashscope"
|
||||
conf.save_config()
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for third party agent runner configs failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def migra(
|
||||
db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager
|
||||
) -> None:
|
||||
"""
|
||||
Stores the migration logic here.
|
||||
btw, i really don't like migration :(
|
||||
"""
|
||||
# 4.5 to 4.6 migration for umop_config_router
|
||||
try:
|
||||
await migrate_45_to_46(astrbot_config_mgr, umop_config_router)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for webchat session
|
||||
try:
|
||||
await migrate_webchat_session(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migra third party agent runner configs
|
||||
_c = False
|
||||
providers = astrbot_config["provider"]
|
||||
ids_map = {}
|
||||
for prov in providers:
|
||||
type_ = prov.get("type")
|
||||
if type_ in ["dify", "coze", "dashscope"]:
|
||||
prov["provider_type"] = "agent_runner"
|
||||
ids_map[prov["id"]] = {
|
||||
"type": type_,
|
||||
"id": prov["id"],
|
||||
}
|
||||
_c = True
|
||||
if _c:
|
||||
astrbot_config.save_config()
|
||||
|
||||
for conf in acm.confs.values():
|
||||
_migra_agent_runner_configs(conf, ids_map)
|
||||
@@ -40,9 +40,6 @@ class SharedPreferences:
|
||||
else:
|
||||
ret = default
|
||||
return ret
|
||||
raise ValueError(
|
||||
"scope_id and key cannot be None when getting a specific preference.",
|
||||
)
|
||||
|
||||
async def range_get_async(
|
||||
self,
|
||||
@@ -56,30 +53,6 @@ class SharedPreferences:
|
||||
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
||||
return ret
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: None,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str,
|
||||
key: None,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: None,
|
||||
key: None,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str | None,
|
||||
@@ -88,7 +61,7 @@ class SharedPreferences:
|
||||
) -> _VT | list[Preference]:
|
||||
"""获取会话范围的偏好设置
|
||||
|
||||
Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
Note: 当 umo 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
"""
|
||||
if umo is None or key is None:
|
||||
return await self.range_get_async("umo", umo, key)
|
||||
|
||||
@@ -56,6 +56,7 @@ class ChatRoute(Route):
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
self.db = db
|
||||
self.umop_config_router = core_lifecycle.umop_config_router
|
||||
|
||||
self.running_convs: dict[str, bool] = {}
|
||||
|
||||
@@ -266,7 +267,8 @@ class ChatRoute(Route):
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
# 删除该会话下的所有对话
|
||||
unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}"
|
||||
message_type = "GroupMessage" if session.is_group else "FriendMessage"
|
||||
unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}"
|
||||
await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
|
||||
|
||||
# 删除消息历史
|
||||
@@ -276,6 +278,16 @@ class ChatRoute(Route):
|
||||
offset_sec=99999999,
|
||||
)
|
||||
|
||||
# 删除与会话关联的配置路由
|
||||
try:
|
||||
await self.umop_config_router.delete_route(unified_msg_origin)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Failed to delete UMO route %s during session cleanup: %s",
|
||||
unified_msg_origin,
|
||||
exc,
|
||||
)
|
||||
|
||||
# 清理队列(仅对 webchat)
|
||||
if session.platform_id == "webchat":
|
||||
webchat_queue_mgr.remove_queues(session_id)
|
||||
|
||||
@@ -87,6 +87,8 @@
|
||||
:disabled="isStreaming || isConvRunning"
|
||||
:enableStreaming="enableStreaming"
|
||||
:isRecording="isRecording"
|
||||
:session-id="currSessionId || null"
|
||||
:current-session="getCurrentSession"
|
||||
@send="handleSendMessage"
|
||||
@toggleStreaming="toggleStreaming"
|
||||
@removeImage="removeImage"
|
||||
|
||||
@@ -11,7 +11,13 @@
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
<div style="display: flex; justify-content: space-between; align-items: center; padding: 0px 12px;">
|
||||
<div style="display: flex; justify-content: flex-start; margin-top: 4px; align-items: center; gap: 8px;">
|
||||
<ProviderModelSelector ref="providerModelSelectorRef" />
|
||||
<ConfigSelector
|
||||
:session-id="sessionId || null"
|
||||
:platform-id="sessionPlatformId"
|
||||
:is-group="sessionIsGroup"
|
||||
@config-changed="handleConfigChange"
|
||||
/>
|
||||
<ProviderModelSelector v-if="showProviderSelector" ref="providerModelSelectorRef" />
|
||||
|
||||
<v-tooltip :text="enableStreaming ? tm('streaming.enabled') : tm('streaming.disabled')" location="top">
|
||||
<template v-slot:activator="{ props }">
|
||||
@@ -58,9 +64,11 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, onBeforeUnmount, watch } from 'vue';
|
||||
import { ref, computed, onMounted, onBeforeUnmount } from 'vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import ProviderModelSelector from './ProviderModelSelector.vue';
|
||||
import ConfigSelector from './ConfigSelector.vue';
|
||||
import type { Session } from '@/composables/useSessions';
|
||||
|
||||
interface Props {
|
||||
prompt: string;
|
||||
@@ -69,9 +77,14 @@ interface Props {
|
||||
disabled: boolean;
|
||||
enableStreaming: boolean;
|
||||
isRecording: boolean;
|
||||
sessionId?: string | null;
|
||||
currentSession?: Session | null;
|
||||
}
|
||||
|
||||
const props = defineProps<Props>();
|
||||
const props = withDefaults(defineProps<Props>(), {
|
||||
sessionId: null,
|
||||
currentSession: null
|
||||
});
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:prompt': [value: string];
|
||||
@@ -90,12 +103,16 @@ const { tm } = useModuleI18n('features/chat');
|
||||
const inputField = ref<HTMLTextAreaElement | null>(null);
|
||||
const imageInputRef = ref<HTMLInputElement | null>(null);
|
||||
const providerModelSelectorRef = ref<InstanceType<typeof ProviderModelSelector> | null>(null);
|
||||
const showProviderSelector = ref(true);
|
||||
|
||||
const localPrompt = computed({
|
||||
get: () => props.prompt,
|
||||
set: (value) => emit('update:prompt', value)
|
||||
});
|
||||
|
||||
const sessionPlatformId = computed(() => props.currentSession?.platform_id || 'webchat');
|
||||
const sessionIsGroup = computed(() => Boolean(props.currentSession?.is_group));
|
||||
|
||||
const canSend = computed(() => {
|
||||
return (props.prompt && props.prompt.trim()) || props.stagedImagesUrl.length > 0 || props.stagedAudioUrl;
|
||||
});
|
||||
@@ -168,7 +185,16 @@ function handleRecordClick() {
|
||||
}
|
||||
}
|
||||
|
||||
function handleConfigChange(payload: { configId: string; agentRunnerType: string }) {
|
||||
const runnerType = (payload.agentRunnerType || '').toLowerCase();
|
||||
const isInternal = runnerType === 'internal' || runnerType === 'local';
|
||||
showProviderSelector.value = isInternal;
|
||||
}
|
||||
|
||||
function getCurrentSelection() {
|
||||
if (!showProviderSelector.value) {
|
||||
return null;
|
||||
}
|
||||
return providerModelSelectorRef.value?.getCurrentSelection();
|
||||
}
|
||||
|
||||
|
||||
311
dashboard/src/components/chat/ConfigSelector.vue
Normal file
311
dashboard/src/components/chat/ConfigSelector.vue
Normal file
@@ -0,0 +1,311 @@
|
||||
<template>
|
||||
<div>
|
||||
<v-tooltip text="选择用于当前会话的配置文件" location="top">
|
||||
<template #activator="{ props: tooltipProps }">
|
||||
<v-chip
|
||||
v-bind="tooltipProps"
|
||||
class="text-none config-chip"
|
||||
variant="tonal"
|
||||
size="x-small"
|
||||
rounded="lg"
|
||||
@click="openDialog"
|
||||
:disabled="loadingConfigs || saving"
|
||||
>
|
||||
<v-icon start size="14">mdi-cog</v-icon>
|
||||
{{ selectedConfigLabel }}
|
||||
</v-chip>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
<v-dialog v-model="dialog" max-width="480" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="d-flex align-center justify-space-between">
|
||||
<span>选择配置文件</span>
|
||||
<v-btn icon variant="text" @click="closeDialog">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<div v-if="loadingConfigs" class="text-center py-6">
|
||||
<v-progress-circular indeterminate color="primary"></v-progress-circular>
|
||||
</div>
|
||||
|
||||
<v-list v-else class="config-list" density="comfortable">
|
||||
<v-list-item
|
||||
v-for="config in configOptions"
|
||||
:key="config.id"
|
||||
:active="tempSelectedConfig === config.id"
|
||||
rounded="lg"
|
||||
variant="text"
|
||||
@click="tempSelectedConfig = config.id"
|
||||
>
|
||||
<v-list-item-title>{{ config.name }}</v-list-item-title>
|
||||
<v-list-item-subtitle class="text-caption text-grey">
|
||||
{{ config.id }}
|
||||
</v-list-item-subtitle>
|
||||
<template #append>
|
||||
<v-icon v-if="tempSelectedConfig === config.id" color="primary">mdi-check</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
<div v-if="configOptions.length === 0" class="text-center text-body-2 text-medium-emphasis">
|
||||
暂无可选配置,请先在配置页创建。
|
||||
</div>
|
||||
</v-list>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="closeDialog">取消</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="confirmSelection"
|
||||
:disabled="!tempSelectedConfig"
|
||||
:loading="saving"
|
||||
>
|
||||
应用
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, ref, watch } from 'vue';
|
||||
import axios from 'axios';
|
||||
import { useToast } from '@/utils/toast';
|
||||
|
||||
interface ConfigInfo {
|
||||
id: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
interface ConfigChangedPayload {
|
||||
configId: string;
|
||||
agentRunnerType: string;
|
||||
}
|
||||
|
||||
const STORAGE_KEY = 'chat.selectedConfigId';
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
sessionId?: string | null;
|
||||
platformId?: string;
|
||||
isGroup?: boolean;
|
||||
}>(), {
|
||||
sessionId: null,
|
||||
platformId: 'webchat',
|
||||
isGroup: false
|
||||
});
|
||||
|
||||
const emit = defineEmits<{ 'config-changed': [ConfigChangedPayload] }>();
|
||||
|
||||
const configOptions = ref<ConfigInfo[]>([]);
|
||||
const loadingConfigs = ref(false);
|
||||
const dialog = ref(false);
|
||||
const tempSelectedConfig = ref('');
|
||||
const selectedConfigId = ref('default');
|
||||
const agentRunnerType = ref('local');
|
||||
const saving = ref(false);
|
||||
const pendingSync = ref(false);
|
||||
const routingEntries = ref<Array<{ pattern: string; confId: string }>>([]);
|
||||
const configCache = ref<Record<string, string>>({});
|
||||
|
||||
const toast = useToast();
|
||||
|
||||
const normalizedSessionId = computed(() => {
|
||||
const id = props.sessionId?.trim();
|
||||
return id ? id : null;
|
||||
});
|
||||
|
||||
const hasActiveSession = computed(() => !!normalizedSessionId.value);
|
||||
|
||||
const messageType = computed(() => (props.isGroup ? 'GroupMessage' : 'FriendMessage'));
|
||||
|
||||
const username = computed(() => localStorage.getItem('user') || 'guest');
|
||||
|
||||
const sessionKey = computed(() => {
|
||||
if (!normalizedSessionId.value) {
|
||||
return null;
|
||||
}
|
||||
return `${props.platformId}!${username.value}!${normalizedSessionId.value}`;
|
||||
});
|
||||
|
||||
const targetUmo = computed(() => {
|
||||
if (!sessionKey.value) {
|
||||
return null;
|
||||
}
|
||||
return `${props.platformId}:${messageType.value}:${sessionKey.value}`;
|
||||
});
|
||||
|
||||
const selectedConfigLabel = computed(() => {
|
||||
const target = configOptions.value.find((item) => item.id === selectedConfigId.value);
|
||||
return target?.name || selectedConfigId.value || 'default';
|
||||
});
|
||||
|
||||
function openDialog() {
|
||||
tempSelectedConfig.value = selectedConfigId.value;
|
||||
dialog.value = true;
|
||||
}
|
||||
|
||||
function closeDialog() {
|
||||
dialog.value = false;
|
||||
}
|
||||
|
||||
async function fetchConfigList() {
|
||||
loadingConfigs.value = true;
|
||||
try {
|
||||
const res = await axios.get('/api/config/abconfs');
|
||||
configOptions.value = res.data.data?.info_list || [];
|
||||
} catch (error) {
|
||||
console.error('加载配置文件列表失败', error);
|
||||
configOptions.value = [];
|
||||
} finally {
|
||||
loadingConfigs.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchRoutingEntries() {
|
||||
try {
|
||||
const res = await axios.get('/api/config/umo_abconf_routes');
|
||||
const routing = res.data.data?.routing || {};
|
||||
routingEntries.value = Object.entries(routing).map(([pattern, confId]) => ({
|
||||
pattern,
|
||||
confId: confId as string
|
||||
}));
|
||||
} catch (error) {
|
||||
console.error('获取配置路由失败', error);
|
||||
routingEntries.value = [];
|
||||
}
|
||||
}
|
||||
|
||||
function matchesPattern(pattern: string, target: string): boolean {
|
||||
const parts = pattern.split(':');
|
||||
const targetParts = target.split(':');
|
||||
if (parts.length !== 3 || targetParts.length !== 3) {
|
||||
return false;
|
||||
}
|
||||
return parts.every((part, index) => part === '' || part === '*' || part === targetParts[index]);
|
||||
}
|
||||
|
||||
function resolveConfigId(umo: string | null): string {
|
||||
if (!umo) {
|
||||
return 'default';
|
||||
}
|
||||
for (const entry of routingEntries.value) {
|
||||
if (matchesPattern(entry.pattern, umo)) {
|
||||
return entry.confId;
|
||||
}
|
||||
}
|
||||
return 'default';
|
||||
}
|
||||
|
||||
async function getAgentRunnerType(confId: string): Promise<string> {
|
||||
if (configCache.value[confId]) {
|
||||
return configCache.value[confId];
|
||||
}
|
||||
try {
|
||||
const res = await axios.get('/api/config/abconf', {
|
||||
params: { id: confId }
|
||||
});
|
||||
const type = res.data.data?.config?.provider_settings?.agent_runner_type || 'local';
|
||||
configCache.value[confId] = type;
|
||||
return type;
|
||||
} catch (error) {
|
||||
console.error('获取配置文件详情失败', error);
|
||||
return 'local';
|
||||
}
|
||||
}
|
||||
|
||||
async function setSelection(confId: string) {
|
||||
const normalized = confId || 'default';
|
||||
selectedConfigId.value = normalized;
|
||||
const runnerType = await getAgentRunnerType(normalized);
|
||||
agentRunnerType.value = runnerType;
|
||||
emit('config-changed', {
|
||||
configId: normalized,
|
||||
agentRunnerType: runnerType
|
||||
});
|
||||
}
|
||||
|
||||
async function applySelectionToBackend(confId: string): Promise<boolean> {
|
||||
if (!targetUmo.value) {
|
||||
pendingSync.value = true;
|
||||
return true;
|
||||
}
|
||||
saving.value = true;
|
||||
try {
|
||||
await axios.post('/api/config/umo_abconf_route/update', {
|
||||
umo: targetUmo.value,
|
||||
conf_id: confId
|
||||
});
|
||||
const filtered = routingEntries.value.filter((entry) => entry.pattern !== targetUmo.value);
|
||||
filtered.push({ pattern: targetUmo.value, confId });
|
||||
routingEntries.value = filtered;
|
||||
return true;
|
||||
} catch (error) {
|
||||
const err = error as any;
|
||||
console.error('更新配置文件失败', err);
|
||||
toast.error(err?.response?.data?.message || '配置文件应用失败');
|
||||
return false;
|
||||
} finally {
|
||||
saving.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
async function confirmSelection() {
|
||||
if (!tempSelectedConfig.value) {
|
||||
return;
|
||||
}
|
||||
const previousId = selectedConfigId.value;
|
||||
await setSelection(tempSelectedConfig.value);
|
||||
localStorage.setItem(STORAGE_KEY, tempSelectedConfig.value);
|
||||
const applied = await applySelectionToBackend(tempSelectedConfig.value);
|
||||
if (!applied) {
|
||||
localStorage.setItem(STORAGE_KEY, previousId);
|
||||
await setSelection(previousId);
|
||||
}
|
||||
dialog.value = false;
|
||||
}
|
||||
|
||||
async function syncSelectionForSession() {
|
||||
if (!targetUmo.value) {
|
||||
pendingSync.value = true;
|
||||
return;
|
||||
}
|
||||
if (pendingSync.value) {
|
||||
pendingSync.value = false;
|
||||
await applySelectionToBackend(selectedConfigId.value);
|
||||
return;
|
||||
}
|
||||
await fetchRoutingEntries();
|
||||
const resolved = resolveConfigId(targetUmo.value);
|
||||
await setSelection(resolved);
|
||||
localStorage.setItem(STORAGE_KEY, resolved);
|
||||
}
|
||||
|
||||
watch(
|
||||
() => [props.sessionId, props.platformId, props.isGroup],
|
||||
async () => {
|
||||
await syncSelectionForSession();
|
||||
}
|
||||
);
|
||||
|
||||
onMounted(async () => {
|
||||
await fetchConfigList();
|
||||
const stored = localStorage.getItem(STORAGE_KEY) || 'default';
|
||||
selectedConfigId.value = stored;
|
||||
await setSelection(stored);
|
||||
await syncSelectionForSession();
|
||||
});
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.config-chip {
|
||||
cursor: pointer;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
|
||||
.config-list {
|
||||
max-height: 360px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
</style>
|
||||
@@ -3,6 +3,7 @@
|
||||
<!-- 选择提供商和模型按钮 -->
|
||||
<v-chip class="text-none" variant="tonal" size="x-small"
|
||||
v-if="selectedProviderId && selectedModelName" @click="openDialog">
|
||||
<v-icon start size="14">mdi-creation</v-icon>
|
||||
{{ selectedProviderId }} / {{ selectedModelName }}
|
||||
</v-chip>
|
||||
<v-chip variant="tonal" rounded="xl" size="x-small" v-else @click="openDialog">
|
||||
|
||||
@@ -7,6 +7,10 @@
|
||||
<v-icon start>mdi-message-text</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.basic') }}
|
||||
</v-tab>
|
||||
<v-tab value="agent_runner" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-cogs</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.agentRunner') }}
|
||||
</v-tab>
|
||||
<v-tab value="speech_to_text" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-microphone-message</v-icon>
|
||||
{{ tm('dialogs.addProvider.tabs.speechToText') }}
|
||||
@@ -27,7 +31,7 @@
|
||||
|
||||
<v-window v-model="activeProviderTab" class="mt-4">
|
||||
<v-window-item
|
||||
v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
|
||||
v-for="tabType in ['chat_completion', 'agent_runner', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank']"
|
||||
:key="tabType" :value="tabType">
|
||||
<v-row class="mt-1">
|
||||
<v-col v-for="(template, name) in getTemplatesByType(tabType)" :key="name" cols="12" sm="6"
|
||||
@@ -36,7 +40,7 @@
|
||||
@click="selectProviderTemplate(name)">
|
||||
<div class="provider-card-content">
|
||||
<div class="provider-card-text">
|
||||
<v-card-title class="provider-card-title">接入 {{ name }}</v-card-title>
|
||||
<v-card-title class="provider-card-title">{{ name }}</v-card-title>
|
||||
<v-card-text
|
||||
class="text-caption text-medium-emphasis provider-card-description">
|
||||
{{ getProviderDescription(template, name) }}
|
||||
@@ -54,7 +58,7 @@
|
||||
</v-col>
|
||||
<v-col v-if="Object.keys(getTemplatesByType(tabType)).length === 0" cols="12">
|
||||
<v-alert type="info" variant="tonal">
|
||||
{{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
|
||||
{{ t('dialogs.addProvider.noTemplates') }}
|
||||
</v-alert>
|
||||
</v-col>
|
||||
</v-row>
|
||||
@@ -104,19 +108,6 @@ export default {
|
||||
this.$emit('update:show', value);
|
||||
}
|
||||
},
|
||||
|
||||
// 翻译消息的计算属性
|
||||
messages() {
|
||||
return {
|
||||
tabTypes: {
|
||||
'chat_completion': this.tm('providers.tabs.chatCompletion'),
|
||||
'speech_to_text': this.tm('providers.tabs.speechToText'),
|
||||
'text_to_speech': this.tm('providers.tabs.textToSpeech'),
|
||||
'embedding': this.tm('providers.tabs.embedding'),
|
||||
'rerank': this.tm('providers.tabs.rerank')
|
||||
}
|
||||
};
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
closeDialog() {
|
||||
@@ -140,11 +131,6 @@ export default {
|
||||
// 从工具函数导入
|
||||
getProviderIcon,
|
||||
|
||||
// 获取Tab类型的中文名称
|
||||
getTabTypeName(tabType) {
|
||||
return this.messages.tabTypes[tabType] || tabType;
|
||||
},
|
||||
|
||||
// 获取提供商简介
|
||||
getProviderDescription(template, name) {
|
||||
return getProviderDescription(template, name, this.tm);
|
||||
|
||||
@@ -101,6 +101,21 @@ function shouldShowItem(itemMeta, itemKey) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查最外层的 object 是否应该显示
|
||||
function shouldShowSection() {
|
||||
const sectionMeta = props.metadata[props.metadataKey]
|
||||
if (!sectionMeta?.condition) {
|
||||
return true
|
||||
}
|
||||
for (const [conditionKey, expectedValue] of Object.entries(sectionMeta.condition)) {
|
||||
const actualValue = getValueBySelector(props.iterable, conditionKey)
|
||||
if (actualValue !== expectedValue) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
function hasVisibleItemsAfter(items, currentIndex) {
|
||||
const itemEntries = Object.entries(items)
|
||||
|
||||
@@ -114,12 +129,33 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function parseSpecialValue(value) {
|
||||
if (!value || typeof value !== 'string') {
|
||||
return { name: '', subtype: '' }
|
||||
}
|
||||
const [name, ...rest] = value.split(':')
|
||||
return {
|
||||
name,
|
||||
subtype: rest.join(':') || ''
|
||||
}
|
||||
}
|
||||
|
||||
function getSpecialName(value) {
|
||||
return parseSpecialValue(value).name
|
||||
}
|
||||
|
||||
function getSpecialSubtype(value) {
|
||||
return parseSpecialValue(value).subtype
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<template>
|
||||
|
||||
|
||||
<v-card style="margin-bottom: 16px; padding-bottom: 8px; background-color: rgb(var(--v-theme-background));" rounded="md" variant="outlined">
|
||||
<v-card v-if="shouldShowSection()" style="margin-bottom: 16px; padding-bottom: 8px; background-color: rgb(var(--v-theme-background));"
|
||||
rounded="md" variant="outlined">
|
||||
<v-card-text class="config-section" v-if="metadata[metadataKey]?.type === 'object'" style="padding-bottom: 8px;">
|
||||
<v-list-item-title class="config-title">
|
||||
{{ metadata[metadataKey]?.description }}
|
||||
@@ -187,22 +223,16 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
|
||||
<!-- Boolean switch for JSON selector -->
|
||||
<v-switch v-else-if="itemMeta?.type === 'bool'" v-model="createSelectorModel(itemKey).value"
|
||||
color="primary" inset density="compact" hide-details style="display: flex; justify-content: end;"></v-switch>
|
||||
color="primary" inset density="compact" hide-details
|
||||
style="display: flex; justify-content: end;"></v-switch>
|
||||
|
||||
<!-- List item for JSON selector -->
|
||||
<ListConfigItem
|
||||
v-else-if="itemMeta?.type === 'list'"
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
button-text="修改"
|
||||
class="config-field"
|
||||
/>
|
||||
<ListConfigItem v-else-if="itemMeta?.type === 'list'" v-model="createSelectorModel(itemKey).value"
|
||||
button-text="修改" class="config-field" />
|
||||
|
||||
<!-- Object editor for JSON selector -->
|
||||
<ObjectEditor
|
||||
v-else-if="itemMeta?.type === 'dict'"
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
class="config-field"
|
||||
/>
|
||||
<ObjectEditor v-else-if="itemMeta?.type === 'dict'" v-model="createSelectorModel(itemKey).value"
|
||||
class="config-field" />
|
||||
|
||||
<!-- Fallback for JSON selector -->
|
||||
<v-text-field v-else v-model="createSelectorModel(itemKey).value" density="compact" variant="outlined"
|
||||
@@ -211,50 +241,36 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
|
||||
<!-- Special handling for specific metadata types -->
|
||||
<div v-else-if="itemMeta?._special === 'select_provider'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'chat_completion'"
|
||||
/>
|
||||
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'chat_completion'" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_provider_stt'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'speech_to_text'"
|
||||
/>
|
||||
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'speech_to_text'" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_provider_tts'">
|
||||
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'text_to_speech'" />
|
||||
</div>
|
||||
<div v-else-if="getSpecialName(itemMeta?._special) === 'select_agent_runner_provider'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'text_to_speech'"
|
||||
:provider-type="'agent_runner'"
|
||||
:provider-subtype="getSpecialSubtype(itemMeta?._special)"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'provider_pool'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'chat_completion'"
|
||||
button-text="选择提供商池..."
|
||||
/>
|
||||
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'chat_completion'"
|
||||
button-text="选择提供商池..." />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_persona'">
|
||||
<PersonaSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
<PersonaSelector v-model="createSelectorModel(itemKey).value" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'persona_pool'">
|
||||
<PersonaSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
button-text="选择人格池..."
|
||||
/>
|
||||
<PersonaSelector v-model="createSelectorModel(itemKey).value" button-text="选择人格池..." />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
<KnowledgeBaseSelector v-model="createSelectorModel(itemKey).value" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_plugin_set'">
|
||||
<PluginSetSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
<PluginSetSelector v-model="createSelectorModel(itemKey).value" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 't2i_template'">
|
||||
<T2ITemplateEditor />
|
||||
@@ -263,21 +279,17 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-row>
|
||||
|
||||
<!-- Plugin Set Selector 全宽显示区域 -->
|
||||
<v-row v-if="!itemMeta?.invisible && itemMeta?._special === 'select_plugin_set'" class="plugin-set-display-row">
|
||||
<v-row v-if="!itemMeta?.invisible && itemMeta?._special === 'select_plugin_set'"
|
||||
class="plugin-set-display-row">
|
||||
<v-col cols="12" class="plugin-set-display">
|
||||
<div v-if="createSelectorModel(itemKey).value && createSelectorModel(itemKey).value.length > 0" class="selected-plugins-full-width">
|
||||
<div v-if="createSelectorModel(itemKey).value && createSelectorModel(itemKey).value.length > 0"
|
||||
class="selected-plugins-full-width">
|
||||
<div class="plugins-header">
|
||||
<small class="text-grey">已选择的插件:</small>
|
||||
</div>
|
||||
<div class="d-flex flex-wrap ga-2 mt-2">
|
||||
<v-chip
|
||||
v-for="plugin in (createSelectorModel(itemKey).value || [])"
|
||||
:key="plugin"
|
||||
size="small"
|
||||
label
|
||||
color="primary"
|
||||
variant="outlined"
|
||||
>
|
||||
<v-chip v-for="plugin in (createSelectorModel(itemKey).value || [])" :key="plugin" size="small" label
|
||||
color="primary" variant="outlined">
|
||||
{{ plugin === '*' ? '所有插件' : plugin }}
|
||||
</v-chip>
|
||||
</div>
|
||||
@@ -285,7 +297,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-col>
|
||||
</v-row>
|
||||
</template>
|
||||
<v-divider class="config-divider" v-if="shouldShowItem(itemMeta, itemKey) && hasVisibleItemsAfter(metadata[metadataKey].items, index)"></v-divider>
|
||||
<v-divider class="config-divider"
|
||||
v-if="shouldShowItem(itemMeta, itemKey) && hasVisibleItemsAfter(metadata[metadataKey].items, index)"></v-divider>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -94,6 +94,10 @@ const props = defineProps({
|
||||
type: String,
|
||||
default: 'chat_completion'
|
||||
},
|
||||
providerSubtype: {
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择提供商...'
|
||||
@@ -127,7 +131,10 @@ async function loadProviders() {
|
||||
}
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
providerList.value = response.data.data || []
|
||||
const providers = response.data.data || []
|
||||
providerList.value = props.providerSubtype
|
||||
? providers.filter((provider) => matchesProviderSubtype(provider, props.providerSubtype))
|
||||
: providers
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载提供商列表失败:', error)
|
||||
@@ -137,6 +144,17 @@ async function loadProviders() {
|
||||
}
|
||||
}
|
||||
|
||||
function matchesProviderSubtype(provider, subtype) {
|
||||
if (!subtype) {
|
||||
return true
|
||||
}
|
||||
const normalized = String(subtype).toLowerCase()
|
||||
const candidates = [provider.type, provider.provider, provider.id]
|
||||
.filter(Boolean)
|
||||
.map((value) => String(value).toLowerCase())
|
||||
return candidates.includes(normalized)
|
||||
}
|
||||
|
||||
function selectProvider(provider) {
|
||||
selectedProvider.value = provider.id
|
||||
}
|
||||
|
||||
@@ -4,8 +4,12 @@ import { useRouter } from 'vue-router';
|
||||
|
||||
export interface Session {
|
||||
session_id: string;
|
||||
display_name: string;
|
||||
display_name: string | null;
|
||||
updated_at: string;
|
||||
platform_id: string;
|
||||
creator: string;
|
||||
is_group: number;
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export function useSessions(chatboxMode: boolean = false) {
|
||||
|
||||
@@ -74,6 +74,7 @@
|
||||
"delete": "Delete",
|
||||
"copy": "Copy",
|
||||
"edit": "Edit",
|
||||
"copy": "Copy",
|
||||
"noData": "No data available"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
"tabs": {
|
||||
"all": "All",
|
||||
"chatCompletion": "Chat Completion",
|
||||
"agentRunner": "Agent Runner",
|
||||
"speechToText": "Speech to Text",
|
||||
"textToSpeech": "Text to Speech",
|
||||
"embedding": "Embedding",
|
||||
@@ -44,12 +45,13 @@
|
||||
"title": "Service Provider",
|
||||
"tabs": {
|
||||
"basic": "Basic",
|
||||
"agentRunner": "Agent Runner",
|
||||
"speechToText": "Speech to Text",
|
||||
"textToSpeech": "Text to Speech",
|
||||
"embedding": "Embedding",
|
||||
"rerank": "Rerank"
|
||||
},
|
||||
"noTemplates": "No {type} type provider templates available"
|
||||
"noTemplates": "No this type provider templates available"
|
||||
},
|
||||
"config": {
|
||||
"addTitle": "Add",
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
"providerType": "提供商类型",
|
||||
"tabs": {
|
||||
"all": "全部",
|
||||
"chatCompletion": "基本对话",
|
||||
"chatCompletion": "对话",
|
||||
"agentRunner": "Agent 执行器",
|
||||
"speechToText": "语音转文字",
|
||||
"textToSpeech": "文字转语音",
|
||||
"embedding": "嵌入(Embedding)",
|
||||
@@ -44,13 +45,14 @@
|
||||
"addProvider": {
|
||||
"title": "模型提供商",
|
||||
"tabs": {
|
||||
"basic": "基本",
|
||||
"basic": "对话",
|
||||
"agentRunner": "Agent 执行器",
|
||||
"speechToText": "语音转文字",
|
||||
"textToSpeech": "文字转语音",
|
||||
"embedding": "嵌入(Embedding)",
|
||||
"rerank": "重排序(Rerank)"
|
||||
},
|
||||
"noTemplates": "暂无{type}类型的提供商模板"
|
||||
"noTemplates": "暂无该类型的提供商模板"
|
||||
},
|
||||
"config": {
|
||||
"addTitle": "新增",
|
||||
|
||||
@@ -30,6 +30,10 @@
|
||||
<v-icon start>mdi-message-text</v-icon>
|
||||
{{ tm('providers.tabs.chatCompletion') }}
|
||||
</v-tab>
|
||||
<v-tab value="agent_runner" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-message-text</v-icon>
|
||||
{{ tm('providers.tabs.agentRunner') }}
|
||||
</v-tab>
|
||||
<v-tab value="speech_to_text" class="font-weight-medium px-3">
|
||||
<v-icon start>mdi-microphone-message</v-icon>
|
||||
{{ tm('providers.tabs.speechToText') }}
|
||||
@@ -48,30 +52,62 @@
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
|
||||
<v-row v-if="filteredProviders.length === 0">
|
||||
<v-col cols="12" class="text-center pa-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
|
||||
<p class="text-grey mt-4">{{ getEmptyText() }}</p>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<v-row v-else>
|
||||
<v-col v-for="(provider, index) in filteredProviders" :key="index" cols="12" md="6" lg="4" xl="3">
|
||||
<item-card :item="provider" title-field="id" enabled-field="enable"
|
||||
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
|
||||
@copy="copyProvider" :show-copy-button="true">
|
||||
<template #actions="{ item }">
|
||||
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
|
||||
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
|
||||
{{ tm('availability.test') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<template v-slot:details="{ item }">
|
||||
</template>
|
||||
</item-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<template v-if="activeProviderTypeTab === 'all'">
|
||||
<v-row v-if="groupedProviders.length === 0">
|
||||
<v-col cols="12" class="text-center pa-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
|
||||
<p class="text-grey mt-4">{{ getEmptyText() }}</p>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<div v-else>
|
||||
<div v-for="group in groupedProviders" :key="group.typeKey" class="mb-8">
|
||||
<h1 class="text-h3 font-weight-bold mb-4">{{ group.label }}</h1>
|
||||
<v-row>
|
||||
<v-col v-for="(provider, index) in group.items" :key="`${group.typeKey}-${index}`" cols="12" md="6"
|
||||
lg="4" xl="3">
|
||||
<item-card :item="provider" title-field="id" enabled-field="enable"
|
||||
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
|
||||
@copy="copyProvider" :show-copy-button="true">
|
||||
<template #actions="{ item }">
|
||||
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
|
||||
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
|
||||
{{ tm('availability.test') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<template v-slot:details="{ item }">
|
||||
</template>
|
||||
</item-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<template v-else>
|
||||
<v-row v-if="filteredProviders.length === 0">
|
||||
<v-col cols="12" class="text-center pa-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
|
||||
<p class="text-grey mt-4">{{ getEmptyText() }}</p>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-row v-else>
|
||||
<v-col v-for="(provider, index) in filteredProviders" :key="index" cols="12" md="6" lg="4" xl="3">
|
||||
<item-card :item="provider" title-field="id" enabled-field="enable"
|
||||
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
|
||||
@copy="copyProvider" :show-copy-button="true">
|
||||
<template #actions="{ item }">
|
||||
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
|
||||
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
|
||||
{{ tm('availability.test') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<template v-slot:details="{ item }">
|
||||
</template>
|
||||
</item-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- 供应商状态部分 -->
|
||||
@@ -289,8 +325,8 @@ export default {
|
||||
"anthropic_chat_completion": "chat_completion",
|
||||
"googlegenai_chat_completion": "chat_completion",
|
||||
"zhipu_chat_completion": "chat_completion",
|
||||
"dify": "chat_completion",
|
||||
"coze": "chat_completion",
|
||||
"dify": "agent_runner",
|
||||
"coze": "agent_runner",
|
||||
"dashscope": "chat_completion",
|
||||
"openai_whisper_api": "speech_to_text",
|
||||
"openai_whisper_selfhost": "speech_to_text",
|
||||
@@ -334,6 +370,7 @@ export default {
|
||||
},
|
||||
tabTypes: {
|
||||
'chat_completion': this.tm('providers.tabs.chatCompletion'),
|
||||
'agent_runner': this.tm('providers.tabs.agentRunner'),
|
||||
'speech_to_text': this.tm('providers.tabs.speechToText'),
|
||||
'text_to_speech': this.tm('providers.tabs.textToSpeech'),
|
||||
'embedding': this.tm('providers.tabs.embedding'),
|
||||
@@ -363,6 +400,52 @@ export default {
|
||||
};
|
||||
},
|
||||
|
||||
groupedProviders() {
|
||||
if (!this.config_data.provider) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const typeOrder = [
|
||||
'chat_completion',
|
||||
'agent_runner',
|
||||
'speech_to_text',
|
||||
'text_to_speech',
|
||||
'embedding',
|
||||
'rerank',
|
||||
];
|
||||
|
||||
const assigned = new Set();
|
||||
const groups = typeOrder
|
||||
.map((typeKey) => {
|
||||
const items = this.config_data.provider.filter((provider) => {
|
||||
const resolved = this.getProviderType(provider);
|
||||
if (resolved === typeKey) {
|
||||
assigned.add(provider.id);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return {
|
||||
typeKey,
|
||||
label: this.messages.tabTypes[typeKey] || typeKey,
|
||||
items,
|
||||
};
|
||||
})
|
||||
.filter((group) => group.items.length > 0);
|
||||
|
||||
const remaining = this.config_data.provider.filter(
|
||||
(provider) => !assigned.has(provider.id),
|
||||
);
|
||||
if (remaining.length > 0) {
|
||||
groups.push({
|
||||
typeKey: 'others',
|
||||
label: this.tm('providers.tabs.all'),
|
||||
items: remaining,
|
||||
});
|
||||
}
|
||||
return groups;
|
||||
},
|
||||
|
||||
// 根据选择的标签过滤提供商列表
|
||||
filteredProviders() {
|
||||
if (!this.config_data.provider || this.activeProviderTypeTab === 'all') {
|
||||
@@ -371,13 +454,7 @@ export default {
|
||||
|
||||
return this.config_data.provider.filter(provider => {
|
||||
// 如果provider.provider_type已经存在,直接使用它
|
||||
if (provider.provider_type) {
|
||||
return provider.provider_type === this.activeProviderTypeTab;
|
||||
}
|
||||
|
||||
// 否则使用映射关系
|
||||
const mappedType = this.oldVersionProviderTypeMapping[provider.type];
|
||||
return mappedType === this.activeProviderTypeTab;
|
||||
return this.getProviderType(provider) === this.activeProviderTypeTab;
|
||||
});
|
||||
}
|
||||
},
|
||||
@@ -387,6 +464,14 @@ export default {
|
||||
},
|
||||
|
||||
methods: {
|
||||
getProviderType(provider) {
|
||||
if (!provider) return undefined;
|
||||
if (provider.provider_type) {
|
||||
return provider.provider_type;
|
||||
}
|
||||
return this.oldVersionProviderTypeMapping[provider.type];
|
||||
},
|
||||
|
||||
getConfig() {
|
||||
axios.get('/api/config/get').then((res) => {
|
||||
this.config_data = res.data.data.config;
|
||||
@@ -690,6 +775,9 @@ export default {
|
||||
if (!provider.enable) {
|
||||
throw new Error('该提供商未被用户启用');
|
||||
}
|
||||
if (provider.provider_type === 'agent_runner') {
|
||||
throw new Error('暂时无法测试 Agent Runner 类型的提供商');
|
||||
}
|
||||
|
||||
const res = await axios.get(`/api/config/provider/check_one?id=${provider.id}`);
|
||||
if (res.data && res.data.status === 'ok') {
|
||||
|
||||
@@ -2,14 +2,18 @@ import datetime
|
||||
|
||||
from astrbot.api import logger, sp, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.provider.sources.coze_source import ProviderCoze
|
||||
from astrbot.core.provider.sources.dify_source import ProviderDify
|
||||
|
||||
from ..long_term_memory import LongTermMemory
|
||||
from .utils.rst_scene import RstScene
|
||||
|
||||
THIRD_PARTY_AGENT_RUNNER_KEY = {
|
||||
"dify": "dify_conversation_id",
|
||||
"coze": "coze_conversation_id",
|
||||
}
|
||||
THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
|
||||
|
||||
|
||||
class ConversationCommands:
|
||||
def __init__(self, context: star.Context, ltm: LongTermMemory | None = None):
|
||||
@@ -38,6 +42,7 @@ class ConversationCommands:
|
||||
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
"""重置 LLM 会话"""
|
||||
umo = message.unified_msg_origin
|
||||
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||
is_unique_session = cfg["platform_settings"]["unique_session"]
|
||||
is_group = bool(message.get_group_id())
|
||||
@@ -62,28 +67,23 @@ class ConversationCommands:
|
||||
)
|
||||
return
|
||||
|
||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
|
||||
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
|
||||
await sp.remove_async(
|
||||
scope="umo",
|
||||
scope_id=umo,
|
||||
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
|
||||
)
|
||||
message.set_result(MessageEventResult().message("重置对话成功。"))
|
||||
return
|
||||
|
||||
if not self.context.get_using_provider(umo):
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
message.unified_msg_origin,
|
||||
)
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
|
||||
|
||||
if not cid:
|
||||
message.set_result(
|
||||
@@ -94,7 +94,7 @@ class ConversationCommands:
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.update_conversation(
|
||||
message.unified_msg_origin,
|
||||
umo,
|
||||
cid,
|
||||
[],
|
||||
)
|
||||
@@ -151,29 +151,14 @@ class ConversationCommands:
|
||||
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1):
|
||||
"""查看对话列表"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
"""原有的Dify处理逻辑保持不变"""
|
||||
parts = ["Dify 对话列表:\n"]
|
||||
assert isinstance(provider, ProviderDify)
|
||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||
idx = 1
|
||||
for conv in data["data"]:
|
||||
ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime(
|
||||
"%m-%d %H:%M",
|
||||
)
|
||||
parts.append(
|
||||
f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n"
|
||||
)
|
||||
idx += 1
|
||||
if idx == 1:
|
||||
parts.append("没有找到任何对话。")
|
||||
dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None)
|
||||
parts.append(
|
||||
f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。"
|
||||
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
|
||||
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。",
|
||||
),
|
||||
)
|
||||
ret = "".join(parts)
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
return
|
||||
|
||||
size_per_page = 6
|
||||
@@ -241,15 +226,15 @@ class ConversationCommands:
|
||||
|
||||
async def new_conv(self, message: AstrMessageEvent):
|
||||
"""创建新对话"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message("成功,下次聊天将是新对话。"),
|
||||
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
|
||||
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
|
||||
await sp.remove_async(
|
||||
scope="umo",
|
||||
scope_id=message.unified_msg_origin,
|
||||
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
|
||||
)
|
||||
message.set_result(MessageEventResult().message("已创建新对话。"))
|
||||
return
|
||||
|
||||
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
|
||||
@@ -272,19 +257,9 @@ class ConversationCommands:
|
||||
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
|
||||
"""创建新群聊对话"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message("成功,下次聊天将是新对话。"),
|
||||
)
|
||||
return
|
||||
if sid:
|
||||
session = str(
|
||||
MessageSesion(
|
||||
MessageSession(
|
||||
platform_name=message.platform_meta.id,
|
||||
message_type=MessageType("GroupMessage"),
|
||||
session_id=sid,
|
||||
@@ -319,31 +294,6 @@ class ConversationCommands:
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify), "provider type is not dify"
|
||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||
if not data["data"]:
|
||||
message.set_result(MessageEventResult().message("未找到任何对话。"))
|
||||
return
|
||||
selected_conv = None
|
||||
if index is not None:
|
||||
try:
|
||||
selected_conv = data["data"][index - 1]
|
||||
except IndexError:
|
||||
message.set_result(
|
||||
MessageEventResult().message("对话序号错误,请使用 /ls 查看"),
|
||||
)
|
||||
return
|
||||
else:
|
||||
selected_conv = data["data"][0]
|
||||
ret = (
|
||||
f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。"
|
||||
)
|
||||
provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"]
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
return
|
||||
|
||||
if index is None:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
@@ -376,19 +326,6 @@ class ConversationCommands:
|
||||
if not new_name:
|
||||
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
cid = provider.conversation_ids.get(message.unified_msg_origin, None)
|
||||
if not cid:
|
||||
message.set_result(MessageEventResult().message("未找到当前对话。"))
|
||||
return
|
||||
await provider.api_client.rename(cid, new_name, message.unified_msg_origin)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.update_conversation_title(
|
||||
message.unified_msg_origin,
|
||||
new_name,
|
||||
@@ -408,20 +345,14 @@ class ConversationCommands:
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None)
|
||||
if dify_cid:
|
||||
await provider.api_client.delete_chat_conv(
|
||||
message.unified_msg_origin,
|
||||
dify_cid,
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。",
|
||||
),
|
||||
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
|
||||
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
|
||||
await sp.remove_async(
|
||||
scope="umo",
|
||||
scope_id=message.unified_msg_origin,
|
||||
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
|
||||
)
|
||||
message.set_result(MessageEventResult().message("重置对话成功。"))
|
||||
return
|
||||
|
||||
session_curr_cid = (
|
||||
|
||||
@@ -5,7 +5,6 @@ from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api.provider import LLMResponse, ProviderRequest
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.provider.sources.dify_source import ProviderDify
|
||||
|
||||
from .commands import (
|
||||
AdminCommands,
|
||||
@@ -279,33 +278,20 @@ class Main(star.Star):
|
||||
return
|
||||
try:
|
||||
conv = None
|
||||
if provider.meta().type != "dify":
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
|
||||
if not session_curr_cid:
|
||||
logger.error(
|
||||
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
|
||||
)
|
||||
return
|
||||
if not session_curr_cid:
|
||||
logger.error(
|
||||
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
|
||||
)
|
||||
return
|
||||
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
session_curr_cid,
|
||||
)
|
||||
else:
|
||||
# Dify 自己有维护对话,不需要 bot 端维护。
|
||||
assert isinstance(provider, ProviderDify)
|
||||
cid = provider.conversation_ids.get(
|
||||
event.unified_msg_origin,
|
||||
None,
|
||||
)
|
||||
if cid is None:
|
||||
logger.error(
|
||||
"[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
|
||||
)
|
||||
return
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
session_curr_cid,
|
||||
)
|
||||
|
||||
prompt = event.message_str
|
||||
|
||||
|
||||
Reference in New Issue
Block a user