Compare commits
3 Commits
fix/llm-me
...
fix/3826-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
52a8fc0152 | ||
|
|
6bb695f850 | ||
|
|
c13c51f499 |
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from typing import Any, ClassVar, Literal, cast
|
from typing import Any, ClassVar, Literal, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
|
||||||
@@ -145,23 +145,39 @@ class Message(BaseModel):
|
|||||||
"tool",
|
"tool",
|
||||||
]
|
]
|
||||||
|
|
||||||
content: str | list[ContentPart]
|
content: str | list[ContentPart] | None = None
|
||||||
"""The content of the message."""
|
"""The content of the message."""
|
||||||
|
|
||||||
|
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||||
|
"""The tool calls of the message."""
|
||||||
|
|
||||||
|
tool_call_id: str | None = None
|
||||||
|
"""The ID of the tool call."""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_content_required(self):
|
||||||
|
# assistant + tool_calls is not None: allow content to be None
|
||||||
|
if self.role == "assistant" and self.tool_calls is not None:
|
||||||
|
return self
|
||||||
|
|
||||||
|
# other all cases: content is required
|
||||||
|
if self.content is None:
|
||||||
|
raise ValueError(
|
||||||
|
"content is required unless role='assistant' and tool_calls is not None"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class AssistantMessageSegment(Message):
|
class AssistantMessageSegment(Message):
|
||||||
"""A message segment from the assistant."""
|
"""A message segment from the assistant."""
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: str | list[ContentPart] | None = None
|
|
||||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCallMessageSegment(Message):
|
class ToolCallMessageSegment(Message):
|
||||||
"""A message segment representing a tool call."""
|
"""A message segment representing a tool call."""
|
||||||
|
|
||||||
role: Literal["tool"] = "tool"
|
role: Literal["tool"] = "tool"
|
||||||
tool_call_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class UserMessageSegment(Message):
|
class UserMessageSegment(Message):
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from astrbot.api import star
|
|||||||
from astrbot.api.event import AstrMessageEvent
|
from astrbot.api.event import AstrMessageEvent
|
||||||
from astrbot.api.message_components import At, Image, Plain
|
from astrbot.api.message_components import At, Image, Plain
|
||||||
from astrbot.api.platform import MessageType
|
from astrbot.api.platform import MessageType
|
||||||
from astrbot.api.provider import Provider, ProviderRequest
|
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
|
||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -158,8 +158,12 @@ class LongTermMemory:
|
|||||||
cfg = self.cfg(event)
|
cfg = self.cfg(event)
|
||||||
if cfg["enable_active_reply"]:
|
if cfg["enable_active_reply"]:
|
||||||
prompt = req.prompt
|
prompt = req.prompt
|
||||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
req.prompt = (
|
||||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||||
|
f"\nNow, a new message is coming: `{prompt}`. "
|
||||||
|
"Please react to it. Only output your response and do not output any other information. "
|
||||||
|
"You MUST use the SAME language as the chatroom is using."
|
||||||
|
)
|
||||||
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
||||||
else:
|
else:
|
||||||
req.system_prompt += (
|
req.system_prompt += (
|
||||||
@@ -167,13 +171,15 @@ class LongTermMemory:
|
|||||||
)
|
)
|
||||||
req.system_prompt += chats_str
|
req.system_prompt += chats_str
|
||||||
|
|
||||||
async def after_req_llm(self, event: AstrMessageEvent):
|
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
|
||||||
if event.unified_msg_origin not in self.session_chats:
|
if event.unified_msg_origin not in self.session_chats:
|
||||||
return
|
return
|
||||||
|
|
||||||
if event.get_result() and event.get_result().is_llm_result():
|
if llm_resp.completion_text:
|
||||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}"
|
||||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
logger.debug(
|
||||||
|
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
|
||||||
|
)
|
||||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||||
cfg = self.cfg(event)
|
cfg = self.cfg(event)
|
||||||
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ class Main(star.Star):
|
|||||||
|
|
||||||
@filter.on_llm_response()
|
@filter.on_llm_response()
|
||||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||||
"""在 LLM 响应后基于配置注入思考过程文本"""
|
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
|
||||||
umo = event.unified_msg_origin
|
umo = event.unified_msg_origin
|
||||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||||
@@ -331,12 +331,9 @@ class Main(star.Star):
|
|||||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@filter.after_message_sent()
|
|
||||||
async def after_llm_req(self, event: AstrMessageEvent):
|
|
||||||
"""在 LLM 请求后记录对话"""
|
|
||||||
if self.ltm and self.ltm_enabled(event):
|
if self.ltm and self.ltm_enabled(event):
|
||||||
try:
|
try:
|
||||||
await self.ltm.after_req_llm(event)
|
await self.ltm.after_req_llm(event, resp)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"ltm: {e}")
|
logger.error(f"ltm: {e}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user