diff --git a/astrbot/core/agent/runners/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope_agent_runner.py index 6a7cada5..bcb03068 100644 --- a/astrbot/core/agent/runners/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope_agent_runner.py @@ -1,7 +1,9 @@ import asyncio import functools +import queue import re import sys +import threading import typing as T from dashscope import Application @@ -144,17 +146,116 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]): result.append(ctx) return result - async def _execute_dashscope_request(self): - """执行 Dashscope 请求的核心逻辑""" - prompt = self.req.prompt or "" - session_id = self.req.session_id or "unknown" - image_urls = self.req.image_urls or [] - contexts = self.req.contexts or [] - system_prompt = self.req.system_prompt + def _consume_sync_generator( + self, response: T.Any, response_queue: queue.Queue + ) -> None: + """在线程中消费同步generator,将结果放入队列 + Args: + response: 同步generator对象 + response_queue: 用于传递数据的队列 + + """ + try: + if self.streaming: + for chunk in response: + response_queue.put(("data", chunk)) + else: + response_queue.put(("data", response)) + except Exception as e: + response_queue.put(("error", e)) + finally: + response_queue.put(("done", None)) + + async def _process_stream_chunk( + self, chunk: ApplicationResponse, output_text: str + ) -> tuple[str, list | None, AgentResponse | None]: + """处理流式响应的单个chunk + + Args: + chunk: Dashscope响应chunk + output_text: 当前累积的输出文本 + + Returns: + (更新后的output_text, doc_references, AgentResponse或None) + + """ + logger.debug(f"dashscope stream chunk: {chunk}") + + if chunk.status_code != 200: + logger.error( + f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", + ) + self._transition_state(AgentState.ERROR) + error_msg = ( + f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}" + ) + self.final_llm_resp = LLMResponse( + role="err", + result_chain=MessageChain().message(error_msg), + ) + return ( + output_text, + None, + AgentResponse( + type="err", + data=AgentResponseData(chain=MessageChain().message(error_msg)), + ), + ) + + chunk_text = chunk.output.get("text", "") or "" + # RAG 引用脚标格式化 + chunk_text = re.sub(r"\[(\d+)\]", r"[\1]", chunk_text) + + response = None + if chunk_text: + output_text += chunk_text + response = AgentResponse( + type="streaming_delta", + data=AgentResponseData(chain=MessageChain().message(chunk_text)), + ) + + # 获取文档引用 + doc_references = chunk.output.get("doc_references", None) + + return output_text, doc_references, response + + def _format_doc_references(self, doc_references: list) -> str: + """格式化文档引用为文本 + + Args: + doc_references: 文档引用列表 + + Returns: + 格式化后的引用文本 + + """ + ref_parts = [] + for ref in doc_references: + ref_title = ( + ref.get("title", "") if ref.get("title") else ref.get("doc_name", "") + ) + ref_parts.append(f"{ref['index_id']}. {ref_title}\n") + ref_str = "".join(ref_parts) + return f"\n\n回答来源:\n{ref_str}" + + async def _build_request_payload( + self, prompt: str, session_id: str, contexts: list, system_prompt: str + ) -> dict: + """构建请求payload + + Args: + prompt: 用户输入 + session_id: 会话ID + contexts: 上下文列表 + system_prompt: 系统提示词 + + Returns: + 请求payload字典 + + """ # 获得会话变量 payload_vars = self.variables.copy() - # 动态变量 session_var = await sp.get_async( scope="umo", scope_id=session_id, @@ -169,8 +270,6 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]): ): # 支持多轮对话的 new_record = {"role": "user", "content": prompt} - if image_urls: - logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") contexts_no_img = await self._remove_image_from_context(contexts) context_query = [*contexts_no_img, new_record] if system_prompt: @@ -178,74 +277,91 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]): for part in context_query: if "_no_save" in part: del part["_no_save"] - # 调用阿里云百炼 API - payload = { + + return { "app_id": self.app_id, "api_key": self.api_key, "messages": context_query, "biz_params": payload_vars or None, + "stream": self.streaming, + "incremental_output": True, } - partial = functools.partial( - Application.call, - **payload, - ) - response = await asyncio.get_event_loop().run_in_executor(None, partial) else: # 不支持多轮对话的 - # 调用阿里云百炼 API payload = { "app_id": self.app_id, "prompt": prompt, "api_key": self.api_key, "biz_params": payload_vars or None, + "stream": self.streaming, + "incremental_output": True, } if self.rag_options: payload["rag_options"] = self.rag_options - partial = functools.partial( - Application.call, - **payload, - ) - response = await asyncio.get_event_loop().run_in_executor(None, partial) + return payload - assert isinstance(response, ApplicationResponse) + async def _handle_streaming_response(self, response: T.Any): + """处理流式响应 - logger.debug(f"dashscope resp: {response}") + Args: + response: Dashscope 流式响应 generator - if response.status_code != 200: - logger.error( - f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", - ) - self._transition_state(AgentState.ERROR) - self.final_llm_resp = LLMResponse( - role="err", - result_chain=MessageChain().message( - f"阿里云百炼请求失败: message={response.message} code={response.status_code}", - ), - ) - yield AgentResponse( - type="err", - data=AgentResponseData( - chain=MessageChain().message( - f"阿里云百炼请求失败: message={response.message} code={response.status_code}" - ) - ), - ) - return + Yields: + AgentResponse 对象 - output_text = response.output.get("text", "") or "" - # RAG 引用脚标格式化 - output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) - if self.output_reference and response.output.get("doc_references", None): - ref_parts = [] - for ref in response.output.get("doc_references", []) or []: - ref_title = ( - ref.get("title", "") - if ref.get("title") - else ref.get("doc_name", "") + """ + response_queue = queue.Queue() + consumer_thread = threading.Thread( + target=self._consume_sync_generator, + args=(response, response_queue), + daemon=True, + ) + consumer_thread.start() + + output_text = "" + doc_references = None + + while True: + try: + item_type, item_data = await asyncio.get_event_loop().run_in_executor( + None, response_queue.get, True, 1 + ) + except queue.Empty: + continue + + if item_type == "done": + break + elif item_type == "error": + raise item_data + elif item_type == "data": + chunk = item_data + assert isinstance(chunk, ApplicationResponse) + + ( + output_text, + chunk_doc_refs, + response, + ) = await self._process_stream_chunk(chunk, output_text) + + if response: + if response.type == "err": + yield response + return + yield response + + if chunk_doc_refs: + doc_references = chunk_doc_refs + + # 添加 RAG 引用 + if self.output_reference and doc_references: + ref_text = self._format_doc_references(doc_references) + output_text += ref_text + + if self.streaming: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData(chain=MessageChain().message(ref_text)), ) - ref_parts.append(f"{ref['index_id']}. {ref_title}\n") - ref_str = "".join(ref_parts) - output_text += f"\n\n回答来源:\n{ref_str}" # 创建最终响应 chain = MessageChain(chain=[Comp.Plain(output_text)]) @@ -263,6 +379,33 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]): data=AgentResponseData(chain=chain), ) + async def _execute_dashscope_request(self): + """执行 Dashscope 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + contexts = self.req.contexts or [] + system_prompt = self.req.system_prompt + + # 检查图片输入 + if image_urls: + logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") + + # 构建请求payload + payload = await self._build_request_payload( + prompt, session_id, contexts, system_prompt + ) + + if not self.streaming: + payload["incremental_output"] = False + + # 发起请求 + partial = functools.partial(Application.call, **payload) + response = await asyncio.get_event_loop().run_in_executor(None, partial) + + async for resp in self._handle_streaming_response(response): + yield resp + @override def done(self) -> bool: """检查 Agent 是否已完成工作"""