import re import asyncio import functools from typing import List from .. import Provider, Personality from ..entities import LLMResponse from ..func_tool_manager import FuncCall from ..register import register_provider_adapter from astrbot.core.message.message_event_result import MessageChain from .openai_source import ProviderOpenAIOfficial from astrbot.core import logger, sp from dashscope import Application @register_provider_adapter("dashscope", "Dashscope APP 适配器。") class ProviderDashscope(ProviderOpenAIOfficial): def __init__( self, provider_config: dict, provider_settings: dict, default_persona: Personality | None = None, ) -> None: Provider.__init__( self, provider_config, provider_settings, default_persona, ) self.api_key = provider_config.get("dashscope_api_key", "") if not self.api_key: raise Exception("阿里云百炼 API Key 不能为空。") self.app_id = provider_config.get("dashscope_app_id", "") if not self.app_id: raise Exception("阿里云百炼 APP ID 不能为空。") self.dashscope_app_type = provider_config.get("dashscope_app_type", "") if not self.dashscope_app_type: raise Exception("阿里云百炼 APP 类型不能为空。") self.model_name = "dashscope" self.variables: dict = provider_config.get("variables", {}) self.rag_options: dict = provider_config.get("rag_options", {}) self.output_reference = self.rag_options.get("output_reference", False) self.rag_options = self.rag_options.copy() self.rag_options.pop("output_reference", None) self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) def has_rag_options(self): """判断是否有 RAG 选项 Returns: bool: 是否有 RAG 选项 """ if self.rag_options and ( len(self.rag_options.get("pipeline_ids", [])) > 0 or len(self.rag_options.get("file_ids", [])) > 0 ): return True return False async def text_chat( self, prompt: str, session_id: str = None, image_urls: List[str] = [], func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, model=None, **kwargs, ) -> LLMResponse: if contexts is None: contexts = [] # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 session_var = await sp.session_get(session_id, "session_variables", default={}) payload_vars.update(session_var) if ( self.dashscope_app_type in ["agent", "dialog-workflow"] and not self.has_rag_options() ): # 支持多轮对话的 new_record = {"role": "user", "content": prompt} if image_urls: logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") contexts_no_img = await self._remove_image_from_context(contexts) context_query = [*contexts_no_img, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) for part in context_query: if "_no_save" in part: del part["_no_save"] # 调用阿里云百炼 API payload = { "app_id": self.app_id, "api_key": self.api_key, "messages": context_query, "biz_params": payload_vars or None, } partial = functools.partial( Application.call, **payload, ) response = await asyncio.get_event_loop().run_in_executor(None, partial) else: # 不支持多轮对话的 # 调用阿里云百炼 API payload = { "app_id": self.app_id, "prompt": prompt, "api_key": self.api_key, "biz_params": payload_vars or None, } if self.rag_options: payload["rag_options"] = self.rag_options partial = functools.partial( Application.call, **payload, ) response = await asyncio.get_event_loop().run_in_executor(None, partial) logger.debug(f"dashscope resp: {response}") if response.status_code != 200: logger.error( f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code" ) return LLMResponse( role="err", result_chain=MessageChain().message( f"阿里云百炼请求失败: message={response.message} code={response.status_code}" ), ) output_text = response.output.get("text", "") # RAG 引用脚标格式化 output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) if self.output_reference and response.output.get("doc_references", None): ref_str = "" for ref in response.output.get("doc_references", []): ref_title = ( ref.get("title", "") if ref.get("title") else ref.get("doc_name", "") ) ref_str += f"{ref['index_id']}. {ref_title}\n" output_text += f"\n\n回答来源:\n{ref_str}" llm_response = LLMResponse("assistant") llm_response.result_chain = MessageChain().message(output_text) return llm_response async def text_chat_stream( self, prompt, session_id=None, image_urls=..., func_tool=None, contexts=..., system_prompt=None, tool_calls_result=None, model=None, **kwargs, ): # raise NotImplementedError("This method is not implemented yet.") # 调用 text_chat 模拟流式 llm_response = await self.text_chat( prompt=prompt, session_id=session_id, image_urls=image_urls, func_tool=func_tool, contexts=contexts, system_prompt=system_prompt, tool_calls_result=tool_calls_result, ) llm_response.is_chunk = True yield llm_response llm_response.is_chunk = False yield llm_response async def forget(self, session_id): return True async def get_current_key(self): return self.api_key async def set_key(self, key): raise Exception("阿里云百炼 适配器不支持设置 API Key。") async def get_models(self): return [self.get_model()] async def get_human_readable_context(self, session_id, page, page_size): raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。") async def terminate(self): pass