Compare commits

...

4 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
c6b6eef8c4 Complete Docker compatibility fix with enhanced documentation
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:03:54 +00:00
copilot-swe-agent[bot]
50cf263076 Implement CLI Docker compatibility fix and login-info command
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:01:36 +00:00
copilot-swe-agent[bot]
2554548088 Initial commit: fix formatting and explore codebase
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 15:53:37 +00:00
copilot-swe-agent[bot]
aa4a2d10e2 Initial plan 2025-08-16 15:48:10 +00:00
30 changed files with 305 additions and 134 deletions

0
astrbot.lock Normal file
View File

View File

@@ -139,6 +139,14 @@ def conf():
- dashboard.password: Dashboard 密码 - dashboard.password: Dashboard 密码
- callback_api_base: 回调接口基址 - callback_api_base: 回调接口基址
可用子命令:
- set: 设置配置项值
- get: 获取配置项值
- login-info: 显示 Web 管理面板登录信息
""" """
pass pass
@@ -204,3 +212,44 @@ def get_config(key: str = None):
click.echo(f" {key}: {value}") click.echo(f" {key}: {value}")
except (KeyError, TypeError): except (KeyError, TypeError):
pass pass
@conf.command(name="login-info")
def get_login_info():
"""显示 Web 管理面板的登录信息
在 Docker 环境中使用示例:
docker exec -e ASTRBOT_ROOT=/AstrBot astrbot-container astrbot conf login-info
"""
config = _load_config()
try:
username = _get_nested_item(config, "dashboard.username")
# 注意我们不显示实际的MD5哈希密码而是提示用户如何重置
click.echo("🔐 Web 管理面板登录信息:")
click.echo(f" 用户名: {username}")
click.echo(" 密码: [已加密存储]")
click.echo()
click.echo("💡 如需重置密码,请使用以下命令:")
click.echo(" astrbot conf set dashboard.password <新密码>")
click.echo()
click.echo("🌐 访问地址:")
# 尝试获取端口信息
try:
port = _get_nested_item(config, "dashboard.port")
click.echo(f" http://localhost:{port}")
click.echo(f" http://your-server-ip:{port}")
except (KeyError, TypeError):
click.echo(" http://localhost:6185 (默认端口)")
click.echo(" http://your-server-ip:6185 (默认端口)")
click.echo()
click.echo("📋 Docker 环境使用说明:")
click.echo(" 如果在 Docker 中运行,请使用以下命令格式:")
click.echo(" docker exec -e ASTRBOT_ROOT=/AstrBot <容器名> astrbot conf login-info")
except KeyError:
click.echo("❌ 无法找到登录配置,请先运行 'astrbot init' 初始化")
except Exception as e:
raise click.UsageError(f"获取登录信息失败: {str(e)}")

View File

@@ -16,6 +16,12 @@ def check_astrbot_root(path: str | Path) -> bool:
def get_astrbot_root() -> Path: def get_astrbot_root() -> Path:
"""获取Astrbot根目录路径""" """获取Astrbot根目录路径"""
import os
# 使用与core应用相同的路径解析逻辑优先使用ASTRBOT_ROOT环境变量
if path := os.environ.get("ASTRBOT_ROOT"):
return Path(path)
else:
return Path.cwd() return Path.cwd()

View File

@@ -124,7 +124,8 @@ def build_plug_list(plugins_dir: Path) -> list:
if metadata and all( if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"] k in metadata for k in ["name", "desc", "version", "author", "repo"]
): ):
result.append({ result.append(
{
"name": str(metadata.get("name", "")), "name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")), "desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")), "version": str(metadata.get("version", "")),
@@ -132,7 +133,8 @@ def build_plug_list(plugins_dir: Path) -> list:
"repo": str(metadata.get("repo", "")), "repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED, "status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir), "local_path": str(plugin_dir),
}) }
)
# 获取在线插件列表 # 获取在线插件列表
online_plugins = [] online_plugins = []
@@ -142,7 +144,8 @@ def build_plug_list(plugins_dir: Path) -> list:
resp.raise_for_status() resp.raise_for_status()
data = resp.json() data = resp.json()
for plugin_id, plugin_info in data.items(): for plugin_id, plugin_info in data.items():
online_plugins.append({ online_plugins.append(
{
"name": str(plugin_id), "name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")), "desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")), "version": str(plugin_info.get("version", "")),
@@ -150,7 +153,8 @@ def build_plug_list(plugins_dir: Path) -> list:
"repo": str(plugin_info.get("repo", "")), "repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED, "status": PluginStatus.NOT_INSTALLED,
"local_path": None, "local_path": None,
}) }
)
except Exception as e: except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True) click.echo(f"获取在线插件列表失败: {e}", err=True)

View File

@@ -65,7 +65,7 @@ DEFAULT_CONFIG = {
"show_tool_use_status": False, "show_tool_use_status": False,
"streaming_segmented": False, "streaming_segmented": False,
"separate_provider": True, "separate_provider": True,
"max_agent_step": 30 "max_agent_step": 30,
}, },
"provider_stt_settings": { "provider_stt_settings": {
"enable": False, "enable": False,
@@ -598,11 +598,8 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.openai.com/v1", "api_base": "https://api.openai.com/v1",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"model": "gpt-4o-mini", "hint": "也兼容所有与OpenAI API兼容的服务。",
"temperature": 0.4
},
"hint": "也兼容所有与OpenAI API兼容的服务。"
}, },
"Azure OpenAI": { "Azure OpenAI": {
"id": "azure", "id": "azure",
@@ -614,10 +611,7 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "", "api_base": "",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"model": "gpt-4o-mini",
"temperature": 0.4
},
}, },
"xAI": { "xAI": {
"id": "xai", "id": "xai",
@@ -628,10 +622,7 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.x.ai/v1", "api_base": "https://api.x.ai/v1",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {"model": "grok-2-latest", "temperature": 0.4},
"model": "grok-2-latest",
"temperature": 0.4
},
}, },
"Anthropic": { "Anthropic": {
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错", "hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
@@ -646,11 +637,11 @@ CONFIG_METADATA_2 = {
"model_config": { "model_config": {
"model": "claude-3-5-sonnet-latest", "model": "claude-3-5-sonnet-latest",
"max_tokens": 4096, "max_tokens": 4096,
"temperature": 0.2 "temperature": 0.2,
}, },
}, },
"Ollama": { "Ollama": {
"hint":"启用前请确保已正确安装并运行 Ollama 服务端Ollama默认不带鉴权无需修改key", "hint": "启用前请确保已正确安装并运行 Ollama 服务端Ollama默认不带鉴权无需修改key",
"id": "ollama_default", "id": "ollama_default",
"provider": "ollama", "provider": "ollama",
"type": "openai_chat_completion", "type": "openai_chat_completion",
@@ -658,10 +649,7 @@ CONFIG_METADATA_2 = {
"enable": True, "enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama "key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1", "api_base": "http://localhost:11434/v1",
"model_config": { "model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"model": "llama3.1-8b",
"temperature": 0.4
},
}, },
"LM Studio": { "LM Studio": {
"id": "lm_studio", "id": "lm_studio",
@@ -686,7 +674,7 @@ CONFIG_METADATA_2 = {
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {
"model": "gemini-1.5-flash", "model": "gemini-1.5-flash",
"temperature": 0.4 "temperature": 0.4,
}, },
}, },
"Gemini": { "Gemini": {
@@ -700,7 +688,7 @@ CONFIG_METADATA_2 = {
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {
"model": "gemini-2.0-flash-exp", "model": "gemini-2.0-flash-exp",
"temperature": 0.4 "temperature": 0.4,
}, },
"gm_resp_image_modal": False, "gm_resp_image_modal": False,
"gm_native_search": False, "gm_native_search": False,
@@ -725,10 +713,7 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.deepseek.com/v1", "api_base": "https://api.deepseek.com/v1",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {"model": "deepseek-chat", "temperature": 0.4},
"model": "deepseek-chat",
"temperature": 0.4
},
}, },
"302.AI": { "302.AI": {
"id": "302ai", "id": "302ai",
@@ -739,10 +724,7 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.302.ai/v1", "api_base": "https://api.302.ai/v1",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"model": "gpt-4.1-mini",
"temperature": 0.4
},
}, },
"硅基流动": { "硅基流动": {
"id": "siliconflow", "id": "siliconflow",
@@ -755,7 +737,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.siliconflow.cn/v1", "api_base": "https://api.siliconflow.cn/v1",
"model_config": { "model_config": {
"model": "deepseek-ai/DeepSeek-V3", "model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4 "temperature": 0.4,
}, },
}, },
"PPIO派欧云": { "PPIO派欧云": {
@@ -769,7 +751,7 @@ CONFIG_METADATA_2 = {
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {
"model": "deepseek/deepseek-r1", "model": "deepseek/deepseek-r1",
"temperature": 0.4 "temperature": 0.4,
}, },
}, },
"优云智算": { "优云智算": {
@@ -794,10 +776,7 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"timeout": 120, "timeout": 120,
"api_base": "https://api.moonshot.cn/v1", "api_base": "https://api.moonshot.cn/v1",
"model_config": { "model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"model": "moonshot-v1-8k",
"temperature": 0.4
},
}, },
"智谱 AI": { "智谱 AI": {
"id": "zhipu_default", "id": "zhipu_default",
@@ -825,7 +804,7 @@ CONFIG_METADATA_2 = {
"dify_query_input_key": "astrbot_text_query", "dify_query_input_key": "astrbot_text_query",
"variables": {}, "variables": {},
"timeout": 60, "timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!" "hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
}, },
"阿里云百炼应用": { "阿里云百炼应用": {
"id": "dashscope", "id": "dashscope",
@@ -853,10 +832,7 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"timeout": 120, "timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1", "api_base": "https://api-inference.modelscope.cn/v1",
"model_config": { "model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"model": "Qwen/Qwen3-32B",
"temperature": 0.4
},
}, },
"FastGPT": { "FastGPT": {
"id": "fastgpt", "id": "fastgpt",

View File

@@ -128,6 +128,7 @@ class Plain(BaseMessageComponent):
async def to_dict(self): async def to_dict(self):
return {"type": "text", "data": {"text": self.text}} return {"type": "text", "data": {"text": self.text}}
class Face(BaseMessageComponent): class Face(BaseMessageComponent):
type: ComponentType = "Face" type: ComponentType = "Face"
id: int id: int

View File

@@ -8,6 +8,7 @@ from enum import Enum, auto
class AgentState(Enum): class AgentState(Enum):
"""Agent 状态枚举""" """Agent 状态枚举"""
IDLE = auto() # 初始状态 IDLE = auto() # 初始状态
RUNNING = auto() # 运行中 RUNNING = auto() # 运行中
DONE = auto() # 完成 DONE = auto() # 完成

View File

@@ -57,6 +57,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
logger.error(f"钉钉图片处理失败: {e}") logger.error(f"钉钉图片处理失败: {e}")
logger.warning(f"跳过图片发送: {image_path}") logger.warning(f"跳过图片发送: {image_path}")
continue continue
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
await self.send_with_client(self.client, message) await self.send_with_client(self.client, message)
await super().send(message) await super().send(message)

View File

@@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot):
await self.on_ready_once_callback() await self.on_ready_once_callback()
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True) f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
)
def _create_message_data(self, message: discord.Message) -> dict: def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典""" """从 discord.Message 创建数据字典"""
@@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot):
message_data = self._create_message_data(message) message_data = self._create_message_data(message)
await self.on_message_received(message_data) await self.on_message_received(message_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str: def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容""" """从交互中提取内容"""
interaction_type = interaction.type interaction_type = interaction.type

View File

@@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent):
self.url = url self.url = url
self.disabled = disabled self.disabled = disabled
class DiscordReference(BaseMessageComponent): class DiscordReference(BaseMessageComponent):
"""Discord引用组件""" """Discord引用组件"""
type: str = "discord_reference" type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str): def __init__(self, message_id: str, channel_id: str):
self.message_id = message_id self.message_id = message_id
self.channel_id = channel_id self.channel_id = channel_id
@@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent):
self.components = components or [] self.components = components or []
self.timeout = timeout self.timeout = timeout
def to_discord_view(self) -> discord.ui.View: def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象""" """转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout) view = discord.ui.View(timeout=self.timeout)

View File

@@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
# 解析消息链为 Discord 所需的对象 # 解析消息链为 Discord 所需的对象
try: try:
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message) (
content,
files,
view,
embeds,
reference_message_id,
) = await self._parse_to_discord(message)
except Exception as e: except Exception as e:
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True) logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
return return
@@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
if await asyncio.to_thread(path.exists): if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes) file_bytes = await asyncio.to_thread(path.read_bytes)
files.append( files.append(
discord.File(BytesIO(file_bytes), discord.File(BytesIO(file_bytes), filename=i.name)
filename=i.name)
) )
else: else:
logger.warning( logger.warning(

View File

@@ -308,7 +308,9 @@ class SlackAdapter(Platform):
base64_content = base64.b64encode(content).decode("utf-8") base64_content = base64.b64encode(content).decode("utf-8")
return base64_content return base64_content
else: else:
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}") logger.error(
f"Failed to download slack file: {resp.status} {await resp.text()}"
)
raise Exception(f"下载文件失败: {resp.status}") raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]: async def run(self) -> Awaitable[Any]:

View File

@@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent):
"text": {"type": "mrkdwn", "text": "文件上传失败"}, "text": {"type": "mrkdwn", "text": "文件上传失败"},
} }
file_url = response["files"][0]["permalink"] file_url = response["files"][0]["permalink"]
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}} return {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
},
}
else: else:
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}} return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}

View File

@@ -66,7 +66,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
return chunks return chunks
@classmethod @classmethod
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str): async def send_with_client(
cls, client: ExtBot, message: MessageChain, user_name: str
):
image_path = None image_path = None
has_reply = False has_reply = False

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
class WebChatQueueMgr: class WebChatQueueMgr:
def __init__(self) -> None: def __init__(self) -> None:
self.queues = {} self.queues = {}
@@ -30,4 +31,5 @@ class WebChatQueueMgr:
"""Check if a queue exists for the given conversation ID""" """Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues return conversation_id in self.queues
webchat_queue_mgr = WebChatQueueMgr() webchat_queue_mgr = WebChatQueueMgr()

View File

@@ -234,7 +234,9 @@ class WeChatPadProAdapter(Platform):
try: try:
async with session.post(url, params=params, json=payload) as response: async with session.post(url, params=params, json=payload) as response:
if response.status != 200: if response.status != 200:
logger.error(f"生成授权码失败: {response.status}, {await response.text()}") logger.error(
f"生成授权码失败: {response.status}, {await response.text()}"
)
return return
response_data = await response.json() response_data = await response.json()
@@ -245,7 +247,9 @@ class WeChatPadProAdapter(Platform):
if self.auth_key: if self.auth_key:
logger.info("成功获取授权码") logger.info("成功获取授权码")
else: else:
logger.error(f"生成授权码成功但未找到授权码: {response_data}") logger.error(
f"生成授权码成功但未找到授权码: {response_data}"
)
else: else:
logger.error(f"生成授权码失败: {response_data}") logger.error(f"生成授权码失败: {response_data}")
except aiohttp.ClientConnectorError as e: except aiohttp.ClientConnectorError as e:

View File

@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。 注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。
:return: 接口调用结果 :return: 接口调用结果
""" """
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid} data = {
"token": token,
"cursor": cursor,
"limit": limit,
"open_kfid": open_kfid,
}
return self._post("kf/sync_msg", data=data) return self._post("kf/sync_msg", data=data)
def get_service_state(self, open_kfid, external_userid): def get_service_state(self, open_kfid, external_userid):
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
} }
return self._post("kf/service_state/get", data=data) return self._post("kf/service_state/get", data=data)
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""): def trans_service_state(
self, open_kfid, external_userid, service_state, servicer_userid=""
):
""" """
变更会话状态 变更会话状态
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
""" """
return self._get("kf/customer/get_upgrade_service_config") return self._get("kf/customer/get_upgrade_service_config")
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None): def upgrade_service(
self, open_kfid, external_userid, service_type, member=None, groupchat=None
):
""" """
为客户升级为专员或客户群服务 为客户升级为专员或客户群服务
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time} data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
return self._post("kf/get_corp_statistic", data=data) return self._post("kf/get_corp_statistic", data=data)
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None): def get_servicer_statistic(
self, start_time, end_time, open_kfid=None, servicer_userid=None
):
""" """
获取「客户数据统计」接待人员明细数据 获取「客户数据统计」接待人员明细数据

View File

@@ -26,6 +26,7 @@ from optionaldict import optionaldict
from wechatpy.client.api.base import BaseWeChatAPI from wechatpy.client.api.base import BaseWeChatAPI
class WeChatKFMessage(BaseWeChatAPI): class WeChatKFMessage(BaseWeChatAPI):
""" """
发送微信客服消息 发送微信客服消息
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
msg={"msgtype": "news", "link": {"link": articles_data}}, msg={"msgtype": "news", "link": {"link": articles_data}},
) )
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""): def send_msgmenu(
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
):
return self.send( return self.send(
user_id, user_id,
open_kfid, open_kfid,
msgid, msgid,
msg={ msg={
"msgtype": "msgmenu", "msgtype": "msgmenu",
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content}, "msgmenu": {
"head_content": head_content,
"list": menu_list,
"tail_content": tail_content,
},
}, },
) )
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""): def send_location(
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
):
return self.send( return self.send(
user_id, user_id,
open_kfid, open_kfid,
msgid, msgid,
msg={ msg={
"msgtype": "location", "msgtype": "location",
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude}, "msgmenu": {
"name": name,
"address": address,
"latitude": latitude,
"longitude": longitude,
},
}, },
) )
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""): def send_miniprogram(
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
):
return self.send( return self.send(
user_id, user_id,
open_kfid, open_kfid,
msgid, msgid,
msg={ msg={
"msgtype": "miniprogram", "msgtype": "miniprogram",
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath}, "msgmenu": {
"appid": appid,
"title": title,
"thumb_media_id": thumb_media_id,
"pagepath": pagepath,
},
}, },
) )

View File

@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.wexin_event_workers[msg.id] = future self.wexin_event_workers[msg.id] = future
await self.convert_message(msg, future) await self.convert_message(msg, future)
# I love shield so much! # I love shield so much!
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s result = await asyncio.wait_for(
asyncio.shield(future), 60
) # wait for 60s
logger.debug(f"Got future result: {result}") logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None) self.wexin_event_workers.pop(msg.id, None)
return result # xml. see weixin_offacc_event.py return result # xml. see weixin_offacc_event.py

View File

@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
return return
logger.info(f"微信公众平台上传语音返回: {response}") logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode: if active_send_mode:
self.client.message.send_voice( self.client.message.send_voice(
message_obj.sender.user_id, message_obj.sender.user_id,

View File

@@ -2,7 +2,12 @@ from asyncio import Queue
from typing import List, Union from typing import List, Union
from astrbot.core import sp from astrbot.core import sp
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider from astrbot.core.provider.provider import (
Provider,
TTSProvider,
STTProvider,
EmbeddingProvider,
)
from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig

View File

@@ -113,8 +113,7 @@ class CommandGroupFilter(HandlerFilter):
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
) )
raise ValueError( raise ValueError(
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
+ tree
) )
# complete_command_names = [name + " " for name in complete_command_names] # complete_command_names = [name + " " for name in complete_command_names]

View File

@@ -7,6 +7,7 @@ from .star import star_map
T = TypeVar("T", bound="StarHandlerMetadata") T = TypeVar("T", bound="StarHandlerMetadata")
class StarHandlerRegistry(Generic[T]): class StarHandlerRegistry(Generic[T]):
def __init__(self): def __init__(self):
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {} self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
@@ -49,7 +50,8 @@ class StarHandlerRegistry(Generic[T]):
self, module_name: str self, module_name: str
) -> List[StarHandlerMetadata]: ) -> List[StarHandlerMetadata]:
return [ return [
handler for handler in self._handlers handler
for handler in self._handlers
if handler.handler_module_path == module_name if handler.handler_module_path == module_name
] ]
@@ -67,6 +69,7 @@ class StarHandlerRegistry(Generic[T]):
def __len__(self): def __len__(self):
return len(self._handlers) return len(self._handlers)
star_handlers_registry = StarHandlerRegistry() star_handlers_registry = StarHandlerRegistry()

View File

@@ -809,11 +809,11 @@ class PluginManager:
if star_metadata.star_cls is None: if star_metadata.star_cls is None:
return return
if '__del__' in star_metadata.star_cls_type.__dict__: if "__del__" in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor( asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__ None, star_metadata.star_cls.__del__
) )
elif 'terminate' in star_metadata.star_cls_type.__dict__: elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate() await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str): async def turn_on_plugin(self, plugin_name: str):

View File

@@ -182,7 +182,9 @@ class StarTools:
plugin_name = metadata.name plugin_name = metadata.name
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)) data_dir = Path(
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
)
try: try:
data_dir.mkdir(parents=True, exist_ok=True) data_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -56,9 +56,7 @@ class AstrBotUpdator(RepoZipUpdator):
try: try:
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
if os.name == "nt": if os.name == "nt":
args = [ args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
]
else: else:
args = sys.argv[1:] args = sys.argv[1:]
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args) os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)

View File

@@ -5,6 +5,7 @@ from .astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT") _VT = TypeVar("_VT")
class SharedPreferences: class SharedPreferences:
def __init__(self, path=None): def __init__(self, path=None):
if path is None: if path is None:

View File

@@ -210,11 +210,16 @@ class ConfigRoute(Route):
response = await asyncio.wait_for( response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0 provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
) )
logger.debug(f"Received response from {status_info['name']}: {response}") logger.debug(
f"Received response from {status_info['name']}: {response}"
)
if response is not None: if response is not None:
status_info["status"] = "available" status_info["status"] = "available"
response_text_snippet = "" response_text_snippet = ""
if hasattr(response, "completion_text") and response.completion_text: if (
hasattr(response, "completion_text")
and response.completion_text
):
response_text_snippet = ( response_text_snippet = (
response.completion_text[:70] + "..." response.completion_text[:70] + "..."
if len(response.completion_text) > 70 if len(response.completion_text) > 70
@@ -233,29 +238,48 @@ class ConfigRoute(Route):
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'" f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
) )
else: else:
status_info["error"] = "Test call returned None, but expected an LLMResponse object." status_info["error"] = (
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.") "Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
status_info["error"] = "Connection timed out after 45 seconds during test call." status_info["error"] = (
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.") "Connection timed out after 45 seconds during test call."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
)
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
status_info["error"] = error_message status_info["error"] = error_message
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}") logger.warning(
logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}") f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
)
elif provider_capability_type == ProviderType.EMBEDDING: elif provider_capability_type == ProviderType.EMBEDDING:
try: try:
# For embedding, we can call the get_embedding method with a short prompt. # For embedding, we can call the get_embedding method with a short prompt.
embedding_result = await provider.get_embedding("health_check") embedding_result = await provider.get_embedding("health_check")
if isinstance(embedding_result, list) and (not embedding_result or isinstance(embedding_result[0], float)): if isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
):
status_info["status"] = "available" status_info["status"] = "available"
else: else:
status_info["status"] = "unavailable" status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: unexpected result type {type(embedding_result)}" status_info["error"] = (
f"Embedding test failed: unexpected result type {type(embedding_result)}"
)
except Exception as e: except Exception as e:
logger.error(f"Error testing embedding provider {provider_name}: {e}", exc_info=True) logger.error(
f"Error testing embedding provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable" status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: {str(e)}" status_info["error"] = f"Embedding test failed: {str(e)}"
@@ -267,41 +291,71 @@ class ConfigRoute(Route):
status_info["status"] = "available" status_info["status"] = "available"
else: else:
status_info["status"] = "unavailable" status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: unexpected result type {type(audio_result)}" status_info["error"] = (
f"TTS test failed: unexpected result type {type(audio_result)}"
)
except Exception as e: except Exception as e:
logger.error(f"Error testing TTS provider {provider_name}: {e}", exc_info=True) logger.error(
f"Error testing TTS provider {provider_name}: {e}", exc_info=True
)
status_info["status"] = "unavailable" status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: {str(e)}" status_info["error"] = f"TTS test failed: {str(e)}"
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT: elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
try: try:
logger.debug(f"Sending health check audio to provider: {status_info['name']}") logger.debug(
sample_audio_path = os.path.join(get_astrbot_path(), "samples", "stt_health_check.wav") f"Sending health check audio to provider: {status_info['name']}"
)
sample_audio_path = os.path.join(
get_astrbot_path(), "samples", "stt_health_check.wav"
)
if not os.path.exists(sample_audio_path): if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable" status_info["status"] = "unavailable"
status_info["error"] = "STT test failed: sample audio file not found." status_info["error"] = (
logger.warning(f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}") "STT test failed: sample audio file not found."
)
logger.warning(
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}"
)
else: else:
text_result = await provider.get_text(sample_audio_path) text_result = await provider.get_text(sample_audio_path)
if isinstance(text_result, str) and text_result: if isinstance(text_result, str) and text_result:
status_info["status"] = "available" status_info["status"] = "available"
snippet = text_result[:70] + "..." if len(text_result) > 70 else text_result snippet = (
logger.info(f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'") text_result[:70] + "..."
if len(text_result) > 70
else text_result
)
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'"
)
else: else:
status_info["status"] = "unavailable" status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: unexpected result type {type(text_result)}" status_info["error"] = (
logger.warning(f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}") f"STT test failed: unexpected result type {type(text_result)}"
)
logger.warning(
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}"
)
except Exception as e: except Exception as e:
logger.error(f"Error testing STT provider {provider_name}: {e}", exc_info=True) logger.error(
f"Error testing STT provider {provider_name}: {e}", exc_info=True
)
status_info["status"] = "unavailable" status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {str(e)}" status_info["error"] = f"STT test failed: {str(e)}"
else: else:
logger.debug(f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}") logger.debug(
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}"
)
status_info["status"] = "available" status_info["status"] = "available"
status_info["error"] = "This provider type is not tested and is assumed to be available." status_info["error"] = (
"This provider type is not tested and is assumed to be available."
)
return status_info return status_info
def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error): def _error_response(
self, message: str, status_code: int = 500, log_fn=logger.error
):
log_fn(message) log_fn(message)
# 记录更详细的traceback信息但只在是严重错误时 # 记录更详细的traceback信息但只在是严重错误时
if status_code == 500: if status_code == 500:
@@ -312,7 +366,9 @@ class ConfigRoute(Route):
"""API: check a single LLM Provider's status by id""" """API: check a single LLM Provider's status by id"""
provider_id = request.args.get("id") provider_id = request.args.get("id")
if not provider_id: if not provider_id:
return self._error_response("Missing provider_id parameter", 400, logger.warning) return self._error_response(
"Missing provider_id parameter", 400, logger.warning
)
logger.info(f"API call: /config/provider/check_one id={provider_id}") logger.info(f"API call: /config/provider/check_one id={provider_id}")
try: try:
@@ -320,16 +376,21 @@ class ConfigRoute(Route):
target = prov_mgr.inst_map.get(provider_id) target = prov_mgr.inst_map.get(provider_id)
if not target: if not target:
logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.") logger.warning(
return Response().error(f"Provider with id '{provider_id}' not found").__dict__ f"Provider with id '{provider_id}' not found in provider_manager."
)
return (
Response()
.error(f"Provider with id '{provider_id}' not found")
.__dict__
)
result = await self._test_single_provider(target) result = await self._test_single_provider(target)
return Response().ok(result).__dict__ return Response().ok(result).__dict__
except Exception as e: except Exception as e:
return self._error_response( return self._error_response(
f"Critical error checking provider {provider_id}: {e}", f"Critical error checking provider {provider_id}: {e}", 500
500
) )
async def get_configs(self): async def get_configs(self):

View File

@@ -10,7 +10,9 @@ class LogRoute(Route):
super().__init__(context) super().__init__(context)
self.log_broker = log_broker self.log_broker = log_broker
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"]) self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
self.app.add_url_rule("/api/log-history", view_func=self.log_history, methods=["GET"]) self.app.add_url_rule(
"/api/log-history", view_func=self.log_history, methods=["GET"]
)
async def log(self): async def log(self):
async def stream(): async def stream():
@@ -48,9 +50,15 @@ class LogRoute(Route):
"""获取日志历史""" """获取日志历史"""
try: try:
logs = list(self.log_broker.log_cache) logs = list(self.log_broker.log_cache)
return Response().ok(data={ return (
Response()
.ok(
data={
"logs": logs, "logs": logs,
}).__dict__ }
)
.__dict__
)
except BaseException as e: except BaseException as e:
logger.error(f"获取日志历史失败: {e}") logger.error(f"获取日志历史失败: {e}")
return Response().error(f"获取日志历史失败: {e}").__dict__ return Response().error(f"获取日志历史失败: {e}").__dict__