Compare commits
1 Commits
feat/file-
...
perf/provi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
26467fbe22 |
@@ -76,11 +76,6 @@ DEFAULT_CONFIG = {
|
||||
"reachability_check": False,
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
"file_extract": {
|
||||
"enable": False,
|
||||
"provider": "moonshotai",
|
||||
"moonshotai_api_key": "",
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -2074,20 +2069,6 @@ CONFIG_METADATA_2 = {
|
||||
"tool_call_timeout": {
|
||||
"type": "int",
|
||||
},
|
||||
"file_extract": {
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"type": "bool",
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
},
|
||||
"moonshotai_api_key": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
@@ -2422,36 +2403,6 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"file_extract": {
|
||||
"description": "文档解析能力",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.file_extract.enable": {
|
||||
"description": "启用文档解析能力",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.file_extract.provider": {
|
||||
"description": "文档解析提供商",
|
||||
"type": "string",
|
||||
"options": ["moonshotai"],
|
||||
"condition": {
|
||||
"provider_settings.file_extract.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.file_extract.moonshotai_api_key": {
|
||||
"description": "Moonshot AI API Key",
|
||||
"type": "string",
|
||||
"condition": {
|
||||
"provider_settings.file_extract.provider": "moonshotai",
|
||||
"provider_settings.file_extract.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
|
||||
@@ -722,12 +722,7 @@ class File(BaseMessageComponent):
|
||||
"""下载文件"""
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
if self.name:
|
||||
name, ext = os.path.splitext(self.name)
|
||||
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
else:
|
||||
filename = f"{uuid.uuid4().hex}"
|
||||
file_path = os.path.join(download_dir, filename)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from astrbot.core import logger
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import File, Image, Reply
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
@@ -22,7 +22,6 @@ from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -57,13 +56,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||
|
||||
file_extract_conf: dict = settings.get("file_extract", {})
|
||||
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
|
||||
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
|
||||
self.file_extract_msh_api_key: str = file_extract_conf.get(
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -122,50 +114,6 @@ class InternalAgentSubStage(Stage):
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
async def _apply_file_extract(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply file extract to the provider request"""
|
||||
file_paths = []
|
||||
file_names = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
file_paths.append(await comp.get_file())
|
||||
file_names.append(comp.name)
|
||||
elif isinstance(comp, Reply) and comp.chain:
|
||||
for reply_comp in comp.chain:
|
||||
if isinstance(reply_comp, File):
|
||||
file_paths.append(await reply_comp.get_file())
|
||||
file_names.append(reply_comp.name)
|
||||
if not file_paths:
|
||||
return
|
||||
if not req.prompt:
|
||||
req.prompt = "总结一下文件里面讲了什么?"
|
||||
if self.file_extract_prov == "moonshotai":
|
||||
if not self.file_extract_msh_api_key:
|
||||
logger.error("Moonshot AI API key for file extract is not set")
|
||||
return
|
||||
file_contents = await asyncio.gather(
|
||||
*[
|
||||
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
|
||||
for file_path in file_paths
|
||||
]
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
|
||||
return
|
||||
|
||||
# add file extract results to contexts
|
||||
for file_content, file_name in zip(file_contents, file_names):
|
||||
req.contexts.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
@@ -398,17 +346,6 @@ class InternalAgentSubStage(Stage):
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
@@ -419,6 +356,10 @@ class InternalAgentSubStage(Stage):
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
|
||||
@@ -246,13 +246,7 @@ class AiocqhttpAdapter(Platform):
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
# 检查多个可能的文件名字段
|
||||
file_name = (
|
||||
m["data"].get("file_name", "")
|
||||
or m["data"].get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or "file"
|
||||
)
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
@@ -271,14 +265,7 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
# 优先从 API 返回值获取文件名,其次从原始消息数据获取
|
||||
file_name = (
|
||||
ret.get("file_name", "")
|
||||
or ret.get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or m["data"].get("file_name", "")
|
||||
)
|
||||
a = File(name=file_name, url=file_url)
|
||||
a = File(name="", url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
|
||||
@@ -381,9 +381,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
f"Telegram document file_path is None, cannot save the file {file_name}.",
|
||||
)
|
||||
else:
|
||||
message.message.append(
|
||||
Comp.File(file=file_path, name=file_name, url=file_path)
|
||||
)
|
||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
async def extract_file_moonshotai(file_path: str, api_key: str) -> str:
|
||||
"""Extract text from a file using Moonshot AI API"""
|
||||
"""
|
||||
Args:
|
||||
file_path: The path to the file to extract text from
|
||||
api_key: The API key to use to extract text from the file
|
||||
Returns:
|
||||
The text extracted from the file
|
||||
"""
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.moonshot.cn/v1",
|
||||
)
|
||||
file_object = await client.files.create(
|
||||
file=Path(file_path),
|
||||
purpose="file-extract", # type: ignore
|
||||
)
|
||||
return (await client.files.content(file_id=file_object.id)).text
|
||||
@@ -109,22 +109,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"file_extract": {
|
||||
"description": "File Extract",
|
||||
"provider_settings": {
|
||||
"file_extract": {
|
||||
"enable": {
|
||||
"description": "Enable File Extract"
|
||||
},
|
||||
"provider": {
|
||||
"description": "File Extract Provider"
|
||||
},
|
||||
"moonshotai_api_key": {
|
||||
"description": "Moonshot AI API Key"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "Other Settings",
|
||||
"provider_settings": {
|
||||
|
||||
@@ -11,12 +11,7 @@
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"description": "执行器",
|
||||
"labels": [
|
||||
"内置 Agent",
|
||||
"Dify",
|
||||
"Coze",
|
||||
"阿里云百炼应用"
|
||||
]
|
||||
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"]
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent 执行器提供商 ID"
|
||||
@@ -114,22 +109,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"file_extract": {
|
||||
"description": "文档解析能力",
|
||||
"provider_settings": {
|
||||
"file_extract": {
|
||||
"enable": {
|
||||
"description": "启用文档解析能力"
|
||||
},
|
||||
"provider": {
|
||||
"description": "文档解析提供商"
|
||||
},
|
||||
"moonshotai_api_key": {
|
||||
"description": "Moonshot AI API Key"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"provider_settings": {
|
||||
@@ -163,10 +142,7 @@
|
||||
"unsupported_streaming_strategy": {
|
||||
"description": "不支持流式回复的平台",
|
||||
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
|
||||
"labels": [
|
||||
"实时分段回复",
|
||||
"关闭流式回复"
|
||||
]
|
||||
"labels": ["实时分段回复", "关闭流式回复"]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
@@ -481,4 +457,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import At, Image, Plain
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
|
||||
"""
|
||||
@@ -158,12 +158,8 @@ class LongTermMemory:
|
||||
cfg = self.cfg(event)
|
||||
if cfg["enable_active_reply"]:
|
||||
prompt = req.prompt
|
||||
req.prompt = (
|
||||
f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
f"\nNow, a new message is coming: `{prompt}`. "
|
||||
"Please react to it. Only output your response and do not output any other information. "
|
||||
"You MUST use the SAME language as the chatroom is using."
|
||||
)
|
||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
||||
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
||||
else:
|
||||
req.system_prompt += (
|
||||
@@ -171,15 +167,13 @@ class LongTermMemory:
|
||||
)
|
||||
req.system_prompt += chats_str
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
|
||||
async def after_req_llm(self, event: AstrMessageEvent):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
if llm_resp.completion_text:
|
||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}"
|
||||
logger.debug(
|
||||
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
|
||||
)
|
||||
if event.get_result() and event.get_result().is_llm_result():
|
||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
cfg = self.cfg(event)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
||||
|
||||
@@ -322,7 +322,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
|
||||
"""在 LLM 响应后基于配置注入思考过程文本"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||
@@ -331,9 +331,12 @@ class Main(star.Star):
|
||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||
)
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_llm_req(self, event: AstrMessageEvent):
|
||||
"""在 LLM 请求后记录对话"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.after_req_llm(event, resp)
|
||||
await self.ltm.after_req_llm(event)
|
||||
except Exception as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user