diff --git a/astrbot/core/agent/runners/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope_agent_runner.py
index 6a7cada5..bcb03068 100644
--- a/astrbot/core/agent/runners/dashscope_agent_runner.py
+++ b/astrbot/core/agent/runners/dashscope_agent_runner.py
@@ -1,7 +1,9 @@
import asyncio
import functools
+import queue
import re
import sys
+import threading
import typing as T
from dashscope import Application
@@ -144,17 +146,116 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
result.append(ctx)
return result
- async def _execute_dashscope_request(self):
- """执行 Dashscope 请求的核心逻辑"""
- prompt = self.req.prompt or ""
- session_id = self.req.session_id or "unknown"
- image_urls = self.req.image_urls or []
- contexts = self.req.contexts or []
- system_prompt = self.req.system_prompt
+ def _consume_sync_generator(
+ self, response: T.Any, response_queue: queue.Queue
+ ) -> None:
+ """在线程中消费同步generator,将结果放入队列
+ Args:
+ response: 同步generator对象
+ response_queue: 用于传递数据的队列
+
+ """
+ try:
+ if self.streaming:
+ for chunk in response:
+ response_queue.put(("data", chunk))
+ else:
+ response_queue.put(("data", response))
+ except Exception as e:
+ response_queue.put(("error", e))
+ finally:
+ response_queue.put(("done", None))
+
+ async def _process_stream_chunk(
+ self, chunk: ApplicationResponse, output_text: str
+ ) -> tuple[str, list | None, AgentResponse | None]:
+ """处理流式响应的单个chunk
+
+ Args:
+ chunk: Dashscope响应chunk
+ output_text: 当前累积的输出文本
+
+ Returns:
+ (更新后的output_text, doc_references, AgentResponse或None)
+
+ """
+ logger.debug(f"dashscope stream chunk: {chunk}")
+
+ if chunk.status_code != 200:
+ logger.error(
+ f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
+ )
+ self._transition_state(AgentState.ERROR)
+ error_msg = (
+ f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
+ )
+ self.final_llm_resp = LLMResponse(
+ role="err",
+ result_chain=MessageChain().message(error_msg),
+ )
+ return (
+ output_text,
+ None,
+ AgentResponse(
+ type="err",
+ data=AgentResponseData(chain=MessageChain().message(error_msg)),
+ ),
+ )
+
+ chunk_text = chunk.output.get("text", "") or ""
+ # RAG 引用脚标格式化
+ chunk_text = re.sub(r"[\[(\d+)\]]", r"[\1]", chunk_text)
+
+ response = None
+ if chunk_text:
+ output_text += chunk_text
+ response = AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(chain=MessageChain().message(chunk_text)),
+ )
+
+ # 获取文档引用
+ doc_references = chunk.output.get("doc_references", None)
+
+ return output_text, doc_references, response
+
+ def _format_doc_references(self, doc_references: list) -> str:
+ """格式化文档引用为文本
+
+ Args:
+ doc_references: 文档引用列表
+
+ Returns:
+ 格式化后的引用文本
+
+ """
+ ref_parts = []
+ for ref in doc_references:
+ ref_title = (
+ ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
+ )
+ ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
+ ref_str = "".join(ref_parts)
+ return f"\n\n回答来源:\n{ref_str}"
+
+ async def _build_request_payload(
+ self, prompt: str, session_id: str, contexts: list, system_prompt: str
+ ) -> dict:
+ """构建请求payload
+
+ Args:
+ prompt: 用户输入
+ session_id: 会话ID
+ contexts: 上下文列表
+ system_prompt: 系统提示词
+
+ Returns:
+ 请求payload字典
+
+ """
# 获得会话变量
payload_vars = self.variables.copy()
- # 动态变量
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
@@ -169,8 +270,6 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
):
# 支持多轮对话的
new_record = {"role": "user", "content": prompt}
- if image_urls:
- logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
contexts_no_img = await self._remove_image_from_context(contexts)
context_query = [*contexts_no_img, new_record]
if system_prompt:
@@ -178,74 +277,91 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
- # 调用阿里云百炼 API
- payload = {
+
+ return {
"app_id": self.app_id,
"api_key": self.api_key,
"messages": context_query,
"biz_params": payload_vars or None,
+ "stream": self.streaming,
+ "incremental_output": True,
}
- partial = functools.partial(
- Application.call,
- **payload,
- )
- response = await asyncio.get_event_loop().run_in_executor(None, partial)
else:
# 不支持多轮对话的
- # 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
+ "stream": self.streaming,
+ "incremental_output": True,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
- partial = functools.partial(
- Application.call,
- **payload,
- )
- response = await asyncio.get_event_loop().run_in_executor(None, partial)
+ return payload
- assert isinstance(response, ApplicationResponse)
+ async def _handle_streaming_response(self, response: T.Any):
+ """处理流式响应
- logger.debug(f"dashscope resp: {response}")
+ Args:
+ response: Dashscope 流式响应 generator
- if response.status_code != 200:
- logger.error(
- f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
- )
- self._transition_state(AgentState.ERROR)
- self.final_llm_resp = LLMResponse(
- role="err",
- result_chain=MessageChain().message(
- f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
- ),
- )
- yield AgentResponse(
- type="err",
- data=AgentResponseData(
- chain=MessageChain().message(
- f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
- )
- ),
- )
- return
+ Yields:
+ AgentResponse 对象
- output_text = response.output.get("text", "") or ""
- # RAG 引用脚标格式化
- output_text = re.sub(r"[\[(\d+)\]]", r"[\1]", output_text)
- if self.output_reference and response.output.get("doc_references", None):
- ref_parts = []
- for ref in response.output.get("doc_references", []) or []:
- ref_title = (
- ref.get("title", "")
- if ref.get("title")
- else ref.get("doc_name", "")
+ """
+ response_queue = queue.Queue()
+ consumer_thread = threading.Thread(
+ target=self._consume_sync_generator,
+ args=(response, response_queue),
+ daemon=True,
+ )
+ consumer_thread.start()
+
+ output_text = ""
+ doc_references = None
+
+ while True:
+ try:
+ item_type, item_data = await asyncio.get_event_loop().run_in_executor(
+ None, response_queue.get, True, 1
+ )
+ except queue.Empty:
+ continue
+
+ if item_type == "done":
+ break
+ elif item_type == "error":
+ raise item_data
+ elif item_type == "data":
+ chunk = item_data
+ assert isinstance(chunk, ApplicationResponse)
+
+ (
+ output_text,
+ chunk_doc_refs,
+ response,
+ ) = await self._process_stream_chunk(chunk, output_text)
+
+ if response:
+ if response.type == "err":
+ yield response
+ return
+ yield response
+
+ if chunk_doc_refs:
+ doc_references = chunk_doc_refs
+
+ # 添加 RAG 引用
+ if self.output_reference and doc_references:
+ ref_text = self._format_doc_references(doc_references)
+ output_text += ref_text
+
+ if self.streaming:
+ yield AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(chain=MessageChain().message(ref_text)),
)
- ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
- ref_str = "".join(ref_parts)
- output_text += f"\n\n回答来源:\n{ref_str}"
# 创建最终响应
chain = MessageChain(chain=[Comp.Plain(output_text)])
@@ -263,6 +379,33 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
data=AgentResponseData(chain=chain),
)
+ async def _execute_dashscope_request(self):
+ """执行 Dashscope 请求的核心逻辑"""
+ prompt = self.req.prompt or ""
+ session_id = self.req.session_id or "unknown"
+ image_urls = self.req.image_urls or []
+ contexts = self.req.contexts or []
+ system_prompt = self.req.system_prompt
+
+ # 检查图片输入
+ if image_urls:
+ logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
+
+ # 构建请求payload
+ payload = await self._build_request_payload(
+ prompt, session_id, contexts, system_prompt
+ )
+
+ if not self.streaming:
+ payload["incremental_output"] = False
+
+ # 发起请求
+ partial = functools.partial(Application.call, **payload)
+ response = await asyncio.get_event_loop().run_in_executor(None, partial)
+
+ async for resp in self._handle_streaming_response(response):
+ yield resp
+
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""