Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
817ef94742 Implement Telegram topic isolation feature - share conversations across topics
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:30:18 +00:00
copilot-swe-agent[bot]
10e91fe375 Initial plan 2025-08-16 16:18:47 +00:00
31 changed files with 153 additions and 323 deletions

View File

@@ -124,17 +124,15 @@ def build_plug_list(plugins_dir: Path) -> list:
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
result.append(
{
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
}
)
result.append({
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
})
# 获取在线插件列表
online_plugins = []
@@ -144,17 +142,15 @@ def build_plug_list(plugins_dir: Path) -> list:
resp.raise_for_status()
data = resp.json()
for plugin_id, plugin_info in data.items():
online_plugins.append(
{
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
}
)
online_plugins.append({
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
})
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)

View File

@@ -65,7 +65,7 @@ DEFAULT_CONFIG = {
"show_tool_use_status": False,
"streaming_segmented": False,
"separate_provider": True,
"max_agent_step": 30,
"max_agent_step": 30
},
"provider_stt_settings": {
"enable": False,
@@ -598,8 +598,11 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"hint": "也兼容所有与OpenAI API兼容的服务。",
"model_config": {
"model": "gpt-4o-mini",
"temperature": 0.4
},
"hint": "也兼容所有与OpenAI API兼容的服务。"
},
"Azure OpenAI": {
"id": "azure",
@@ -611,7 +614,10 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"model_config": {
"model": "gpt-4o-mini",
"temperature": 0.4
},
},
"xAI": {
"id": "xai",
@@ -622,7 +628,10 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
"model_config": {
"model": "grok-2-latest",
"temperature": 0.4
},
},
"Anthropic": {
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
@@ -637,11 +646,11 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "claude-3-5-sonnet-latest",
"max_tokens": 4096,
"temperature": 0.2,
"temperature": 0.2
},
},
"Ollama": {
"hint": "启用前请确保已正确安装并运行 Ollama 服务端Ollama默认不带鉴权无需修改key",
"hint":"启用前请确保已正确安装并运行 Ollama 服务端Ollama默认不带鉴权无需修改key",
"id": "ollama_default",
"provider": "ollama",
"type": "openai_chat_completion",
@@ -649,7 +658,10 @@ CONFIG_METADATA_2 = {
"enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1",
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"model_config": {
"model": "llama3.1-8b",
"temperature": 0.4
},
},
"LM Studio": {
"id": "lm_studio",
@@ -674,9 +686,8 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
"temperature": 0.4,
"temperature": 0.4
},
"enable_google_search": False,
},
"Gemini": {
"id": "gemini_default",
@@ -689,7 +700,7 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"model_config": {
"model": "gemini-2.0-flash-exp",
"temperature": 0.4,
"temperature": 0.4
},
"gm_resp_image_modal": False,
"gm_native_search": False,
@@ -714,7 +725,10 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.deepseek.com/v1",
"timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
"model_config": {
"model": "deepseek-chat",
"temperature": 0.4
},
},
"302.AI": {
"id": "302ai",
@@ -725,7 +739,10 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.302.ai/v1",
"timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"model_config": {
"model": "gpt-4.1-mini",
"temperature": 0.4
},
},
"硅基流动": {
"id": "siliconflow",
@@ -738,7 +755,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.siliconflow.cn/v1",
"model_config": {
"model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4,
"temperature": 0.4
},
},
"PPIO派欧云": {
@@ -752,7 +769,7 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"model_config": {
"model": "deepseek/deepseek-r1",
"temperature": 0.4,
"temperature": 0.4
},
},
"优云智算": {
@@ -777,7 +794,10 @@ CONFIG_METADATA_2 = {
"key": [],
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"model_config": {
"model": "moonshot-v1-8k",
"temperature": 0.4
},
},
"智谱 AI": {
"id": "zhipu_default",
@@ -805,7 +825,7 @@ CONFIG_METADATA_2 = {
"dify_query_input_key": "astrbot_text_query",
"variables": {},
"timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!"
},
"阿里云百炼应用": {
"id": "dashscope",
@@ -833,7 +853,10 @@ CONFIG_METADATA_2 = {
"key": [],
"timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"model_config": {
"model": "Qwen/Qwen3-32B",
"temperature": 0.4
},
},
"FastGPT": {
"id": "fastgpt",
@@ -1577,11 +1600,6 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
},
"enable_google_search": {
"description": "启用 Google 搜索",
"type": "bool",
"hint": "启用后Gemini(OpenAI兼容) 提供商将支持 googleSearch 函数工具调用,用于网页搜索功能。",
},
"model_config": {
"description": "模型配置",
"type": "object",

View File

@@ -1,3 +1,3 @@
from .vec_db import FaissVecDB
__all__ = ["FaissVecDB"]
__all__ = ["FaissVecDB"]

View File

@@ -128,7 +128,6 @@ class Plain(BaseMessageComponent):
async def to_dict(self):
return {"type": "text", "data": {"text": self.text}}
class Face(BaseMessageComponent):
type: ComponentType = "Face"
id: int

View File

@@ -8,11 +8,10 @@ from enum import Enum, auto
class AgentState(Enum):
"""Agent 状态枚举"""
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
class AgentResponseData(T.TypedDict):

View File

@@ -57,7 +57,6 @@ class DingtalkMessageEvent(AstrMessageEvent):
logger.error(f"钉钉图片处理失败: {e}")
logger.warning(f"跳过图片发送: {image_path}")
continue
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)

View File

@@ -41,8 +41,7 @@ class DiscordBotClient(discord.Bot):
await self.on_ready_once_callback()
except Exception as e:
logger.error(
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
)
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
@@ -91,6 +90,7 @@ class DiscordBotClient(discord.Bot):
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
interaction_type = interaction.type

View File

@@ -79,12 +79,9 @@ class DiscordButton(BaseMessageComponent):
self.url = url
self.disabled = disabled
class DiscordReference(BaseMessageComponent):
"""Discord引用组件"""
type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str):
self.message_id = message_id
self.channel_id = channel_id
@@ -101,6 +98,7 @@ class DiscordView(BaseMessageComponent):
self.components = components or []
self.timeout = timeout
def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout)

View File

@@ -53,13 +53,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
# 解析消息链为 Discord 所需的对象
try:
(
content,
files,
view,
embeds,
reference_message_id,
) = await self._parse_to_discord(message)
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
except Exception as e:
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
return
@@ -212,7 +206,8 @@ class DiscordPlatformEvent(AstrMessageEvent):
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
files.append(
discord.File(BytesIO(file_bytes), filename=i.name)
discord.File(BytesIO(file_bytes),
filename=i.name)
)
else:
logger.warning(

View File

@@ -308,9 +308,7 @@ class SlackAdapter(Platform):
base64_content = base64.b64encode(content).decode("utf-8")
return base64_content
else:
logger.error(
f"Failed to download slack file: {resp.status} {await resp.text()}"
)
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]:

View File

@@ -75,13 +75,7 @@ class SlackMessageEvent(AstrMessageEvent):
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
file_url = response["files"][0]["permalink"]
return {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
},
}
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
else:
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}

View File

@@ -223,9 +223,10 @@ class TelegramPlatformAdapter(Platform):
message.type = MessageType.GROUP_MESSAGE
message.group_id = str(update.message.chat.id)
if update.message.message_thread_id:
# Topic Group
# Topic Group - keep topic info in group_id for message routing,
# but use base group_id for session_id to share conversations
message.group_id += "#" + str(update.message.message_thread_id)
message.session_id = message.group_id
# session_id remains the base group_id for shared conversation
message.message_id = str(update.message.message_id)
message.sender = MessageMember(
@@ -349,6 +350,13 @@ class TelegramPlatformAdapter(Platform):
session_id=message.session_id,
client=self.client,
)
# Store topic information for shared conversations
if message.type == MessageType.GROUP_MESSAGE and "#" in message.group_id:
group_id, topic_id = message.group_id.split("#")
message_event.set_extra("telegram_topic_id", topic_id)
message_event.set_extra("telegram_base_group_id", group_id)
self.commit_event(message_event)
def get_client(self) -> ExtBot:

View File

@@ -150,7 +150,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name,
}
if message_thread_id:
payload["reply_to_message_id"] = message_thread_id
payload["message_thread_id"] = message_thread_id
delta = ""
current_content = ""

View File

@@ -1,6 +1,5 @@
import asyncio
class WebChatQueueMgr:
def __init__(self) -> None:
self.queues = {}
@@ -31,5 +30,4 @@ class WebChatQueueMgr:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
webchat_queue_mgr = WebChatQueueMgr()

View File

@@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform):
def _extract_auth_key(self, data):
"""Helper method to extract auth_key from response data."""
if isinstance(data, dict):
auth_keys = data.get("authKeys") # 新接口
auth_keys = data.get("authKeys") # 新接口
if isinstance(auth_keys, list) and auth_keys:
return auth_keys[0]
elif isinstance(data, list) and data: # 旧接口
elif isinstance(data, list) and data: # 旧接口
return data[0]
return None
@@ -234,9 +234,7 @@ class WeChatPadProAdapter(Platform):
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(
f"生成授权码失败: {response.status}, {await response.text()}"
)
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
return
response_data = await response.json()
@@ -247,9 +245,7 @@ class WeChatPadProAdapter(Platform):
if self.auth_key:
logger.info("成功获取授权码")
else:
logger.error(
f"生成授权码成功但未找到授权码: {response_data}"
)
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
else:
logger.error(f"生成授权码失败: {response_data}")
except aiohttp.ClientConnectorError as e:

View File

@@ -48,12 +48,7 @@ class WeChatKF(BaseWeChatAPI):
注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。
:return: 接口调用结果
"""
data = {
"token": token,
"cursor": cursor,
"limit": limit,
"open_kfid": open_kfid,
}
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
return self._post("kf/sync_msg", data=data)
def get_service_state(self, open_kfid, external_userid):
@@ -77,9 +72,7 @@ class WeChatKF(BaseWeChatAPI):
}
return self._post("kf/service_state/get", data=data)
def trans_service_state(
self, open_kfid, external_userid, service_state, servicer_userid=""
):
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
"""
变更会话状态
@@ -187,9 +180,7 @@ class WeChatKF(BaseWeChatAPI):
"""
return self._get("kf/customer/get_upgrade_service_config")
def upgrade_service(
self, open_kfid, external_userid, service_type, member=None, groupchat=None
):
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
"""
为客户升级为专员或客户群服务
@@ -255,9 +246,7 @@ class WeChatKF(BaseWeChatAPI):
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
return self._post("kf/get_corp_statistic", data=data)
def get_servicer_statistic(
self, start_time, end_time, open_kfid=None, servicer_userid=None
):
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
"""
获取「客户数据统计」接待人员明细数据

View File

@@ -26,7 +26,6 @@ from optionaldict import optionaldict
from wechatpy.client.api.base import BaseWeChatAPI
class WeChatKFMessage(BaseWeChatAPI):
"""
发送微信客服消息
@@ -126,55 +125,35 @@ class WeChatKFMessage(BaseWeChatAPI):
msg={"msgtype": "news", "link": {"link": articles_data}},
)
def send_msgmenu(
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
):
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "msgmenu",
"msgmenu": {
"head_content": head_content,
"list": menu_list,
"tail_content": tail_content,
},
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
},
)
def send_location(
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
):
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "location",
"msgmenu": {
"name": name,
"address": address,
"latitude": latitude,
"longitude": longitude,
},
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
},
)
def send_miniprogram(
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
):
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "miniprogram",
"msgmenu": {
"appid": appid,
"title": title,
"thumb_media_id": thumb_media_id,
"pagepath": pagepath,
},
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
},
)

View File

@@ -160,9 +160,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.wexin_event_workers[msg.id] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(
asyncio.shield(future), 60
) # wait for 60s
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
return result # xml. see weixin_offacc_event.py

View File

@@ -150,6 +150,7 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
return
logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,

View File

@@ -81,17 +81,6 @@ class ProviderOpenAIOfficial(Provider):
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
# Check if we need to add googleSearch function for Gemini(OpenAI Compatible)
if (
self.provider_config.get("enable_google_search", False)
and self.provider_config.get("api_base", "").find(
"generativelanguage.googleapis.com"
)
!= -1
):
# Add googleSearch function as alias to web_search
await self._add_google_search_tool(tools)
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
tool_list = tools.get_func_desc_openai_style(
@@ -135,17 +124,6 @@ class ProviderOpenAIOfficial(Provider):
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API逐步返回结果"""
if tools:
# Check if we need to add googleSearch function for Gemini(OpenAI Compatible)
if (
self.provider_config.get("enable_google_search", False)
and self.provider_config.get("api_base", "").find(
"generativelanguage.googleapis.com"
)
!= -1
):
# Add googleSearch function as alias to web_search
await self._add_google_search_tool(tools)
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
tool_list = tools.get_func_desc_openai_style(
@@ -575,35 +553,3 @@ class ProviderOpenAIOfficial(Provider):
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
async def _add_google_search_tool(self, tools: FuncCall) -> None:
"""Add googleSearch function as an alias to web_search for Gemini(OpenAI Compatible)"""
# Check if googleSearch is already added
for func in tools.func_list:
if func.name == "googleSearch":
return
# Check if web_search exists
web_search_func = None
for func in tools.func_list:
if func.name == "web_search":
web_search_func = func
break
if web_search_func is None:
# If web_search is not available, don't add googleSearch
return
# Add googleSearch as an alias to web_search with English description
tools.add_func(
name="googleSearch",
func_args=[
{
"type": "string",
"name": "query",
"description": "The most relevant search keywords for the user's question, used to search on Google.",
}
],
desc="Search the internet to answer user questions using Google search. Call this tool when users need to search the web for real-time information.",
handler=web_search_func.handler,
)

View File

@@ -2,12 +2,7 @@ from asyncio import Queue
from typing import List, Union
from astrbot.core import sp
from astrbot.core.provider.provider import (
Provider,
TTSProvider,
STTProvider,
EmbeddingProvider,
)
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig

View File

@@ -113,7 +113,8 @@ class CommandGroupFilter(HandlerFilter):
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
)
raise ValueError(
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
+ tree
)
# complete_command_names = [name + " " for name in complete_command_names]

View File

@@ -7,7 +7,6 @@ from .star import star_map
T = TypeVar("T", bound="StarHandlerMetadata")
class StarHandlerRegistry(Generic[T]):
def __init__(self):
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
@@ -50,8 +49,7 @@ class StarHandlerRegistry(Generic[T]):
self, module_name: str
) -> List[StarHandlerMetadata]:
return [
handler
for handler in self._handlers
handler for handler in self._handlers
if handler.handler_module_path == module_name
]
@@ -69,7 +67,6 @@ class StarHandlerRegistry(Generic[T]):
def __len__(self):
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry()

View File

@@ -809,11 +809,11 @@ class PluginManager:
if star_metadata.star_cls is None:
return
if "__del__" in star_metadata.star_cls_type.__dict__:
if '__del__' in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
)
elif "terminate" in star_metadata.star_cls_type.__dict__:
elif 'terminate' in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str):

View File

@@ -182,9 +182,7 @@ class StarTools:
plugin_name = metadata.name
data_dir = Path(
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
)
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
try:
data_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -56,7 +56,9 @@ class AstrBotUpdator(RepoZipUpdator):
try:
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
if os.name == "nt":
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
args = [
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
]
else:
args = sys.argv[1:]
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)

View File

@@ -5,7 +5,6 @@ from .astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:

View File

@@ -210,16 +210,11 @@ class ConfigRoute(Route):
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
)
logger.debug(
f"Received response from {status_info['name']}: {response}"
)
logger.debug(f"Received response from {status_info['name']}: {response}")
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if (
hasattr(response, "completion_text")
and response.completion_text
):
if hasattr(response, "completion_text") and response.completion_text:
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
@@ -238,48 +233,29 @@ class ConfigRoute(Route):
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
)
else:
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
)
status_info["error"] = "Test call returned None, but expected an LLMResponse object."
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.")
except asyncio.TimeoutError:
status_info["error"] = (
"Connection timed out after 45 seconds during test call."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
)
status_info["error"] = "Connection timed out after 45 seconds during test call."
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.")
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
)
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}")
logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}")
elif provider_capability_type == ProviderType.EMBEDDING:
try:
# For embedding, we can call the get_embedding method with a short prompt.
embedding_result = await provider.get_embedding("health_check")
if isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
):
if isinstance(embedding_result, list) and (not embedding_result or isinstance(embedding_result[0], float)):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"Embedding test failed: unexpected result type {type(embedding_result)}"
)
status_info["error"] = f"Embedding test failed: unexpected result type {type(embedding_result)}"
except Exception as e:
logger.error(
f"Error testing embedding provider {provider_name}: {e}",
exc_info=True,
)
logger.error(f"Error testing embedding provider {provider_name}: {e}", exc_info=True)
status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: {str(e)}"
@@ -291,71 +267,41 @@ class ConfigRoute(Route):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"TTS test failed: unexpected result type {type(audio_result)}"
)
status_info["error"] = f"TTS test failed: unexpected result type {type(audio_result)}"
except Exception as e:
logger.error(
f"Error testing TTS provider {provider_name}: {e}", exc_info=True
)
logger.error(f"Error testing TTS provider {provider_name}: {e}", exc_info=True)
status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: {str(e)}"
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
try:
logger.debug(
f"Sending health check audio to provider: {status_info['name']}"
)
sample_audio_path = os.path.join(
get_astrbot_path(), "samples", "stt_health_check.wav"
)
logger.debug(f"Sending health check audio to provider: {status_info['name']}")
sample_audio_path = os.path.join(get_astrbot_path(), "samples", "stt_health_check.wav")
if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable"
status_info["error"] = (
"STT test failed: sample audio file not found."
)
logger.warning(
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}"
)
status_info["error"] = "STT test failed: sample audio file not found."
logger.warning(f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}")
else:
text_result = await provider.get_text(sample_audio_path)
if isinstance(text_result, str) and text_result:
status_info["status"] = "available"
snippet = (
text_result[:70] + "..."
if len(text_result) > 70
else text_result
)
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'"
)
snippet = text_result[:70] + "..." if len(text_result) > 70 else text_result
logger.info(f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'")
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"STT test failed: unexpected result type {type(text_result)}"
)
logger.warning(
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}"
)
status_info["error"] = f"STT test failed: unexpected result type {type(text_result)}"
logger.warning(f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}")
except Exception as e:
logger.error(
f"Error testing STT provider {provider_name}: {e}", exc_info=True
)
logger.error(f"Error testing STT provider {provider_name}: {e}", exc_info=True)
status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {str(e)}"
else:
logger.debug(
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}"
)
logger.debug(f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}")
status_info["status"] = "available"
status_info["error"] = (
"This provider type is not tested and is assumed to be available."
)
status_info["error"] = "This provider type is not tested and is assumed to be available."
return status_info
def _error_response(
self, message: str, status_code: int = 500, log_fn=logger.error
):
def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error):
log_fn(message)
# 记录更详细的traceback信息但只在是严重错误时
if status_code == 500:
@@ -366,9 +312,7 @@ class ConfigRoute(Route):
"""API: check a single LLM Provider's status by id"""
provider_id = request.args.get("id")
if not provider_id:
return self._error_response(
"Missing provider_id parameter", 400, logger.warning
)
return self._error_response("Missing provider_id parameter", 400, logger.warning)
logger.info(f"API call: /config/provider/check_one id={provider_id}")
try:
@@ -376,21 +320,16 @@ class ConfigRoute(Route):
target = prov_mgr.inst_map.get(provider_id)
if not target:
logger.warning(
f"Provider with id '{provider_id}' not found in provider_manager."
)
return (
Response()
.error(f"Provider with id '{provider_id}' not found")
.__dict__
)
logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.")
return Response().error(f"Provider with id '{provider_id}' not found").__dict__
result = await self._test_single_provider(target)
return Response().ok(result).__dict__
except Exception as e:
return self._error_response(
f"Critical error checking provider {provider_id}: {e}", 500
f"Critical error checking provider {provider_id}: {e}",
500
)
async def get_configs(self):

View File

@@ -10,9 +10,7 @@ class LogRoute(Route):
super().__init__(context)
self.log_broker = log_broker
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
self.app.add_url_rule(
"/api/log-history", view_func=self.log_history, methods=["GET"]
)
self.app.add_url_rule("/api/log-history", view_func=self.log_history, methods=["GET"])
async def log(self):
async def stream():
@@ -50,15 +48,9 @@ class LogRoute(Route):
"""获取日志历史"""
try:
logs = list(self.log_broker.log_cache)
return (
Response()
.ok(
data={
"logs": logs,
}
)
.__dict__
)
return Response().ok(data={
"logs": logs,
}).__dict__
except BaseException as e:
logger.error(f"获取日志历史失败: {e}")
return Response().error(f"获取日志历史失败: {e}").__dict__

View File

@@ -1316,6 +1316,16 @@ UID: {user_id} 此 ID 可用于设置管理员。
except BaseException as e:
logger.error(f"处理引用图片失败: {e}")
# Add Telegram topic context for shared conversations
if event.get_platform_name() == "telegram":
topic_id = event.get_extra("telegram_topic_id")
base_group_id = event.get_extra("telegram_base_group_id")
if topic_id and base_group_id:
req.system_prompt += (
f"\n[Context: This message is from Telegram topic "
f"#{topic_id} in group {base_group_id}]\n"
)
if self.ltm:
try:
await self.ltm.on_req_llm(event, req)

View File

@@ -30,11 +30,9 @@ class Main(star.Star):
websearch = self.context.get_config()["provider_settings"]["web_search"]
if websearch:
self.context.activate_llm_tool("web_search")
self.context.activate_llm_tool("googleSearch")
self.context.activate_llm_tool("fetch_url")
else:
self.context.deactivate_llm_tool("web_search")
self.context.deactivate_llm_tool("googleSearch")
self.context.deactivate_llm_tool("fetch_url")
async def _tidy_text(self, text: str) -> str:
@@ -72,14 +70,12 @@ class Main(star.Star):
self.context.get_config()["provider_settings"]["web_search"] = True
self.context.get_config().save_config()
self.context.activate_llm_tool("web_search")
self.context.activate_llm_tool("googleSearch")
self.context.activate_llm_tool("fetch_url")
event.set_result(MessageEventResult().message("已开启网页搜索功能"))
elif oper == "off":
self.context.get_config()["provider_settings"]["web_search"] = False
self.context.get_config().save_config()
self.context.deactivate_llm_tool("web_search")
self.context.deactivate_llm_tool("googleSearch")
self.context.deactivate_llm_tool("fetch_url")
event.set_result(MessageEventResult().message("已关闭网页搜索功能"))
else:
@@ -143,16 +139,6 @@ class Main(star.Star):
return ret
@llm_tool("googleSearch")
async def google_search_alias(self, event: AstrMessageEvent, query: str) -> str:
"""Search the internet to answer user questions using Google search. Call this tool when users need to search the web for real-time information.
Args:
query(string): The most relevant search keywords for the user's question, used to search on Google.
"""
# This is an alias for web_search to provide better OpenAI API compatibility
return await self.search_from_search_engine(event, query)
@llm_tool("fetch_url")
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
"""fetch the content of a website with the given web url