Compare commits
3 Commits
fix/llm-me
...
fix/3826-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
52a8fc0152 | ||
|
|
6bb695f850 | ||
|
|
c13c51f499 |
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -145,23 +145,39 @@ class Message(BaseModel):
|
||||
"tool",
|
||||
]
|
||||
|
||||
content: str | list[ContentPart]
|
||||
content: str | list[ContentPart] | None = None
|
||||
"""The content of the message."""
|
||||
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
"""The tool calls of the message."""
|
||||
|
||||
tool_call_id: str | None = None
|
||||
"""The ID of the tool call."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content_required(self):
|
||||
# assistant + tool_calls is not None: allow content to be None
|
||||
if self.role == "assistant" and self.tool_calls is not None:
|
||||
return self
|
||||
|
||||
# other all cases: content is required
|
||||
if self.content is None:
|
||||
raise ValueError(
|
||||
"content is required unless role='assistant' and tool_calls is not None"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: str | list[ContentPart] | None = None
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
|
||||
|
||||
class ToolCallMessageSegment(Message):
|
||||
"""A message segment representing a tool call."""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class UserMessageSegment(Message):
|
||||
|
||||
@@ -8,7 +8,7 @@ from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import At, Image, Plain
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
|
||||
"""
|
||||
@@ -158,8 +158,12 @@ class LongTermMemory:
|
||||
cfg = self.cfg(event)
|
||||
if cfg["enable_active_reply"]:
|
||||
prompt = req.prompt
|
||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
||||
req.prompt = (
|
||||
f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
f"\nNow, a new message is coming: `{prompt}`. "
|
||||
"Please react to it. Only output your response and do not output any other information. "
|
||||
"You MUST use the SAME language as the chatroom is using."
|
||||
)
|
||||
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
||||
else:
|
||||
req.system_prompt += (
|
||||
@@ -167,13 +171,15 @@ class LongTermMemory:
|
||||
)
|
||||
req.system_prompt += chats_str
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent):
|
||||
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
if event.get_result() and event.get_result().is_llm_result():
|
||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
if llm_resp.completion_text:
|
||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}"
|
||||
logger.debug(
|
||||
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
|
||||
)
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
cfg = self.cfg(event)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
||||
|
||||
@@ -322,7 +322,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后基于配置注入思考过程文本"""
|
||||
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||
@@ -331,12 +331,9 @@ class Main(star.Star):
|
||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||
)
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_llm_req(self, event: AstrMessageEvent):
|
||||
"""在 LLM 请求后记录对话"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.after_req_llm(event)
|
||||
await self.ltm.after_req_llm(event, resp)
|
||||
except Exception as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user