fix: typing error
This commit is contained in:
@@ -44,7 +44,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@@ -102,7 +102,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=input_,
|
||||
system_prompt=tool.description,
|
||||
system_prompt=tool.description or "",
|
||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||
func_tool=toolset,
|
||||
@@ -239,7 +239,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
yield res
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
@@ -337,7 +337,7 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
"""选择使用的 LLM 提供商"""
|
||||
sel_provider = event.get_extra("selected_provider")
|
||||
_ctx = self.ctx.plugin_manager.context
|
||||
@@ -382,6 +382,9 @@ class LLMRequestSubStage(Stage):
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
@@ -520,8 +523,10 @@ class LLMRequestSubStage(Stage):
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
)
|
||||
else:
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
@@ -553,6 +558,8 @@ class LLMRequestSubStage(Stage):
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
if not req.conversation:
|
||||
return
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
|
||||
from typing import List, Union, Optional, AsyncGenerator, Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
@@ -26,8 +26,6 @@ from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
def __init__(
|
||||
@@ -177,9 +175,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self._extras[key] = value
|
||||
|
||||
def get_extra(
|
||||
self, key: str | None = None, default: _VT = None
|
||||
) -> dict[str, Any] | _VT:
|
||||
def get_extra(self, key: str | None = None, default=None) -> Any:
|
||||
"""
|
||||
获取额外的信息。
|
||||
"""
|
||||
|
||||
@@ -10,7 +10,7 @@ from anthropic.types import Message
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from typing import AsyncGenerator
|
||||
@@ -104,7 +104,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
@@ -135,7 +135,7 @@ class ProviderAnthropic(Provider):
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
self, payloads: dict, tools: ToolSet | None
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
@@ -326,7 +326,7 @@ class ProviderAnthropic(Provider):
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
async def assemble_context(self, text: str, image_urls: List[str] | None = None):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import re
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
from astrbot.core import logger, sp
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
|
||||
|
||||
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
||||
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
session_id=None,
|
||||
image_urls=[],
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
assert isinstance(response, ApplicationResponse)
|
||||
|
||||
logger.debug(f"dashscope resp: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
),
|
||||
)
|
||||
|
||||
output_text = response.output.get("text", "")
|
||||
output_text = response.output.get("text", "") or ""
|
||||
# RAG 引用脚标格式化
|
||||
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
||||
if self.output_reference and response.output.get("doc_references", None):
|
||||
ref_str = ""
|
||||
for ref in response.output.get("doc_references", []):
|
||||
for ref in response.output.get("doc_references", []) or []:
|
||||
ref_title = (
|
||||
ref.get("title", "")
|
||||
if ref.get("title")
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import astrbot.core.message.components as Comp
|
||||
import os
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
@@ -55,11 +53,11 @@ class ProviderDify(Provider):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
@@ -223,7 +221,7 @@ class ProviderDify(Provider):
|
||||
# Chat
|
||||
return MessageChain(chain=[Comp.Plain(chunk)])
|
||||
|
||||
async def parse_file(item: dict) -> Comp:
|
||||
async def parse_file(item: dict):
|
||||
match item["type"]:
|
||||
case "image":
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
|
||||
@@ -16,7 +16,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from typing import List, AsyncGenerator
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
@@ -49,7 +49,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
api_version=provider_config.get("api_version", None),
|
||||
base_url=provider_config.get("api_base", None),
|
||||
base_url=provider_config.get("api_base", ""),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
else:
|
||||
@@ -79,7 +79,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
|
||||
if tools:
|
||||
model = payloads.get("model", "").lower()
|
||||
omit_empty_param_field = "gemini" in model
|
||||
@@ -126,7 +126,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
self, payloads: dict, tools: ToolSet
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式查询API,逐步返回结果"""
|
||||
if tools:
|
||||
@@ -183,9 +183,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
yield llm_response
|
||||
|
||||
async def parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: FuncCall
|
||||
):
|
||||
async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
|
||||
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
@@ -208,7 +206,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# workaround for #1359
|
||||
tool_call = json.loads(tool_call)
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
if (
|
||||
tool_call.type == "function"
|
||||
and tool.name == tool_call.function.name
|
||||
):
|
||||
# workaround for #1454
|
||||
if isinstance(tool_call.function.arguments, str):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
@@ -277,7 +278,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
e: Exception,
|
||||
payloads: dict,
|
||||
context_query: list,
|
||||
func_tool: FuncCall,
|
||||
func_tool: ToolSet,
|
||||
chosen_key: str,
|
||||
available_api_keys: List[str],
|
||||
retry_cnt: int,
|
||||
@@ -420,7 +421,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if success:
|
||||
break
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
if retry_cnt == max_retries - 1 or llm_response is None:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
@@ -430,10 +431,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
@@ -526,7 +527,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
|
||||
async def assemble_context(
|
||||
self, text: str, image_urls: List[str] | None = None
|
||||
) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
if image_urls:
|
||||
user_content = {
|
||||
|
||||
@@ -30,7 +30,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
|
||||
Reference in New Issue
Block a user