refactor: LLM response handling with reasoning content (#3632)
* refactor: LLM response handling with reasoning content - Added a `show_reasoning` parameter to `run_agent` to control the display of reasoning content. - Updated `LLMResponse` to include a `reasoning_content` field for storing reasoning text. - Modified `WebChatMessageEvent` to handle and send reasoning content in streaming responses. - Implemented reasoning extraction in various provider sources (e.g., OpenAI, Gemini). - Updated the chat interface to display reasoning content in a collapsible format. - Removed the deprecated `thinking_filter` package and its associated logic. - Updated localization files to include new reasoning-related strings. * feat: add Groq chat completion provider and associated configurations * Update astrbot/core/provider/sources/gemini_source.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -36,7 +36,8 @@ from astrbot.core.star.config import *
|
||||
|
||||
|
||||
# provider
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
from astrbot.core.provider import Provider, ProviderMetaData
|
||||
from astrbot.core.db.po import Personality
|
||||
|
||||
# platform
|
||||
from astrbot.core.platform import (
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from astrbot.core.provider import Personality, Provider, STTProvider
|
||||
from astrbot.core.db.po import Personality
|
||||
from astrbot.core.provider import Provider, STTProvider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderMetaData,
|
||||
|
||||
@@ -110,13 +110,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
elif llm_response.completion_text:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_response.completion_text),
|
||||
),
|
||||
)
|
||||
elif llm_response.reasoning_content:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="reasoning").message(
|
||||
llm_response.reasoning_content,
|
||||
),
|
||||
),
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
break # got final response
|
||||
@@ -177,13 +186,16 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
result.type = "tool_call_result"
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
data=AgentResponseData(chain=result),
|
||||
|
||||
@@ -18,6 +18,7 @@ async def run_agent(
|
||||
max_step: int = 30,
|
||||
show_tool_use: bool = True,
|
||||
stream_to_general: bool = False,
|
||||
show_reasoning: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
@@ -31,7 +32,6 @@ async def run_agent(
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
resp.data["chain"].type = "tool_call_result"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
@@ -40,8 +40,7 @@ async def run_agent(
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use or astr_event.get_platform_name() == "webchat":
|
||||
resp.data["chain"].type = "tool_call"
|
||||
if show_tool_use:
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
@@ -63,6 +62,10 @@ async def run_agent(
|
||||
yield
|
||||
astr_event.clear_result()
|
||||
elif resp.type == "streaming_delta":
|
||||
chain = resp.data["chain"]
|
||||
if chain.type == "reasoning" and not show_reasoning:
|
||||
# display the reasoning content only when configured
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
break
|
||||
|
||||
@@ -880,6 +880,23 @@ CONFIG_METADATA_2 = {
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"Groq": {
|
||||
"id": "groq_default",
|
||||
"provider": "groq",
|
||||
"type": "groq_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "openai/gpt-oss-20b",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
"provider": "302ai",
|
||||
|
||||
@@ -57,6 +57,7 @@ class LLMRequestSubStage(Stage):
|
||||
if isinstance(self.max_step, bool): # workaround: #2622
|
||||
self.max_step = 30
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
@@ -419,7 +420,12 @@ class LLMRequestSubStage(Stage):
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(agent_runner, self.max_step, self.show_tool_use),
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
@@ -443,7 +449,11 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner, self.max_step, self.show_tool_use, stream_to_general
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
reasoning_content = ""
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
async for chain in generator:
|
||||
@@ -124,16 +125,22 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
|
||||
r = await WebChatMessageEvent._send(
|
||||
chain,
|
||||
session_id=self.session_id,
|
||||
streaming=True,
|
||||
)
|
||||
if chain.type == "reasoning":
|
||||
reasoning_content += chain.get_plain_text()
|
||||
else:
|
||||
final_data += r
|
||||
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "complete", # complete means we return the final result
|
||||
"data": final_data,
|
||||
"reasoning": reasoning_content,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .entities import ProviderMetaData
|
||||
from .provider import Personality, Provider, STTProvider
|
||||
from .provider import Provider, STTProvider
|
||||
|
||||
__all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"]
|
||||
__all__ = ["Provider", "ProviderMetaData", "STTProvider"]
|
||||
|
||||
@@ -202,25 +202,28 @@ class ProviderRequest:
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
"""角色, assistant, tool, err"""
|
||||
"""The role of the message, e.g., assistant, tool, err"""
|
||||
result_chain: MessageChain | None = None
|
||||
"""返回的消息链"""
|
||||
"""A chain of message components representing the text completion from LLM."""
|
||||
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
|
||||
"""工具调用参数"""
|
||||
"""Tool call arguments."""
|
||||
tools_call_name: list[str] = field(default_factory=list)
|
||||
"""工具调用名称"""
|
||||
"""Tool call names."""
|
||||
tools_call_ids: list[str] = field(default_factory=list)
|
||||
"""工具调用 ID"""
|
||||
"""Tool call IDs."""
|
||||
reasoning_content: str = ""
|
||||
"""The reasoning content extracted from the LLM, if any."""
|
||||
|
||||
raw_completion: (
|
||||
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
|
||||
) = None
|
||||
_new_record: dict[str, Any] | None = None
|
||||
"""The raw completion response from the LLM provider."""
|
||||
|
||||
_completion_text: str = ""
|
||||
"""The plain text of the completion."""
|
||||
|
||||
is_chunk: bool = False
|
||||
"""是否是流式输出的单个 Chunk"""
|
||||
"""Indicates if the response is a chunked response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -234,7 +237,6 @@ class LLMResponse:
|
||||
| GenerateContentResponse
|
||||
| AnthropicMessage
|
||||
| None = None,
|
||||
_new_record: dict[str, Any] | None = None,
|
||||
is_chunk: bool = False,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
@@ -262,7 +264,6 @@ class LLMResponse:
|
||||
self.tools_call_name = tools_call_name
|
||||
self.tools_call_ids = tools_call_ids
|
||||
self.raw_completion = raw_completion
|
||||
self._new_record = _new_record
|
||||
self.is_chunk = is_chunk
|
||||
|
||||
@property
|
||||
|
||||
@@ -241,6 +241,8 @@ class ProviderManager:
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
@@ -396,7 +398,6 @@ class ProviderManager:
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.selected_default_persona,
|
||||
)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.db.po import Personality
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderMeta,
|
||||
@@ -52,15 +51,10 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality = default_persona
|
||||
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -25,12 +25,10 @@ class ProviderAnthropic(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
self.chosen_api_key: str = ""
|
||||
|
||||
@@ -20,12 +20,10 @@ class ProviderCoze(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -8,7 +8,7 @@ from dashscope.app.application_response import ApplicationResponse
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from .. import Personality, Provider
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
@@ -20,13 +20,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
Provider.__init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -18,12 +18,10 @@ class ProviderDify(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -53,12 +53,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_keys: list = super().get_keys()
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
@@ -326,8 +324,18 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
return gemini_contents
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
|
||||
"""Extract reasoning content from candidate parts"""
|
||||
if not candidate.content or not candidate.content.parts:
|
||||
return ""
|
||||
|
||||
thought_buf: list[str] = [
|
||||
(p.text or "") for p in candidate.content.parts if p.thought
|
||||
]
|
||||
return "".join(thought_buf).strip()
|
||||
|
||||
def _process_content_parts(
|
||||
self,
|
||||
candidate: types.Candidate,
|
||||
llm_response: LLMResponse,
|
||||
) -> MessageChain:
|
||||
@@ -358,6 +366,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
|
||||
raise Exception("API 返回的 candidate.content.parts 为空。")
|
||||
|
||||
# 提取 reasoning content
|
||||
reasoning = self._extract_reasoning_content(candidate)
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
|
||||
chain = []
|
||||
part: types.Part
|
||||
|
||||
@@ -515,6 +528,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
# Accumulate the complete response text for the final response
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
final_response = None
|
||||
|
||||
async for chunk in result:
|
||||
@@ -539,9 +553,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
yield llm_response
|
||||
return
|
||||
|
||||
_f = False
|
||||
|
||||
# 提取 reasoning content
|
||||
reasoning = self._extract_reasoning_content(chunk.candidates[0])
|
||||
if reasoning:
|
||||
_f = True
|
||||
accumulated_reasoning += reasoning
|
||||
llm_response.reasoning_content = reasoning
|
||||
if chunk.text:
|
||||
_f = True
|
||||
accumulated_text += chunk.text
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
if _f:
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
@@ -559,6 +583,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
if not final_response:
|
||||
final_response = LLMResponse("assistant", is_chunk=False)
|
||||
|
||||
# Set the complete accumulated reasoning in the final response
|
||||
if accumulated_reasoning:
|
||||
final_response.reasoning_content = accumulated_reasoning
|
||||
|
||||
# Set the complete accumulated text in the final response
|
||||
if accumulated_text:
|
||||
final_response.result_chain = MessageChain(
|
||||
|
||||
15
astrbot/core/provider/sources/groq_source.py
Normal file
15
astrbot/core/provider/sources/groq_source.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"groq_chat_completion", "Groq Chat Completion Provider Adapter"
|
||||
)
|
||||
class ProviderGroq(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.reasoning_key = "reasoning"
|
||||
@@ -4,12 +4,14 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
@@ -28,17 +30,8 @@ from ..register import register_provider_adapter
|
||||
"OpenAI API Chat Completion 提供商适配器",
|
||||
)
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
def __init__(self, provider_config, provider_settings) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: list = super().get_keys()
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
@@ -53,9 +46,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for key in self.custom_headers:
|
||||
self.custom_headers[key] = str(self.custom_headers[key])
|
||||
|
||||
# 适配 azure openai #332
|
||||
if "api_version" in provider_config:
|
||||
# 使用 azure api
|
||||
# Using Azure OpenAI API
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
api_version=provider_config.get("api_version", None),
|
||||
@@ -64,7 +56,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
else:
|
||||
# 使用 openai api
|
||||
# Using OpenAI Official API
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
@@ -80,6 +72,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
model = model_config.get("model", "unknown")
|
||||
self.set_model(model)
|
||||
|
||||
self.reasoning_key = "reasoning_content"
|
||||
|
||||
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
|
||||
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
|
||||
|
||||
@@ -157,7 +151,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
llm_response = await self.parse_openai_completion(completion, tools)
|
||||
llm_response = await self._parse_openai_completion(completion, tools)
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -210,36 +204,78 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
# 处理文本内容
|
||||
logger.debug(f"chunk delta: {delta}")
|
||||
# handle the content delta
|
||||
reasoning = self._extract_reasoning_content(chunk)
|
||||
_y = False
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
_y = True
|
||||
if delta.content:
|
||||
completion_text = delta.content
|
||||
llm_response.result_chain = MessageChain(
|
||||
chain=[Comp.Plain(completion_text)],
|
||||
)
|
||||
_y = True
|
||||
if _y:
|
||||
yield llm_response
|
||||
|
||||
final_completion = state.get_final_completion()
|
||||
llm_response = await self.parse_openai_completion(final_completion, tools)
|
||||
llm_response = await self._parse_openai_completion(final_completion, tools)
|
||||
|
||||
yield llm_response
|
||||
|
||||
async def parse_openai_completion(
|
||||
def _extract_reasoning_content(
|
||||
self,
|
||||
completion: ChatCompletion | ChatCompletionChunk,
|
||||
) -> str:
|
||||
"""Extract reasoning content from OpenAI ChatCompletion if available."""
|
||||
reasoning_text = ""
|
||||
if len(completion.choices) == 0:
|
||||
return reasoning_text
|
||||
if isinstance(completion, ChatCompletion):
|
||||
choice = completion.choices[0]
|
||||
reasoning_attr = getattr(choice.message, self.reasoning_key, None)
|
||||
if reasoning_attr:
|
||||
reasoning_text = str(reasoning_attr)
|
||||
elif isinstance(completion, ChatCompletionChunk):
|
||||
delta = completion.choices[0].delta
|
||||
reasoning_attr = getattr(delta, self.reasoning_key, None)
|
||||
if reasoning_attr:
|
||||
reasoning_text = str(reasoning_attr)
|
||||
return reasoning_text
|
||||
|
||||
async def _parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: ToolSet | None
|
||||
) -> LLMResponse:
|
||||
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||
"""Parse OpenAI ChatCompletion into LLMResponse"""
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
# parse the text completion
|
||||
if choice.message.content is not None:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
# specially, some providers may set <think> tags around reasoning content in the completion text,
|
||||
# we use regex to remove them, and store then in reasoning_content field
|
||||
reasoning_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
matches = reasoning_pattern.findall(completion_text)
|
||||
if matches:
|
||||
llm_response.reasoning_content = "\n".join(
|
||||
[match.strip() for match in matches],
|
||||
)
|
||||
completion_text = reasoning_pattern.sub("", completion_text).strip()
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
# parse the reasoning content if any
|
||||
# the priority is higher than the <think> tag extraction
|
||||
llm_response.reasoning_content = self._extract_reasoning_content(completion)
|
||||
|
||||
# parse tool calls if any
|
||||
if choice.message.tool_calls and tools is not None:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
tool_call_ids = []
|
||||
@@ -265,11 +301,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
llm_response.tools_call_ids = tool_call_ids
|
||||
|
||||
# specially handle finish reason
|
||||
if choice.finish_reason == "content_filter":
|
||||
raise Exception(
|
||||
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
|
||||
)
|
||||
|
||||
if llm_response.completion_text is None and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
@@ -12,10 +12,5 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
super().__init__(provider_config, provider_settings)
|
||||
|
||||
@@ -204,6 +204,8 @@ class ChatRoute(Route):
|
||||
):
|
||||
# 追加机器人消息
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
if "reasoning" in result:
|
||||
new_his["reasoning"] = result["reasoning"]
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
|
||||
@@ -146,21 +146,6 @@
|
||||
<span>Hello, I'm</span>
|
||||
<span class="bot-name">AstrBot ⭐</span>
|
||||
</div>
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.type') }}</span>
|
||||
<code>help</code>
|
||||
<span>{{ tm('shortcuts.help') }} 😊</span>
|
||||
</div>
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.longPress') }}</span>
|
||||
<code>Ctrl + B</code>
|
||||
<span>{{ tm('shortcuts.voiceRecord') }} 🎤</span>
|
||||
</div>
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.press') }}</span>
|
||||
<code>Ctrl + V</code>
|
||||
<span>{{ tm('shortcuts.pasteImage') }} 🏞️</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 输入区域 -->
|
||||
@@ -1031,17 +1016,26 @@ export default {
|
||||
"content": bot_resp
|
||||
});
|
||||
} else if (chunk_json.type === 'plain') {
|
||||
const chain_type = chunk_json.chain_type || 'normal';
|
||||
|
||||
if (!in_streaming) {
|
||||
message_obj = {
|
||||
type: 'bot',
|
||||
message: this.ref(chunk_json.data),
|
||||
message: this.ref(chain_type === 'reasoning' ? '' : chunk_json.data),
|
||||
reasoning: this.ref(chain_type === 'reasoning' ? chunk_json.data : ''),
|
||||
}
|
||||
this.messages.push({
|
||||
"content": message_obj
|
||||
});
|
||||
in_streaming = true;
|
||||
} else {
|
||||
message_obj.message.value += chunk_json.data;
|
||||
if (chain_type === 'reasoning') {
|
||||
// Append to reasoning content
|
||||
message_obj.reasoning.value += chunk_json.data;
|
||||
} else {
|
||||
// Append to normal message
|
||||
message_obj.message.value += chunk_json.data;
|
||||
}
|
||||
}
|
||||
} else if (chunk_json.type === 'update_title') {
|
||||
// 更新对话标题
|
||||
|
||||
@@ -37,6 +37,19 @@
|
||||
</v-avatar>
|
||||
<div class="bot-message-content">
|
||||
<div class="message-bubble bot-bubble">
|
||||
<!-- Reasoning Block (Collapsible) -->
|
||||
<div v-if="msg.content.reasoning && msg.content.reasoning.trim()" class="reasoning-container">
|
||||
<div class="reasoning-header" @click="toggleReasoning(index)">
|
||||
<v-icon size="small" class="reasoning-icon">
|
||||
{{ isReasoningExpanded(index) ? 'mdi-chevron-down' : 'mdi-chevron-right' }}
|
||||
</v-icon>
|
||||
<span class="reasoning-label">{{ tm('reasoning.thinking') }}</span>
|
||||
</div>
|
||||
<div v-if="isReasoningExpanded(index)" class="reasoning-content">
|
||||
<div v-html="md.render(msg.content.reasoning)" class="markdown-content reasoning-text"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Text -->
|
||||
<div v-if="msg.content.message && msg.content.message.trim()"
|
||||
v-html="md.render(msg.content.message)" class="markdown-content"></div>
|
||||
@@ -125,7 +138,8 @@ export default {
|
||||
copiedMessages: new Set(),
|
||||
isUserNearBottom: true,
|
||||
scrollThreshold: 1,
|
||||
scrollTimer: null
|
||||
scrollTimer: null,
|
||||
expandedReasoning: new Set(), // Track which reasoning blocks are expanded
|
||||
};
|
||||
},
|
||||
mounted() {
|
||||
@@ -142,6 +156,22 @@ export default {
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
// Toggle reasoning expansion state
|
||||
toggleReasoning(messageIndex) {
|
||||
if (this.expandedReasoning.has(messageIndex)) {
|
||||
this.expandedReasoning.delete(messageIndex);
|
||||
} else {
|
||||
this.expandedReasoning.add(messageIndex);
|
||||
}
|
||||
// Force reactivity
|
||||
this.expandedReasoning = new Set(this.expandedReasoning);
|
||||
},
|
||||
|
||||
// Check if reasoning is expanded
|
||||
isReasoningExpanded(messageIndex) {
|
||||
return this.expandedReasoning.has(messageIndex);
|
||||
},
|
||||
|
||||
// 复制代码到剪贴板
|
||||
copyCodeToClipboard(code) {
|
||||
navigator.clipboard.writeText(code).then(() => {
|
||||
@@ -348,7 +378,7 @@ export default {
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
to {
|
||||
@@ -539,6 +569,69 @@ export default {
|
||||
.fade-in {
|
||||
animation: fadeIn 0.3s ease-in-out;
|
||||
}
|
||||
|
||||
/* Reasoning 区块样式 */
|
||||
.reasoning-container {
|
||||
margin-bottom: 12px;
|
||||
margin-top: 6px;
|
||||
border: 1px solid var(--v-theme-border);
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
width: fit-content;
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-container {
|
||||
background-color: rgba(103, 58, 183, 0.08);
|
||||
}
|
||||
|
||||
.reasoning-header {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
padding: 8px 8px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s ease;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.reasoning-header:hover {
|
||||
background-color: rgba(103, 58, 183, 0.08);
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-header:hover {
|
||||
background-color: rgba(103, 58, 183, 0.15);
|
||||
}
|
||||
|
||||
.reasoning-icon {
|
||||
margin-right: 6px;
|
||||
color: var(--v-theme-secondary);
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.reasoning-label {
|
||||
font-size: 13px;
|
||||
font-weight: 500;
|
||||
color: var(--v-theme-secondary);
|
||||
letter-spacing: 0.3px;
|
||||
}
|
||||
|
||||
.reasoning-content {
|
||||
padding: 0px 12px;
|
||||
border-top: 1px solid var(--v-theme-border);
|
||||
color: gray;
|
||||
animation: fadeIn 0.2s ease-in-out;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.reasoning-text {
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
color: var(--v-theme-secondaryText);
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-text {
|
||||
opacity: 0.85;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style>
|
||||
|
||||
@@ -63,6 +63,9 @@
|
||||
"on": "Stream",
|
||||
"off": "Normal"
|
||||
},
|
||||
"reasoning": {
|
||||
"thinking": "Thinking Process"
|
||||
},
|
||||
"connection": {
|
||||
"title": "Connection Status Notice",
|
||||
"message": "The system detected that the chat connection needs to be re-established.",
|
||||
|
||||
@@ -63,6 +63,9 @@
|
||||
"on": "流式",
|
||||
"off": "普通"
|
||||
},
|
||||
"reasoning": {
|
||||
"thinking": "思考过程"
|
||||
},
|
||||
"connection": {
|
||||
"title": "连接状态提醒",
|
||||
"message": "系统检测到聊天连接需要重新建立。",
|
||||
|
||||
@@ -31,6 +31,7 @@ export function getProviderIcon(type) {
|
||||
'302ai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
|
||||
'microsoft': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/microsoft.svg',
|
||||
'vllm': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/vllm.svg',
|
||||
'groq': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/groq.svg',
|
||||
};
|
||||
return icons[type] || '';
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import traceback
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api.provider import LLMResponse, ProviderRequest
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.provider.sources.dify_source import ProviderDify
|
||||
|
||||
@@ -334,6 +334,17 @@ class Main(star.Star):
|
||||
except BaseException as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后基于配置注入思考过程文本"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||
if show_reasoning and resp.reasoning_content:
|
||||
resp.completion_text = (
|
||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||
)
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_llm_req(self, event: AstrMessageEvent):
|
||||
"""在 LLM 请求后记录对话"""
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.provider import LLMResponse
|
||||
from astrbot.api.star import Context, Star
|
||||
|
||||
try:
|
||||
# 谨慎引入,避免在未安装 google-genai 的环境下报错
|
||||
from google.genai.types import GenerateContentResponse
|
||||
except Exception: # pragma: no cover - 兼容无此依赖的运行环境
|
||||
GenerateContentResponse = None # type: ignore
|
||||
|
||||
|
||||
class R1Filter(Star):
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context)
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def resp(self, event: AstrMessageEvent, response: LLMResponse):
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin).get(
|
||||
"provider_settings",
|
||||
{},
|
||||
)
|
||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||
|
||||
# --- Gemini: 过滤/展示 thought:true 片段 ---
|
||||
# Gemini 可能在 parts 中注入 {"thought": true, "text": "..."}
|
||||
# 官方 SDK 默认不会返回此字段。
|
||||
if GenerateContentResponse is not None and isinstance(
|
||||
response.raw_completion,
|
||||
GenerateContentResponse,
|
||||
):
|
||||
thought_text, answer_text = self._extract_gemini_texts(
|
||||
response.raw_completion,
|
||||
)
|
||||
|
||||
if thought_text or answer_text:
|
||||
# 有明确的思考/正文分离信号,则按配置处理
|
||||
if show_reasoning:
|
||||
merged = (
|
||||
(f"🤔思考:{thought_text}\n\n" if thought_text else "")
|
||||
+ (answer_text or "")
|
||||
).strip()
|
||||
if merged:
|
||||
response.completion_text = merged
|
||||
return
|
||||
# 默认隐藏思考内容,仅保留正文
|
||||
elif answer_text:
|
||||
response.completion_text = answer_text
|
||||
return
|
||||
|
||||
# --- 非 Gemini 或无明确 thought:true 情况 ---
|
||||
if show_reasoning:
|
||||
# 显示推理内容的处理逻辑
|
||||
if (
|
||||
response
|
||||
and response.raw_completion
|
||||
and isinstance(response.raw_completion, ChatCompletion)
|
||||
and len(response.raw_completion.choices) > 0
|
||||
and response.raw_completion.choices[0].message
|
||||
):
|
||||
message = response.raw_completion.choices[0].message
|
||||
reasoning_content = "" # 初始化 reasoning_content
|
||||
|
||||
# 检查 Groq deepseek-r1-distill-llama-70b 模型的 'reasoning' 属性
|
||||
if hasattr(message, "reasoning") and message.reasoning:
|
||||
reasoning_content = message.reasoning
|
||||
# 检查 DeepSeek deepseek-reasoner 模型的 'reasoning_content'
|
||||
elif (
|
||||
hasattr(message, "reasoning_content") and message.reasoning_content
|
||||
):
|
||||
reasoning_content = message.reasoning_content
|
||||
|
||||
if reasoning_content:
|
||||
response.completion_text = (
|
||||
f"🤔思考:{reasoning_content}\n\n{message.content}"
|
||||
)
|
||||
else:
|
||||
response.completion_text = message.content
|
||||
else:
|
||||
# 过滤推理标签的处理逻辑
|
||||
completion_text = response.completion_text
|
||||
|
||||
# 检查并移除 <think> 标签
|
||||
if r"<think>" in completion_text or r"</think>" in completion_text:
|
||||
# 移除配对的标签及其内容
|
||||
completion_text = re.sub(
|
||||
r"<think>.*?</think>",
|
||||
"",
|
||||
completion_text,
|
||||
flags=re.DOTALL,
|
||||
).strip()
|
||||
|
||||
# 移除可能残留的单个标签
|
||||
completion_text = (
|
||||
completion_text.replace(r"<think>", "")
|
||||
.replace(r"</think>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
response.completion_text = completion_text
|
||||
|
||||
# ------------------------
|
||||
# helpers
|
||||
# ------------------------
|
||||
def _get_part_dict(self, p: Any) -> dict:
|
||||
"""优先使用 SDK 标准序列化方法获取字典,失败则逐级回退。
|
||||
|
||||
顺序: model_dump → model_dump_json → json → to_dict → dict → __dict__。
|
||||
"""
|
||||
for getter in ("model_dump", "model_dump_json", "json", "to_dict", "dict"):
|
||||
fn = getattr(p, getter, None)
|
||||
if callable(fn):
|
||||
try:
|
||||
result = fn()
|
||||
if isinstance(result, (str, bytes)):
|
||||
try:
|
||||
if isinstance(result, bytes):
|
||||
result = result.decode("utf-8", "ignore")
|
||||
return json.loads(result) or {}
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
f"Unexpected error when calling {getter} on {type(p).__name__}: {e}",
|
||||
)
|
||||
continue
|
||||
try:
|
||||
d = getattr(p, "__dict__", None)
|
||||
if isinstance(d, dict):
|
||||
return d
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
f"Unexpected error when accessing __dict__ on {type(p).__name__}: {e}",
|
||||
)
|
||||
return {}
|
||||
|
||||
def _is_thought_part(self, p: Any) -> bool:
|
||||
"""判断是否为思考片段。
|
||||
|
||||
规则:
|
||||
1) 直接 thought 属性
|
||||
2) 字典字段 thought 或 metadata.thought
|
||||
3) data/raw/extra/_raw 中嵌入的 JSON 串包含 thought: true
|
||||
"""
|
||||
try:
|
||||
if getattr(p, "thought", False):
|
||||
return True
|
||||
except Exception:
|
||||
# best-effort
|
||||
pass
|
||||
|
||||
d = self._get_part_dict(p)
|
||||
if d.get("thought") is True:
|
||||
return True
|
||||
meta = d.get("metadata")
|
||||
if isinstance(meta, dict) and meta.get("thought") is True:
|
||||
return True
|
||||
for k in ("data", "raw", "extra", "_raw"):
|
||||
v = d.get(k)
|
||||
if isinstance(v, (str, bytes)):
|
||||
try:
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode("utf-8", "ignore")
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, dict) and parsed.get("thought") is True:
|
||||
return True
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return False
|
||||
|
||||
def _extract_gemini_texts(self, resp: Any) -> tuple[str, str]:
|
||||
"""从 GenerateContentResponse 中提取 (思考文本, 正文文本)。"""
|
||||
try:
|
||||
cand0 = next(iter(getattr(resp, "candidates", []) or []), None)
|
||||
if not cand0:
|
||||
return "", ""
|
||||
content = getattr(cand0, "content", None)
|
||||
parts = getattr(content, "parts", None) or []
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
return "", ""
|
||||
|
||||
thought_buf: list[str] = []
|
||||
answer_buf: list[str] = []
|
||||
for p in parts:
|
||||
txt = getattr(p, "text", None)
|
||||
if txt is None:
|
||||
continue
|
||||
txt_str = str(txt).strip()
|
||||
if not txt_str:
|
||||
continue
|
||||
if self._is_thought_part(p):
|
||||
thought_buf.append(txt_str)
|
||||
else:
|
||||
answer_buf.append(txt_str)
|
||||
|
||||
return "\n".join(thought_buf).strip(), "\n".join(answer_buf).strip()
|
||||
@@ -1,5 +0,0 @@
|
||||
name: thinking_filter
|
||||
desc: 可选择是否过滤推理模型的思考内容
|
||||
author: Soulter
|
||||
version: 1.0.0
|
||||
repo: https://astrbot.app
|
||||
Reference in New Issue
Block a user