diff --git a/.gitignore b/.gitignore index a863e36e..a3b2aad9 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,5 @@ venv/* packages/python_interpreter/workplace .venv/* .conda/ +.idea +pytest.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c38647e..4dece714 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ ci: autoupdate_commit_msg: ":balloon: pre-commit autoupdate" repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.10 + rev: v0.11.0 hooks: - id: ruff - id: ruff-format diff --git a/README.md b/README.md index b4d97fbb..d2cedd6e 100644 --- a/README.md +++ b/README.md @@ -10,14 +10,13 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_ Soulter%2FAstrBot | Trendshift -[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest) -python -Docker pull -Static Badge -[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) -![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=60) -[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot) -[![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg)](https://gitcode.com/Soulter/AstrBot) +[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest) +python +Docker pull +Static Badge +[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) +![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=60&style=for-the-badge&color=3b618e) +[![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot) English日本語 | @@ -27,6 +26,8 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。 +[![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg?style=for-the-badge)](https://gitcode.com/Soulter/AstrBot) + ## ✨ 主要功能 1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。 diff --git a/astrbot/api/platform/__init__.py b/astrbot/api/platform/__init__.py index dcc02bb4..5a98c590 100644 --- a/astrbot/api/platform/__init__.py +++ b/astrbot/api/platform/__init__.py @@ -5,6 +5,7 @@ from astrbot.core.platform import ( MessageMember, MessageType, PlatformMetadata, + Group, ) from astrbot.core.platform.register import register_platform_adapter @@ -18,4 +19,5 @@ __all__ = [ "MessageType", "PlatformMetadata", "register_platform_adapter", + "Group", ] diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index ba87ba52..7db7ff5f 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,7 +2,7 @@ 如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 """ -VERSION = "3.4.37" +VERSION = "3.4.39" DB_PATH = "data/data_v3.db" # 默认配置 @@ -85,6 +85,7 @@ DEFAULT_CONFIG = { "enable": True, "username": "astrbot", "password": "77b90590a8945a7d36c963981a307dc9", + "host": "0.0.0.0", "port": 6185, }, "platform": [], @@ -122,6 +123,7 @@ CONFIG_METADATA_2 = { "enable": False, "appid": "", "secret": "", + "callback_server_host": "0.0.0.0", "port": 6196, }, "aiocqhttp(OneBotv11)": { @@ -146,10 +148,11 @@ CONFIG_METADATA_2 = { "enable": False, "corpid": "", "secret": "", - "port": 6195, "token": "", "encoding_aes_key": "", "api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/", + "callback_server_host": "0.0.0.0", + "port": 6195, }, "lark(飞书)": { "id": "lark", @@ -220,7 +223,7 @@ CONFIG_METADATA_2 = { "hint": "启用后,机器人可以接收到频道的私聊消息。", }, "ws_reverse_host": { - "description": "反向 Websocket 主机地址", + "description": "反向 Websocket 主机地址(AstrBot 为服务器端)", "type": "string", "hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。", }, @@ -578,7 +581,7 @@ CONFIG_METADATA_2 = { "dify_api_type": "chat", "dify_api_key": "", "dify_api_base": "https://api.dify.ai/v1", - "dify_workflow_output_key": "", + "dify_workflow_output_key": "astrbot_wf_output", "dify_query_input_key": "astrbot_text_query", "variables": {}, "timeout": 60, @@ -590,6 +593,11 @@ CONFIG_METADATA_2 = { "dashscope_app_type": "agent", "dashscope_api_key": "", "dashscope_app_id": "", + "rag_options": { + "pipeline_ids": [], + "file_ids": [], + "output_reference": False, + }, "variables": {}, "timeout": 60, }, @@ -662,6 +670,30 @@ CONFIG_METADATA_2 = { }, }, "items": { + "rag_options": { + "description": "RAG 选项", + "type": "object", + "hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)", + "items": { + "pipeline_ids": { + "description": "知识库 ID 列表", + "type": "list", + "items": {"type": "string"}, + "hint": "对指定知识库内所有文档进行检索, 前往 https://bailian.console.aliyun.com/ 数据应用->知识索引创建和获取 ID。", + }, + "file_ids": { + "description": "非结构化文档 ID, 传入该参数将对指定非结构化文档进行检索。", + "type": "list", + "items": {"type": "string"}, + "hint": "对指定非结构化文档进行检索。前往 https://bailian.console.aliyun.com/ 数据管理创建和获取 ID。", + }, + "output_reference": { + "description": "是否输出知识库/文档的引用", + "type": "bool", + "hint": "在每次回答尾部加上引用源。默认为 False。", + }, + }, + }, "sensevoice_hint": { "description": "部署SenseVoice", "type": "string", @@ -678,12 +710,14 @@ CONFIG_METADATA_2 = { "type": "string", "hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。", }, - # "variables": { - # "description": "工作流固定输入变量", - # "type": "object", - # "obvious_hint": True, - # "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", - # }, + "variables": { + "description": "工作流固定输入变量", + "type": "object", + "obvious_hint": True, + "items": {}, + "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", + "invisible": True, + }, # "fastgpt_app_type": { # "description": "应用类型", # "type": "string", @@ -694,7 +728,7 @@ CONFIG_METADATA_2 = { "dashscope_app_type": { "description": "应用类型", "type": "string", - "hint": "阿里云百炼应用的应用类型。", + "hint": "百炼应用的应用类型。", "options": [ "agent", "agent-arrange", diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 7f73e5d9..64c324a9 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -25,9 +25,11 @@ SOFTWARE. import base64 import json import os +import uuid import typing as T from enum import Enum from pydantic.v1 import BaseModel +from astrbot.core.utils.io import download_image_by_url, file_to_base64 class ComponentType(Enum): @@ -146,6 +148,51 @@ class Record(BaseMessageComponent): return Record(file=url, **_) raise Exception("not a valid url") + async def convert_to_file_path(self) -> str: + """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 + + Returns: + str: 语音的本地路径,以绝对路径表示。 + """ + if self.file and self.file.startswith("file:///"): + file_path = self.file[8:] + return file_path + elif self.file and self.file.startswith("http"): + file_path = await download_image_by_url(self.file) + return os.path.abspath(file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file.removeprefix("base64://") + image_bytes = base64.b64decode(bs64_data) + file_path = f"data/temp/{uuid.uuid4()}.jpg" + with open(file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(file_path) + elif os.path.exists(self.file): + file_path = self.file + return os.path.abspath(file_path) + else: + raise Exception(f"not a valid file: {self.file}") + + async def convert_to_base64(self) -> str: + """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 + + Returns: + str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ + # convert to base64 + if self.file and self.file.startswith("file:///"): + bs64_data = file_to_base64(self.file[8:]) + elif self.file and self.file.startswith("http"): + file_path = await download_image_by_url(self.file) + bs64_data = file_to_base64(file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file + elif os.path.exists(self.file): + bs64_data = file_to_base64(self.file) + else: + raise Exception(f"not a valid file: {self.file}") + return bs64_data + class Video(BaseMessageComponent): type: ComponentType = "Video" @@ -279,10 +326,6 @@ class Image(BaseMessageComponent): file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 def __init__(self, file: T.Optional[str], **_): - # for k in _.keys(): - # if (k == "_type" and _[k] not in ["flash", "show", None]) or \ - # (k == "c" and _[k] not in [2, 3]): - # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") super().__init__(file=file, **_) @staticmethod @@ -307,6 +350,53 @@ class Image(BaseMessageComponent): def fromIO(IO): return Image.fromBytes(IO.read()) + async def convert_to_file_path(self) -> str: + """将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。 + + Returns: + str: 图片的本地路径,以绝对路径表示。 + """ + url = self.url if self.url else self.file + if url and url.startswith("file:///"): + image_file_path = url[8:] + return image_file_path + elif url and url.startswith("http"): + image_file_path = await download_image_by_url(url) + return os.path.abspath(image_file_path) + elif url and url.startswith("base64://"): + bs64_data = url.removeprefix("base64://") + image_bytes = base64.b64decode(bs64_data) + image_file_path = f"data/temp/{uuid.uuid4()}.jpg" + with open(image_file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(image_file_path) + elif os.path.exists(url): + image_file_path = url + return os.path.abspath(image_file_path) + else: + raise Exception(f"not a valid file: {url}") + + async def convert_to_base64(self) -> str: + """将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。 + + Returns: + str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ + # convert to base64 + url = self.url if self.url else self.file + if url and url.startswith("file:///"): + bs64_data = file_to_base64(url[8:]) + elif url and url.startswith("http"): + image_file_path = await download_image_by_url(url) + bs64_data = file_to_base64(image_file_path) + elif url and url.startswith("base64://"): + bs64_data = url + elif os.path.exists(url): + bs64_data = file_to_base64(url) + else: + raise Exception(f"not a valid file: {url}") + return bs64_data + class Reply(BaseMessageComponent): type: ComponentType = "Reply" diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 89aff17a..4cc7fb84 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -77,6 +77,10 @@ class MessageChain: self.use_t2i_ = use_t2i return self + def get_plain_text(self) -> str: + """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" + return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) + class EventResultType(enum.Enum): """用于描述事件处理的结果类型。 @@ -147,9 +151,5 @@ class MessageEventResult(MessageChain): """是否为 LLM 结果。""" return self.result_content_type == ResultContentType.LLM_RESULT - def get_plain_text(self) -> str: - """获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" - return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - CommandResult = MessageEventResult diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 22d67d32..210e62a7 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -64,8 +64,8 @@ class LLMRequestSubStage(Stage): req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: if isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - req.image_urls.append(image_url) + image_path = await comp.convert_to_file_path() + req.image_urls.append(image_path) # 获取对话上下文 conversation_id = await self.conv_manager.get_curr_conversation_id( @@ -148,11 +148,18 @@ class LLMRequestSubStage(Stage): if llm_response.role == "assistant": # text completion - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT) - ) + if llm_response.result_chain: + event.set_result( + MessageEventResult( + chain=llm_response.result_chain.chain + ).set_result_content_type(ResultContentType.LLM_RESULT) + ) + else: + event.set_result( + MessageEventResult() + .message(llm_response.completion_text) + .set_result_content_type(ResultContentType.LLM_RESULT) + ) elif llm_response.role == "err": event.set_result( MessageEventResult().message( diff --git a/astrbot/core/platform/__init__.py b/astrbot/core/platform/__init__.py index 48ea57b8..4007b2d9 100644 --- a/astrbot/core/platform/__init__.py +++ b/astrbot/core/platform/__init__.py @@ -1,7 +1,7 @@ from .platform import Platform from .astr_message_event import AstrMessageEvent from .platform_metadata import PlatformMetadata -from .astrbot_message import AstrBotMessage, MessageMember, MessageType +from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group __all__ = [ "Platform", @@ -10,4 +10,5 @@ __all__ = [ "AstrBotMessage", "MessageMember", "MessageType", + "Group", ] diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index fceb63ce..3e1b14ee 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,11 +1,9 @@ import abc import asyncio from dataclasses import dataclass -from .astrbot_message import AstrBotMessage -from .platform_metadata import PlatformMetadata -from astrbot.core.message.message_event_result import MessageEventResult, MessageChain -from astrbot.core.platform.message_type import MessageType -from typing import List, Union +from typing import List, Union, Optional + +from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( Plain, Image, @@ -16,9 +14,12 @@ from astrbot.core.message.components import ( Forward, Reply, ) -from astrbot.core.utils.metrics import Metric +from astrbot.core.message.message_event_result import MessageEventResult, MessageChain +from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.entites import ProviderRequest -from astrbot.core.db.po import Conversation +from astrbot.core.utils.metrics import Metric +from .astrbot_message import AstrBotMessage, Group +from .platform_metadata import PlatformMetadata @dataclass @@ -201,15 +202,6 @@ class AstrMessageEvent(abc.ABC): """ return self.role == "admin" - async def send(self, message: MessageChain): - """ - 发送消息到消息平台。 - """ - asyncio.create_task( - Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) - ) - self._has_send_oper = True - async def _pre_send(self): """调度器会在执行 send() 前调用该方法""" @@ -371,3 +363,26 @@ class AstrMessageEvent(abc.ABC): system_prompt=system_prompt, conversation=conversation, ) + + """平台适配器""" + + async def send(self, message: MessageChain): + """发送消息到消息平台。 + + Args: + message (MessageChain): 消息链,具体使用方式请参考文档。 + """ + asyncio.create_task( + Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) + ) + self._has_send_oper = True + + async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]: + """获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。 + + 适配情况: + + - gewechat + - aiocqhttp(OneBotv11) + """ + ... diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index ea55eaf4..e7bd4bd9 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -10,6 +10,41 @@ class MessageMember: user_id: str # 发送者id nickname: str = None + def __str__(self): + # 使用 f-string 来构建返回的字符串表示形式 + return ( + f"User ID: {self.user_id}," + f"Nickname: {self.nickname if self.nickname else 'N/A'}" + ) + + +@dataclass +class Group: + group_id: str + """群号""" + group_name: str = None + """群名称""" + group_avatar: str = None + """群头像""" + group_owner: str = None + """群主 id""" + group_admins: List[str] = None + """群管理员 id""" + members: List[MessageMember] = None + """所有群成员""" + + def __str__(self): + # 使用 f-string 来构建返回的字符串表示形式 + return ( + f"Group ID: {self.group_id}\n" + f"Name: {self.group_name if self.group_name else 'N/A'}\n" + f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n" + f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n" + f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n" + f"Members Len: {len(self.members) if self.members else 0}\n" + f"First Member: {self.members[0] if self.members else 'N/A'}\n" + ) + class AstrBotMessage: """ diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index ce38296e..c7aede7d 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,9 +1,9 @@ import asyncio - +import typing from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import Group, MessageMember from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes from aiocqhttp import CQHttp -from astrbot.core.utils.io import file_to_base64, download_image_by_url class AiocqhttpMessageEvent(AstrMessageEvent): @@ -21,20 +21,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent): d = segment.toDict() if isinstance(segment, Plain): d["type"] = "text" + d["data"]["text"] = segment.text.strip() elif isinstance(segment, (Image, Record)): # convert to base64 - if segment.file and segment.file.startswith("file:///"): - bs64_data = file_to_base64(segment.file[8:]) - image_file_path = segment.file[8:] - elif segment.file and segment.file.startswith("http"): - image_file_path = await download_image_by_url(segment.file) - bs64_data = file_to_base64(image_file_path) - elif segment.file and segment.file.startswith("base64://"): - bs64_data = segment.file - else: - bs64_data = file_to_base64(segment.file) + bs64 = await segment.convert_to_base64() d["data"] = { - "file": bs64_data, + "file": bs64, } elif isinstance(segment, At): d["data"] = { @@ -55,8 +47,13 @@ class AiocqhttpMessageEvent(AstrMessageEvent): if send_one_by_one: for seg in message.chain: - if isinstance(seg, Nodes): - # 带有多个节点的合并转发消息 + if isinstance(seg, (Node, Nodes)): + # 合并转发消息 + + if isinstance(seg, Node): + nodes = Nodes([seg]) + seg = nodes + payload = seg.toDict() if self.get_group_id(): payload["group_id"] = self.get_group_id() @@ -78,3 +75,46 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await self.bot.send(self.message_obj.raw_message, ret) await super().send(message) + + async def get_group(self, group_id=None, **kwargs): + if isinstance(group_id, str) and group_id.isdigit(): + group_id = int(group_id) + elif self.get_group_id(): + group_id = int(self.get_group_id()) + else: + return None + + info: dict = await self.bot.call_action( + "get_group_info", + group_id=group_id, + ) + + members: typing.List[typing.Dict] = await self.bot.call_action( + "get_group_member_list", + group_id=group_id, + ) + + owner_id = None + admin_ids = [] + for member in members: + if member["role"] == "owner": + owner_id = member["user_id"] + if member["role"] == "admin": + admin_ids.append(member["user_id"]) + + group = Group( + group_id=str(group_id), + group_name=info.get("group_name"), + group_avatar="", + group_admins=admin_ids, + group_owner=str(owner_id), + members=[ + MessageMember( + user_id=member["user_id"], + nickname=member.get("nickname") or member.get("card"), + ) + for member in members + ], + ) + + return group diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index bce637d7..95876a6f 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -1,17 +1,19 @@ -import threading import asyncio -import aiohttp -import quart import base64 import datetime -import re import os +import re +import threading + +import aiohttp import anyio -from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType -from astrbot.api.message_components import Plain, Image, At, Record +import quart + from astrbot.api import logger, sp -from .downloader import GeweDownloader +from astrbot.api.message_components import Plain, Image, At, Record +from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType from astrbot.core.utils.io import download_image_by_url +from .downloader import GeweDownloader class SimpleGewechatClient: @@ -51,11 +53,11 @@ class SimpleGewechatClient: self.server = quart.Quart(__name__) self.server.add_url_rule( - "/astrbot-gewechat/callback", view_func=self.callback, methods=["POST"] + "/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"] ) self.server.add_url_rule( "/astrbot-gewechat/file/", - view_func=self.handle_file, + view_func=self._handle_file, methods=["GET"], ) @@ -73,6 +75,7 @@ class SimpleGewechatClient: self.stop = False async def get_token_id(self): + """获取 Gewechat Token。""" async with aiohttp.ClientSession() as session: async with session.post(f"{self.base_url}/tools/getTokenId") as resp: json_blob = await resp.json() @@ -192,6 +195,11 @@ class SimpleGewechatClient: abm.sender = MessageMember(user_id, user_real_name) abm.raw_message = d abm.message_str = "" + + if user_id == "weixin": + # 忽略微信团队消息 + return + # 不同消息类型 match d["MsgType"]: case 1: @@ -253,7 +261,7 @@ class SimpleGewechatClient: logger.debug(f"abm: {abm}") return abm - async def callback(self): + async def _callback(self): data = await quart.request.json logger.debug(f"收到 gewechat 回调: {data}") @@ -275,7 +283,7 @@ class SimpleGewechatClient: return quart.jsonify({"r": "AstrBot ACK"}) - async def handle_file(self, file_id): + async def _handle_file(self, file_id): file_path = f"data/temp/{file_id}" return await quart.send_file(file_path) @@ -301,17 +309,17 @@ class SimpleGewechatClient: await self.server.run_task( host="0.0.0.0", port=self.port, - shutdown_trigger=self.shutdown_trigger_placeholder, + shutdown_trigger=self._shutdown_trigger_placeholder, ) - async def shutdown_trigger_placeholder(self): + async def _shutdown_trigger_placeholder(self): # TODO: use asyncio.Event while not self.event_queue.closed and not self.stop: # noqa: ASYNC110 await asyncio.sleep(1) logger.info("gewechat 适配器已关闭。") async def check_online(self, appid: str): - # /login/checkOnline + """检查 APPID 对应的设备是否在线。""" async with aiohttp.ClientSession() as session: async with session.post( f"{self.base_url}/login/checkOnline", @@ -322,6 +330,7 @@ class SimpleGewechatClient: return json_blob["data"] async def logout(self): + """登出 gewechat。""" if self.appid: online = await self.check_online(self.appid) if online: @@ -335,6 +344,7 @@ class SimpleGewechatClient: logger.info(f"登出结果: {json_blob}") async def login(self): + """登录 gewechat。一般来说插件用不到这个方法。""" if self.token is None: await self.get_token_id() @@ -446,9 +456,18 @@ class SimpleGewechatClient: self.appid = appid logger.info(f"已保存 APPID: {appid}") - """API""" + """API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1 + """ - async def get_chatroom_member_list(self, chatroom_wxid: str): + async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict: + """获取群成员列表。 + + Args: + chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。 + + Returns: + dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。 + """ payload = {"appId": self.appid, "chatroomId": chatroom_wxid} async with aiohttp.ClientSession() as session: @@ -461,6 +480,7 @@ class SimpleGewechatClient: return json_blob["data"] async def post_text(self, to_wxid, content: str, ats: str = ""): + """发送纯文本消息""" payload = { "appId": self.appid, "toWxid": to_wxid, @@ -477,6 +497,7 @@ class SimpleGewechatClient: logger.debug(f"发送消息结果: {json_blob}") async def post_image(self, to_wxid, image_url: str): + """发送图片消息""" payload = { "appId": self.appid, "toWxid": to_wxid, @@ -508,6 +529,12 @@ class SimpleGewechatClient: logger.debug(f"发送视频结果: {json_blob}") async def post_voice(self, to_wxid, voice_url: str, voice_duration: int): + """发送语音信息 + + Args: + voice_url (str): 语音文件的网络链接 + voice_duration (int): 语音时长,毫秒 + """ payload = { "appId": self.appid, "toWxid": to_wxid, @@ -525,6 +552,13 @@ class SimpleGewechatClient: logger.debug(f"发送语音结果: {json_blob}") async def post_file(self, to_wxid, file_url: str, file_name: str): + """发送文件 + + Args: + to_wxid (string): 微信ID + file_url (str): 文件的网络链接 + file_name (str): 文件名 + """ payload = { "appId": self.appid, "toWxid": to_wxid, @@ -538,3 +572,114 @@ class SimpleGewechatClient: ) as resp: json_blob = await resp.json() logger.debug(f"发送文件结果: {json_blob}") + + async def add_friend(self, v3: str, v4: str, content: str): + """申请添加好友""" + payload = { + "appId": self.appid, + "scene": 3, + "content": content, + "v4": v4, + "v3": v3, + "option": 2, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/contacts/addContacts", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"申请添加好友结果: {json_blob}") + return json_blob + + async def get_group(self, group_id: str): + payload = { + "appId": self.appid, + "chatroomId": group_id, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/group/getChatroomInfo", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob + + async def get_group_member(self, group_id: str): + payload = { + "appId": self.appid, + "chatroomId": group_id, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/group/getChatroomMemberList", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob + + async def accept_group_invite(self, url: str): + """同意进群""" + payload = {"appId": self.appid, "url": url} + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/group/agreeJoinRoom", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob + + async def add_group_member_to_friend( + self, group_id: str, to_wxid: str, content: str + ): + payload = { + "appId": self.appid, + "chatroomId": group_id, + "content": content, + "memberWxid": to_wxid, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/group/addGroupMemberAsFriend", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob + + async def get_user_or_group_info(self, *ids): + """ + 获取用户或群组信息。 + + :param ids: 可变数量的 wxid 参数 + """ + + wxids_str = list(ids) + + payload = { + "appId": self.appid, + "wxids": wxids_str, # 使用逗号分隔的字符串 + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/contacts/getDetailInfo", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index b43c0663..160b7c8a 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -3,13 +3,12 @@ import uuid import traceback import os -from astrbot.core.message.components import Video -from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file +from astrbot.core.utils.io import save_temp_img, download_file from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata -from astrbot.api.message_components import Plain, Image, Record, At, File +from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember +from astrbot.api.message_components import Plain, Image, Record, At, File, Video from .client import SimpleGewechatClient @@ -72,18 +71,10 @@ class GewechatPlatformEvent(AstrMessageEvent): await client.post_text(**payload) elif isinstance(comp, Image): - img_url = comp.file - img_path = "" - if img_url.startswith("file:///"): - img_path = img_url[8:] - elif comp.file and comp.file.startswith("http"): - img_path = await download_image_by_url(comp.file) - else: - img_path = img_url + img_path = await comp.convert_to_file_path() - # 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径 + # 检查 record_path 是否在 data/temp 目录中 temp_directory = os.path.abspath("data/temp") - img_path = os.path.abspath(img_path) if os.path.commonpath([temp_directory, img_path]) != temp_directory: with open(img_path, "rb") as f: img_path = save_temp_img(f.read()) @@ -137,14 +128,7 @@ class GewechatPlatformEvent(AstrMessageEvent): elif isinstance(comp, Record): # 默认已经存在 data/temp 中 record_url = comp.file - record_path = "" - - if record_url.startswith("file:///"): - record_path = record_url[8:] - elif record_url.startswith("http"): - await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") - else: - record_path = record_url + record_path = await comp.convert_to_file_path() silk_path = f"data/temp/{uuid.uuid4()}.silk" try: @@ -182,3 +166,30 @@ class GewechatPlatformEvent(AstrMessageEvent): to_wxid = self.message_obj.raw_message.get("to_wxid", None) await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client) await super().send(message) + + async def get_group(self, group_id=None, **kwargs): + # 确定有效的 group_id + if group_id is None: + group_id = self.get_group_id() + + if not group_id: + return None + + res = await self.client.get_group(group_id) + data: dict = res["data"] + + if not data["chatroomId"]: + return None + + members = [ + MessageMember(user_id=member["wxid"], nickname=member["nickName"]) + for member in data.get("memberList", []) + ] + + return Group( + group_id=data["chatroomId"], + group_name=data.get("nickName"), + group_avatar=data.get("smallHeadImgUrl"), + group_owner=data.get("chatRoomOwner"), + members=members, + ) diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index fd29b360..1ee30c48 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -2,6 +2,7 @@ import base64 import asyncio import json import re +import astrbot.api.message_components as Comp from astrbot.api.platform import ( Platform, @@ -11,7 +12,6 @@ from astrbot.api.platform import ( PlatformMetadata, ) from astrbot.api.event import MessageChain -from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion from .lark_event import LarkMessageEvent from ...register import register_platform_adapter @@ -92,7 +92,7 @@ class LarkPlatformAdapter(Platform): at_list = {} if message.mentions: for m in message.mentions: - at_list[m.key] = At(qq=m.id.open_id, name=m.name) + at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) if m.name == self.bot_name: abm.self_id = m.id.open_id @@ -111,7 +111,7 @@ class LarkPlatformAdapter(Platform): if s in at_list: abm.message.append(at_list[s]) else: - abm.message.append(Plain(parts[i].strip())) + abm.message.append(Comp.Plain(parts[i].strip())) elif message.message_type == "post": _ls = [] @@ -132,7 +132,7 @@ class LarkPlatformAdapter(Platform): if comp["tag"] == "at": abm.message.append(at_list[comp["user_id"]]) elif comp["tag"] == "text" and comp["text"].strip(): - abm.message.append(Plain(comp["text"].strip())) + abm.message.append(Comp.Plain(comp["text"].strip())) elif comp["tag"] == "img": image_key = comp["image_key"] request = ( @@ -147,10 +147,10 @@ class LarkPlatformAdapter(Platform): logger.error(f"无法下载飞书图片: {image_key}") image_bytes = response.file.read() image_base64 = base64.b64encode(image_bytes).decode() - abm.message.append(Image.fromBase64(image_base64)) + abm.message.append(Comp.Image.fromBase64(image_base64)) for comp in abm.message: - if isinstance(comp, Plain): + if isinstance(comp, Comp.Plain): abm.message_str += comp.text abm.message_id = message.message_id abm.raw_message = message diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 56257420..a219e249 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -15,6 +15,7 @@ class QQOfficialWebhook: self.appid = config["appid"] self.secret = config["secret"] self.port = config.get("port", 6196) + self.callback_server_host = config.get("callback_server_host", "0.0.0.0") if isinstance(self.port, str): self.port = int(self.port) @@ -95,8 +96,11 @@ class QQOfficialWebhook: return {"opcode": 12} async def start_polling(self): + logger.info( + f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。" + ) await self.server.run_task( - host="0.0.0.0", + host=self.callback_server_host, port=self.port, shutdown_trigger=self.shutdown_trigger_placeholder, ) diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index a8a04e2e..d19017a4 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -51,19 +51,8 @@ class TelegramPlatformEvent(AstrMessageEvent): at_flag = True await client.send_message(text=i.text, **payload) elif isinstance(i, Image): - if i.path: - image_path = i.path - else: - image_path = i.file - - if image_path.startswith("base64://"): - import base64 - - base64_data = image_path[9:] - image_bytes = base64.b64decode(base64_data) - await client.send_photo(photo=image_bytes, **payload) - else: - await client.send_photo(photo=image_path, **payload) + image_path = await i.convert_to_file_path() + await client.send_photo(photo=image_path, **payload) elif isinstance(i, File): if i.file.startswith("https://"): path = "data/temp/" + i.name @@ -72,7 +61,8 @@ class TelegramPlatformEvent(AstrMessageEvent): await client.send_document(document=i.file, filename=i.name, **payload) elif isinstance(i, Record): - await client.send_voice(voice=i.file, **payload) + path = await i.convert_to_file_path() + await client.send_voice(voice=path, **payload) async def send(self, message: MessageChain): if self.get_message_type() == MessageType.GROUP_MESSAGE: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 77eae03d..cef83b03 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -34,6 +34,7 @@ class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) self.port = int(config.get("port")) + self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.server.add_url_rule( "/callback/command", view_func=self.verify, methods=["GET"] ) @@ -86,9 +87,11 @@ class WecomServer: return "success" async def start_polling(self): - logger.info(f"将在 0.0.0.0:{self.port} 端口启动 企业微信 适配器。") + logger.info( + f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。" + ) await self.server.run_task( - host="0.0.0.0", + host=self.callback_server_host, port=self.port, shutdown_trigger=self.shutdown_trigger_placeholder, ) diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 83e99b5c..470b7b1f 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -3,7 +3,6 @@ from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import Plain, Image, Record from wechatpy.enterprise import WeChatClient -from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.api import logger @@ -43,14 +42,7 @@ class WecomPlatformEvent(AstrMessageEvent): message_obj.self_id, message_obj.session_id, comp.text ) elif isinstance(comp, Image): - img_url = comp.file - img_path = "" - if img_url.startswith("file:///"): - img_path = img_url[8:] - elif comp.file and comp.file.startswith("http"): - img_path = await download_image_by_url(comp.file) - else: - img_path = img_url + img_path = await comp.convert_to_file_path() with open(img_path, "rb") as f: try: @@ -68,16 +60,7 @@ class WecomPlatformEvent(AstrMessageEvent): response["media_id"], ) elif isinstance(comp, Record): - record_url = comp.file - record_path = "" - - if record_url.startswith("file:///"): - record_path = record_url[8:] - elif record_url.startswith("http"): - await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") - else: - record_path = record_url - + record_path = await comp.convert_to_file_path() # 转成amr record_path_amr = f"data/temp/{uuid.uuid4()}.amr" pydub.AudioSegment.from_wav(record_path).export( diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 3180b495..c51b860f 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -4,6 +4,8 @@ from typing import List, Dict, Type from .func_tool_manager import FuncCall from openai.types.chat.chat_completion import ChatCompletion from astrbot.core.db.po import Conversation +from astrbot.core.message.message_event_result import MessageChain +import astrbot.core.message.components as Comp class ProviderType(enum.Enum): @@ -56,8 +58,8 @@ class ProviderRequest: class LLMResponse: role: str """角色, assistant, tool, err""" - completion_text: str = "" - """LLM 返回的文本""" + result_chain: MessageChain = None + """返回的消息链""" tools_call_args: List[Dict[str, any]] = field(default_factory=list) """工具调用参数""" tools_call_name: List[str] = field(default_factory=list) @@ -65,3 +67,51 @@ class LLMResponse: raw_completion: ChatCompletion = None _new_record: Dict[str, any] = None + + _completion_text: str = "" + + def __init__( + self, + role: str, + completion_text: str = "", + result_chain: MessageChain = None, + tools_call_args: List[Dict[str, any]] = None, + tools_call_name: List[str] = None, + raw_completion: ChatCompletion = None, + _new_record: Dict[str, any] = None, + ): + """初始化 LLMResponse + + Args: + role (str): 角色, assistant, tool, err + completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "". + result_chain (MessageChain, optional): 返回的消息链. Defaults to None. + tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None. + tools_call_name (List[str], optional): 工具调用名称. Defaults to None. + raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None. + """ + self.role = role + self.completion_text = completion_text + self.result_chain = result_chain + self.tools_call_args = tools_call_args + self.tools_call_name = tools_call_name + self.raw_completion = raw_completion + self._new_record = _new_record + + @property + def completion_text(self): + if self.result_chain: + return self.result_chain.get_plain_text() + return self._completion_text + + @completion_text.setter + def completion_text(self, value): + if self.result_chain: + self.result_chain.chain = [ + comp + for comp in self.result_chain.chain + if not isinstance(comp, Comp.Plain) + ] # 清空 Plain 组件 + self.result_chain.chain.insert(0, Comp.Plain(value)) + else: + self._completion_text = value diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index cdb1b3d6..0f04628f 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -73,7 +73,7 @@ class FuncCall: handler=handler, ) self.func_list.append(_func) - logger.info(f"添加了函数调用工具({len(self.func_list)}): {name} - {desc}") + logger.info(f"添加函数调用工具: {name}") def remove_func(self, name: str) -> None: """ diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 9647b41c..7158d57b 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -1,3 +1,4 @@ +import re import asyncio import functools from typing import List @@ -40,11 +41,24 @@ class ProviderDashscope(ProviderOpenAIOfficial): raise Exception("阿里云百炼 APP 类型不能为空。") self.model_name = "dashscope" self.variables: dict = provider_config.get("variables", {}) + self.rag_options: dict = provider_config.get("rag_options", {}) + self.output_reference = self.rag_options.get("output_reference", False) + self.rag_options = self.rag_options.copy() + self.rag_options.pop("output_reference", None) self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) + def has_rag_options(self): + if ( + self.rag_options + and self.rag_options.get("pipeline_ids", None) + and self.rag_options.get("file_ids", None) + ): + return True + return False + async def text_chat( self, prompt: str, @@ -62,7 +76,10 @@ class ProviderDashscope(ProviderOpenAIOfficial): session_var = session_vars.get(session_id, {}) payload_vars.update(session_var) - if self.dashscope_app_type in ["agent", "dialog-workflow"]: + if ( + self.dashscope_app_type in ["agent", "dialog-workflow"] + and self.has_rag_options() + ): # 支持多轮对话的 new_record = {"role": "user", "content": prompt} if image_urls: @@ -86,12 +103,17 @@ class ProviderDashscope(ProviderOpenAIOfficial): else: # 不支持多轮对话的 # 调用阿里云百炼 API + payload = { + "app_id": self.app_id, + "prompt": prompt, + "api_key": self.api_key, + "biz_params": payload_vars or None, + } + if self.rag_options: + payload["rag_options"] = self.rag_options partial = functools.partial( Application.call, - app_id=self.app_id, - promtp=prompt, - api_key=self.api_key, - biz_params=payload_vars or None, + **payload, ) response = await asyncio.get_event_loop().run_in_executor(None, partial) @@ -107,6 +129,14 @@ class ProviderDashscope(ProviderOpenAIOfficial): ) output_text = response.output.get("text", "") + # RAG 引用脚标格式化 + output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) + if self.output_reference and response.output.get("doc_references", None): + ref_str = "" + for ref in response.output.get("doc_references", []): + ref_str += f"{ref['index_id']}. {ref['title']}\n" + output_text += f"\n\n回答来源:\n{ref_str}" + return LLMResponse(role="assistant", completion_text=output_text) async def forget(self, session_id): diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 37f575f2..8b5890c2 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -1,3 +1,5 @@ +import astrbot.core.message.components as Comp + from typing import List from .. import Provider, Personality from ..entites import LLMResponse @@ -5,8 +7,9 @@ from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter from astrbot.core.utils.dify_api_client import DifyAPIClient -from astrbot.core.utils.io import download_image_by_url +from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.core import logger, sp +from astrbot.core.message.message_event_result import MessageChain @register_provider_adapter("dify", "Dify APP 适配器。") @@ -30,7 +33,6 @@ class ProviderDify(Provider): if not self.api_key: raise Exception("Dify API Key 不能为空。") api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") - self.api_client = DifyAPIClient(self.api_key, api_base) self.api_type = provider_config.get("dify_api_type", "") if not self.api_type: raise Exception("Dify API 类型不能为空。") @@ -41,15 +43,19 @@ class ProviderDify(Provider): self.dify_query_input_key = provider_config.get( "dify_query_input_key", "astrbot_text_query" ) - self.variables: dict = provider_config.get("variables", {}) if not self.dify_query_input_key: self.dify_query_input_key = "astrbot_text_query" + if not self.workflow_output_key: + self.workflow_output_key = "astrbot_wf_output" + self.variables: dict = provider_config.get("variables", {}) self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) self.conversation_ids = {} """记录当前 session id 的对话 ID""" + self.api_client = DifyAPIClient(self.api_key, api_base) + async def text_chat( self, prompt: str, @@ -65,26 +71,27 @@ class ProviderDify(Provider): files_payload = [] for image_url in image_urls: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - file_response = await self.api_client.file_upload( - image_path, user=session_id + image_path = ( + await download_image_by_url(image_url) + if image_url.startswith("http") + else image_url + ) + file_response = await self.api_client.file_upload( + image_path, user=session_id + ) + logger.debug(f"Dify 上传图片响应:{file_response}") + if "id" not in file_response: + logger.warning( + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" ) - if "id" not in file_response: - logger.warning( - f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" - ) - continue - files_payload.append( - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - } - ) - else: - # TODO: 处理更多情况 - logger.warning(f"未知的图片链接:{image_url},图片将忽略。") + continue + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) # 获得会话变量 payload_vars = self.variables.copy() @@ -96,6 +103,9 @@ class ProviderDify(Provider): try: match self.api_type: case "chat" | "agent": + if not prompt: + prompt = "请描述这张图片。" + async for chunk in self.api_client.chat_messages( inputs={ **payload_vars, @@ -148,8 +158,9 @@ class ProviderDify(Provider): ) case "workflow_finished": logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。" + f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束" ) + logger.debug(f"Dify 工作流结果:{chunk}") if chunk["data"]["error"]: logger.error( f"Dify 工作流出现错误:{chunk['data']['error']}" @@ -164,9 +175,7 @@ class ProviderDify(Provider): raise Exception( f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}" ) - result = chunk["data"]["outputs"][ - self.workflow_output_key - ] + result = chunk case _: raise Exception(f"未知的 Dify API 类型:{self.api_type}") except Exception as e: @@ -176,7 +185,54 @@ class ProviderDify(Provider): if not result: logger.warning("Dify 请求结果为空,请查看 Debug 日志。") - return LLMResponse(role="assistant", completion_text=result) + chain = await self.parse_dify_result(result) + + return LLMResponse(role="assistant", result_chain=chain) + + async def parse_dify_result(self, chunk: dict | str) -> MessageChain: + if isinstance(chunk, str): + # Chat + return MessageChain(chain=[Comp.Plain(chunk)]) + + async def parse_file(item: dict) -> Comp: + match item["type"]: + case "image": + return Comp.Image(file=item["url"], url=item["url"]) + case "audio": + # 仅支持 wav + path = f"data/temp/{item['filename']}.wav" + await download_file(item["url"], path) + return Comp.Image(file=item["url"], url=item["url"]) + case "video": + return Comp.Video(file=item["url"]) + case _: + return Comp.File(name=item["filename"], file=item["url"]) + + output = chunk["data"]["outputs"][self.workflow_output_key] + chains = [] + if isinstance(output, str): + # 纯文本输出 + chains.append(Comp.Plain(output)) + elif isinstance(output, list): + # 主要适配 Dify 的 HTTP 请求结点的多模态输出 + for item in output: + # handle Array[File] + if ( + not isinstance(item, dict) + or item.get("dify_model_identity", "") != "__dify__file__" + ): + chains.append(Comp.Plain(str(output))) + break + else: + chains.append(Comp.Plain(str(output))) + + # scan file + files = chunk["data"].get("files", []) + for item in files: + comp = await parse_file(item) + chains.append(comp) + + return MessageChain(chain=chains) async def forget(self, session_id): self.conversation_ids[session_id] = "" diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 77d7fb48..0b9f7ad0 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -358,7 +358,9 @@ def register_llm_tool(name: str = None): } ) md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) - llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler) + llm_tools.add_func( + llm_tool_name, args, docstring.description.strip(), md.handler + ) return awaitable return decorator diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index 80be3fff..badf5d62 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -8,7 +8,7 @@ class DifyAPIClient: def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): self.api_key = api_key self.api_base = api_base - self.session = ClientSession() + self.session = ClientSession(trust_env=True) self.headers = { "Authorization": f"Bearer {self.api_key}", } diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 2e525a3c..34ba8fd3 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -3,6 +3,7 @@ import datetime from .route import Route, Response, RouteContext from quart import request from astrbot.core import WEBUI_SK +from astrbot import logger class AuthRoute(Route): @@ -19,9 +20,20 @@ class AuthRoute(Route): password = self.config["dashboard"]["password"] post_data = await request.json if post_data["username"] == username and post_data["password"] == password: + change_pwd_hint = False + if username == "astrbot" and password == "77b90590a8945a7d36c963981a307dc9": + change_pwd_hint = True + logger.warning("为了保证安全,请尽快修改默认密码。") + return ( Response() - .ok({"token": self.generate_jwt(username), "username": username}) + .ok( + { + "token": self.generate_jwt(username), + "username": username, + "change_pwd_hint": change_pwd_hint, + } + ) .__dict__ ) else: diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 54140d92..088c999f 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -29,11 +29,21 @@ def validate_config( ) -> typing.Tuple[typing.List[str], typing.Dict]: errors = [] - def validate(data, metadata=schema, path=""): - for key, meta in metadata.items(): - if key not in data: + def validate(data: dict, metadata: dict = schema, path=""): + for key, value in data.items(): + if key not in metadata: + # 无 schema 的配置项,执行类型猜测 + if isinstance(value, str): + if value.isdigit(): + data[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + data[key] = float(value) + elif value == "true": + data[key] = True + elif value == "false": + data[key] = False continue - value = data[key] + meta = metadata[key] # null 转换 if value is None: data[key] = DEFAULT_VALUE_MAP[meta["type"]] @@ -43,6 +53,16 @@ def validate_config( errors.append( f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}" ) + elif ( + meta["type"] == "list" + and isinstance(value, list) + and value + and "items" in meta + and isinstance(value[0], dict) + ): + # 当前仅针对 list[dict] 的情况进行类型校验,以适配 AstrBot 中 platform、provider 的配置 + for item in value: + validate(item, meta["items"], path=f"{path}{key}.") elif meta["type"] == "object" and isinstance(value, dict): validate(value, meta["items"], path=f"{path}{key}.") diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 42bbad31..6fc0651f 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -120,12 +120,22 @@ class AstrBotDashboard: return f"获取进程信息失败: {str(e)}" def run(self): - try: - ip_addr = get_local_ip_addresses() - except Exception as _: - ip_addr = [] - + ip_addr = [] port = self.core_lifecycle.astrbot_config["dashboard"].get("port", 6185) + host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "0.0.0.0") + + logger.info(f"正在启动 WebUI, 监听地址: http://{host}:{port}") + + if host == "0.0.0.0": + logger.info( + "提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host)" + ) + + if host not in ["localhost", "127.0.0.1"]: + try: + ip_addr = get_local_ip_addresses() + except Exception as _: + pass if isinstance(port, str): port = int(port) @@ -142,15 +152,21 @@ class AstrBotDashboard: raise Exception(f"端口 {port} 已被占用") - display = f"\n ✨✨✨\n AstrBot v{VERSION} 管理面板已启动,可访问\n\n" + display = f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n" display += f" ➜ 本地: http://localhost:{port}\n" for ip in ip_addr: display += f" ➜ 网络: http://{ip}:{port}\n" display += " ➜ 默认用户名和密码: astrbot\n ✨✨✨\n" + + if not ip_addr: + display += ( + "可在 data/cmd_config.json 中配置 dashboard.host 以便远程访问。\n" + ) + logger.info(display) return self.app.run_task( - host="0.0.0.0", + host=host, port=port, shutdown_trigger=self.shutdown_trigger_placeholder, ) diff --git a/changelogs/v3.4.38.md b/changelogs/v3.4.38.md new file mode 100644 index 00000000..2a3504fc --- /dev/null +++ b/changelogs/v3.4.38.md @@ -0,0 +1,57 @@ +# What's Changed + +> Special thanks for all contributors and plugin developers and users who love AstrBot. 💖 + +## ✨ 新增的功能 + +1. 支持解析回复消息,支持 LLM 对所引用消息具有感知 #783 +2. 支持 Dify 的文件、图片、视频、音频输出 #819 +3. QQ 下支持嵌套转发(napcat) @zouyonghe +4. 配置页样式重写,更紧凑的 WebUI 配置 + +## 🎈 功能性优化 + +1. 使用系统时间而不是 UTC+8 时间作为默认时间以适应海外用户需求 @roeseth +2. 在对话隔离情况下也可以将整个群聊加入白名单 #746 +3. 在调用插件异常时更完整的报错输出 +4. gewechat 下对已知且没有业务处理的事件类型不显示详细日志 @diudiu62 +5. 优化 WebUI 悬浮文档 @IGCrystal +6. 支持自定义 WebUI、Wecom Webhook Server, QQ Official Webhook Server 的 host #821 +7. Dify 下当只有图片输入时的默认 prompt 防止一些报错 #837 + +## 🐛 修复的 Bug + +1. fishaudio 默认 baseurl 不可用 +2. gewechat 下重复登录后提示设备不存在导致无法重新登陆 @beat4ocean +3. gewechat 下用户本人发消息会触发消息回复 @beat4ocean +4. 钉钉 WebUI 文档不显示 +5. 更新插件后插件热重载不完全、函数工具重复添加 +6. OpenAI TTS API TypeError 报错 #755 +7. EdgeTTS 部分情况下无法使用 @Soulter @需要哦 +8. QQ 官方机器人平台下发送 base64 图片消息段报错 @Soulter @shuiping233 +9. QQ 官方机器人平台下命令参数报错信息无法正常发送 @shuiping233 +10. WebUI 错误地显示未知更新 +11. 部分情况下文件无法上传到 Telegram 群组 #601 +12. 插件管理的插件简介太长导致 “帮助”“操作”图标不显示 #790 +13. LLOnebot 合并消息转发错误 #842 +14. model_config 中自定义的配置项(如温度)类型自动变回 string #854 + +## 🧩 新增的插件 + +1. astrbot_plugin_image_understanding_Janus-Pro - 使用deepseek-ai/Janus-Pro系列模型为本地模型提供的图片理解补充 @xiewoc +2. astrbot_plugin_moyurenpro - 摸鱼人日历,支持自定义时间时区,自定义api,支持立即发送,工作日定时发送。 @quirrel-zh @DuBwTf +3. astrbot_plugin_wechat_manager - 微信关键字好友自动审核、关键字邀请进群。@diudiu62 +4. astrbot_plugin_qwq_filter - qwq 思考过滤工具 @beat4ocean +5. astrbot_plugin_chatsummary - 一个通过拉取历史聊天记录,调用LLM大模型接口实现消息总结功能。@laopanmemz +6. astrBot_PGR_Dialogue - 检测到部分战双角色的名称(或别称)时,有概率发送一条语音文本 @KurisuRee7 +7. astrbot_plugin_bv - 解析群内https://www.bilibili.com/video/BV号/ 的链接并获取视频数据与视频文件,以合并转发方式发送 @haliludaxuanfeng +8. astrbot_plugin_gemini_exp - 让你在AstrBot调用Gemini2.0-flash-exp来生成图片或者p图。Gemini2.0-flash-exp为原生多模态模型,其既是语言模型,也是生图模型,因此能够对图像使用简单的自然语言命令进行处理。@Elen123bot +9. astrbot_plugin_sjzb - 随机生成绝地潜兵2游戏中一组4个战备配置 @tenno1174 +10. astrbot_plugin_picture_manager - 图片管理插件,允许用户通过自定义触发指令从API或直接URL获取图片。@bigshabei +11. astrbot_plugin_bilibiliParse - 解析哔哩哔哩视频,并以图片的形式发送给用户 @7Hello12 +12. astrbot_plugin_sensoji - 这是一个模拟日本浅草寺抽签功能的插件。用户可以通过发送 /抽签 命令随机抽取一个签文,获取运势提示。签文包含吉凶结果(如“大吉”、“凶”等)以及对应的运势描述。 @Shouugou +13. astrbot_plugin_videosummary - 使用 bibigpt 实现视频总结 @kterna +14. astrbot_plugin_InitiativeDialogue - 使 bot 在用户长时间未发送消息时主动与用户对话的插件 @advent259141 +15. astrbot_plugin_emoji - 基于达莉娅综合群娱插件的表情包制作插件,仅保留了@其他群员制作表情包的部分。由桑帛云API提供表情包制作。@KurisuRee7 +16. astrbot_plugin_videos_analysis - 聚合视频分享链接解析(仅测试过napcat) @miaoxutao123 +17. astrbot_plugin_daily_news - 每日 60 秒新闻推送插件 - 自动推送每日热点新闻 @anka-afk \ No newline at end of file diff --git a/changelogs/v3.4.39.md b/changelogs/v3.4.39.md new file mode 100644 index 00000000..d80b4e86 --- /dev/null +++ b/changelogs/v3.4.39.md @@ -0,0 +1,4 @@ +# What's Changed + +1. 默认账户密码登录成功后弹出修改警告 +2. 将 WebUI 默认 host 改变回 v3.4.38 之前的版本以减少兼容性问题。 \ No newline at end of file diff --git a/compose.yml b/compose.yml index 805d30c1..3bab93fc 100644 --- a/compose.yml +++ b/compose.yml @@ -1,16 +1,21 @@ version: '3.8' +# 当接入 QQ NapCat 时,请使用这个 compose 文件一件部署: https://github.com/NapNeko/NapCat-Docker/blob/main/compose/astrbot.yml + services: astrbot: image: soulter/astrbot:latest container_name: astrbot + restart: always ports: # mappings description: https://github.com/Soulter/AstrBot/issues/497 - - "6185:6185" - - "6195:6195" # optional, wecom default port - - "6199:6199" # optional, aiocqhttp default port - - "6196:6196" # optional, qq official webhook default port - - "11451:11451" # optional, gewechat default port + - "6185:6185" # 必选,AstrBot WebUI 端口 + - "6195:6195" # 可选, 企业微信 Webhook 端口 + - "6199:6199" # 可选, QQ 个人号 WebSocket 端口 + - "6196:6196" # 可选, QQ 官方接口 Webhook 端口 + - "11451:11451" # 可选, 微信个人号 Webhook 端口 + environment: + - TZ=Asia/Shanghai volumes: - ./data:/AstrBot/data - - /etc/timezone:/etc/timezone:ro - - /etc/localtime:/etc/localtime:ro + # - /etc/timezone:/etc/timezone:ro + # - /etc/localtime:/etc/localtime:ro diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index cc39f7ff..2796f95d 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -2,9 +2,6 @@
{{ metadata[metadataKey]?.description }} ({{ metadataKey }}) - - object -
- +
+ style="border: 1px solid #e0e0e0; padding: 8px; margin-bottom: 16px; border-radius: 10px; margin-top: 16px">
- - + + - {{ metadata[metadataKey].items[key]?.description + '(' + key + ')' }} - {{ - metadata[metadataKey].items[key]?.type }} - + {{ metadata[metadataKey].items[key]?.description + '(' + key + ')' }} + {{ key }} @@ -45,7 +39,14 @@ - + + {{ + metadata[metadataKey].items[key]?.type || 'string' }} + + + +
- +
- + + +
- - + + {{ metadata[metadataKey]?.description + '(' + metadataKey + ')' }} - {{ - metadata[metadataKey]?.type }} - @@ -100,23 +99,35 @@ - + + + {{ + metadata[metadataKey]?.type }} + + + +
+ dense :disabled="metadata[metadataKey]?.readonly" density="compact" flat hide-details + single-line> + v-model="iterable[metadataKey]" variant="outlined" dense density="compact" flat hide-details + single-line> + v-model="iterable[metadataKey]" variant="outlined" dense density="compact" flat hide-details + single-line> + v-model="iterable[metadataKey]" variant="outlined" dense density="compact" flat hide-details + single-line> + v-model="iterable[metadataKey]" color="primary" hide-details> @@ -124,7 +135,7 @@ - +
diff --git a/dashboard/src/components/shared/ExtensionCard.vue b/dashboard/src/components/shared/ExtensionCard.vue index 6b30460d..a88d6beb 100644 --- a/dashboard/src/components/shared/ExtensionCard.vue +++ b/dashboard/src/components/shared/ExtensionCard.vue @@ -70,11 +70,10 @@ const viewHandlers = () => {

{{ extension.name }} - P - 有新版本可用: {{ extension.online_version }} {{ extension }} + 有新版本可用: {{ extension.online_version }}