fix: typing error

This commit is contained in:
Soulter
2025-10-21 10:56:44 +08:00
parent a0f8f3ae32
commit 36ffcf3cc3
7 changed files with 53 additions and 48 deletions

View File

@@ -44,7 +44,7 @@ except (ModuleNotFoundError, ImportError):
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@@ -102,7 +102,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description,
system_prompt=tool.description or "",
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
@@ -239,7 +239,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
yield res
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
@@ -337,7 +337,7 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
def _select_provider(self, event: AstrMessageEvent):
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
@@ -382,6 +382,9 @@ class LLMRequestSubStage(Stage):
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
@@ -520,8 +523,10 @@ class LLMRequestSubStage(Stage):
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
else:
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
@@ -553,6 +558,8 @@ class LLMRequestSubStage(Stage):
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
if not req.conversation:
return
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)

View File

@@ -4,7 +4,7 @@ import re
import hashlib
import uuid
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
from typing import List, Union, Optional, AsyncGenerator, Any
from astrbot import logger
from astrbot.core.db.po import Conversation
@@ -26,8 +26,6 @@ from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
_VT = TypeVar("_VT")
class AstrMessageEvent(abc.ABC):
def __init__(
@@ -177,9 +175,7 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(
self, key: str | None = None, default: _VT = None
) -> dict[str, Any] | _VT:
def get_extra(self, key: str | None = None, default=None) -> Any:
"""
获取额外的信息。
"""

View File

@@ -10,7 +10,7 @@ from anthropic.types import Message
from astrbot.core.utils.io import download_image_by_url
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.func_tool_manager import ToolSet
from ..register import register_provider_adapter
from astrbot.core.provider.entities import LLMResponse
from typing import AsyncGenerator
@@ -104,7 +104,7 @@ class ProviderAnthropic(Provider):
return system_prompt, new_messages
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
@@ -135,7 +135,7 @@ class ProviderAnthropic(Provider):
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
self, payloads: dict, tools: ToolSet | None
) -> AsyncGenerator[LLMResponse, None]:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
@@ -326,7 +326,7 @@ class ProviderAnthropic(Provider):
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
async def assemble_context(self, text: str, image_urls: List[str] | None = None):
"""组装上下文,支持文本和图片"""
if not image_urls:
return {"role": "user", "content": text}

View File

@@ -1,15 +1,14 @@
import re
import asyncio
import functools
from typing import List
from .. import Provider, Personality
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
session_id=None,
image_urls=[],
func_tool=None,
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
assert isinstance(response, ApplicationResponse)
logger.debug(f"dashscope resp: {response}")
if response.status_code != 200:
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
),
)
output_text = response.output.get("text", "")
output_text = response.output.get("text", "") or ""
# RAG 引用脚标格式化
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
if self.output_reference and response.output.get("doc_references", None):
ref_str = ""
for ref in response.output.get("doc_references", []):
for ref in response.output.get("doc_references", []) or []:
ref_title = (
ref.get("title", "")
if ref.get("title")

View File

@@ -1,9 +1,7 @@
import astrbot.core.message.components as Comp
import os
from typing import List
from .. import Provider
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url, download_file
@@ -55,11 +53,11 @@ class ProviderDify(Provider):
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
@@ -223,7 +221,7 @@ class ProviderDify(Provider):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict) -> Comp:
async def parse_file(item: dict):
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])

View File

@@ -16,7 +16,7 @@ from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.func_tool_manager import ToolSet
from typing import List, AsyncGenerator
from ..register import register_provider_adapter
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
@@ -49,7 +49,7 @@ class ProviderOpenAIOfficial(Provider):
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
base_url=provider_config.get("api_base", None),
base_url=provider_config.get("api_base", ""),
timeout=self.timeout,
)
else:
@@ -79,7 +79,7 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
@@ -126,7 +126,7 @@ class ProviderOpenAIOfficial(Provider):
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
self, payloads: dict, tools: ToolSet
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API逐步返回结果"""
if tools:
@@ -183,9 +183,7 @@ class ProviderOpenAIOfficial(Provider):
yield llm_response
async def parse_openai_completion(
self, completion: ChatCompletion, tools: FuncCall
):
async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
"""解析 OpenAI 的 ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
@@ -208,7 +206,10 @@ class ProviderOpenAIOfficial(Provider):
# workaround for #1359
tool_call = json.loads(tool_call)
for tool in tools.func_list:
if tool.name == tool_call.function.name:
if (
tool_call.type == "function"
and tool.name == tool_call.function.name
):
# workaround for #1454
if isinstance(tool_call.function.arguments, str):
args = json.loads(tool_call.function.arguments)
@@ -277,7 +278,7 @@ class ProviderOpenAIOfficial(Provider):
e: Exception,
payloads: dict,
context_query: list,
func_tool: FuncCall,
func_tool: ToolSet,
chosen_key: str,
available_api_keys: List[str],
retry_cnt: int,
@@ -420,7 +421,7 @@ class ProviderOpenAIOfficial(Provider):
if success:
break
if retry_cnt == max_retries - 1:
if retry_cnt == max_retries - 1 or llm_response is None:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
if last_exception is None:
raise Exception("未知错误")
@@ -430,10 +431,10 @@ class ProviderOpenAIOfficial(Provider):
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
@@ -526,7 +527,9 @@ class ProviderOpenAIOfficial(Provider):
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
async def assemble_context(
self, text: str, image_urls: List[str] | None = None
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
if image_urls:
user_content = {

View File

@@ -30,7 +30,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
timeout=timeout,
)
self.set_model(provider_config.get("model", None))
self.set_model(provider_config.get("model", ""))
async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")