feat(third-party-agent): implement streaming response handling and enhance agent execution flow

This commit is contained in:
Soulter
2025-11-23 23:03:56 +08:00
parent 520f521887
commit a6dc458212

View File

@@ -1,5 +1,6 @@
import asyncio
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from astrbot.core import logger
from astrbot.core.agent.runners.coze_agent_runner import CozeAgentRunner
@@ -7,9 +8,13 @@ from astrbot.core.agent.runners.dashscope_agent_runner import DashscopeAgentRunn
from astrbot.core.agent.runners.dify_agent_runner import DifyAgentRunner
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
ProviderRequest,
@@ -29,6 +34,32 @@ AGENT_RUNNER_TYPE_KEY = {
}
async def run_third_party_agent(
runner: "BaseAgentRunner",
stream_to_general: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""
运行第三方 agent runner 并转换响应格式
类似于 run_agent 函数,但专门处理第三方 agent runner
"""
try:
async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
if resp.type == "streaming_delta":
if stream_to_general:
continue
yield resp.data["chain"]
elif resp.type == "llm_result":
if stream_to_general:
yield resp.data["chain"]
except Exception as e:
logger.error(f"Third party agent runner error: {e}")
err_msg = (
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
)
yield MessageChain().message(err_msg)
class ThirdPartyAgentSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
@@ -38,10 +69,11 @@ class ThirdPartyAgentSubStage(Stage):
AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
"",
)
self.prov_cfg: dict = next(
(p for p in self.conf["provider"] if p["id"] == self.prov_id),
{},
)
settings = ctx.astrbot_config["provider_settings"]
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
]
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
@@ -52,6 +84,11 @@ class ThirdPartyAgentSubStage(Stage):
provider_wake_prefix
):
return
self.prov_cfg: dict = next(
(p for p in self.conf["provider"] if p["id"] == self.prov_id),
{},
)
if not self.prov_id or not self.prov_cfg:
logger.error(
"Third Party Agent Runner provider ID is not configured properly."
@@ -90,6 +127,15 @@ class ThirdPartyAgentSubStage(Stage):
event=event,
)
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
await runner.reset(
request=req,
run_context=AgentContextWrapper(
@@ -98,24 +144,52 @@ class ThirdPartyAgentSubStage(Stage):
),
agent_hooks=MAIN_AGENT_HOOKS,
provider_config=self.prov_cfg,
streaming=streaming_response,
)
async for _ in runner.step_until_done():
pass
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_third_party_agent(
runner,
stream_to_general=False,
),
),
)
yield
if runner.done():
final_resp = runner.get_final_llm_resp()
if final_resp and final_resp.result_chain:
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
# 非流式响应或转换为普通响应
async for _ in run_third_party_agent(
runner,
stream_to_general=stream_to_general,
):
yield
final_resp = runner.get_final_llm_resp()
final_resp = runner.get_final_llm_resp()
if not final_resp or not final_resp.result_chain:
logger.warning("Agent Runner 未返回最终结果。")
return
if not final_resp or not final_resp.result_chain:
logger.warning("Agent Runner 未返回最终结果。")
return
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.LLM_RESULT,
),
)
yield
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.LLM_RESULT,
),
)
yield
asyncio.create_task(
Metric.upload(