Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b77df0556 | ||
|
|
1fa11062de | ||
|
|
6883de0f1c | ||
|
|
bdde0fe094 | ||
|
|
ab22b8103e | ||
|
|
641d5cd67b |
@@ -65,6 +65,11 @@ class AstrBotBootstrap():
|
||||
self.context.plugin_updator = self.plugin_manager.updator
|
||||
self.context.message_handler = self.message_handler
|
||||
self.context.command_manager = self.command_manager
|
||||
|
||||
|
||||
# load dashboard
|
||||
self.dashboard.run_http_server()
|
||||
dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
|
||||
|
||||
if self.test_mode:
|
||||
return
|
||||
@@ -77,9 +82,7 @@ class AstrBotBootstrap():
|
||||
platform_tasks = self.load_platform()
|
||||
# load metrics uploader
|
||||
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader")
|
||||
# load dashboard
|
||||
self.dashboard.run_http_server()
|
||||
dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
|
||||
|
||||
tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks]
|
||||
tasks = [self.handle_task(task) for task in tasks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@@ -60,6 +60,9 @@ class ContentSafetyHelper():
|
||||
from astrbot.message.baidu_aip_judge import BaiduJudge
|
||||
self.baidu_judge = BaiduJudge(aip)
|
||||
logger.info("已启用百度 AI 内容审核。")
|
||||
except ImportError as e:
|
||||
logger.error("检测到库依赖不完整,将不会启用百度 AI 内容审核。请先使用 pip 安装 `baidu_aip` 包。")
|
||||
logger.error(e)
|
||||
except BaseException as e:
|
||||
logger.error("百度 AI 内容审核初始化失败。")
|
||||
logger.error(e)
|
||||
|
||||
@@ -15,29 +15,48 @@ class DashBoardHelper():
|
||||
self.context = context
|
||||
self.config_key_dont_show = ['dashboard', 'config_version']
|
||||
|
||||
def try_cast(self, value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
return int(value)
|
||||
elif type_ == "float" and isinstance(value, str) \
|
||||
and value.replace(".", "", 1).isdigit():
|
||||
return float(value)
|
||||
elif type_ == "float" and isinstance(value, int):
|
||||
return float(value)
|
||||
|
||||
|
||||
def validate_config(self, data):
|
||||
errors = []
|
||||
# 递归验证数据
|
||||
def validate(data, path=""):
|
||||
for key, meta in CONFIG_METADATA_2.items():
|
||||
def validate(data, metadata=CONFIG_METADATA_2, path=""):
|
||||
for key, meta in metadata.items():
|
||||
if key not in data:
|
||||
if key not in self.config_key_dont_show:
|
||||
# 这些key不会传给前端,所以不需要验证
|
||||
errors.append(f"Missing key: {path}{key}")
|
||||
continue
|
||||
value = data[key]
|
||||
if meta["type"] == "int" and not isinstance(value, int):
|
||||
errors.append(f"Invalid type for {path}{key}: expected int, got {type(value).__name__}")
|
||||
elif meta["type"] == "bool" and not isinstance(value, bool):
|
||||
errors.append(f"Invalid type for {path}{key}: expected bool, got {type(value).__name__}")
|
||||
elif meta["type"] == "string" and not isinstance(value, str):
|
||||
errors.append(f"Invalid type for {path}{key}: expected string, got {type(value).__name__}")
|
||||
elif meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(f"Invalid type for {path}{key}: expected list, got {type(value).__name__}")
|
||||
# 递归验证
|
||||
if meta["type"] == "list" and isinstance(value, list):
|
||||
for item in value:
|
||||
validate(item, meta["items"], path=f"{path}{key}.")
|
||||
elif meta["type"] == "dict" and not isinstance(value, dict):
|
||||
errors.append(f"Invalid type for {path}{key}: expected dict, got {type(value).__name__}")
|
||||
elif meta["type"] == "object" and isinstance(value, dict):
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
|
||||
if meta["type"] == "int" and not isinstance(value, int):
|
||||
casted = self.try_cast(value, "int")
|
||||
if casted is None:
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}")
|
||||
data[key] = casted
|
||||
elif meta["type"] == "float" and not isinstance(value, float):
|
||||
casted = self.try_cast(value, "float")
|
||||
if casted is None:
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}")
|
||||
data[key] = casted
|
||||
elif meta["type"] == "bool" and not isinstance(value, bool):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "string" and not isinstance(value, str):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "object" and not isinstance(value, dict):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
validate(data)
|
||||
|
||||
@@ -68,6 +87,6 @@ class DashBoardHelper():
|
||||
typ = item['val_type']
|
||||
if typ == 'int':
|
||||
if not value.isdigit():
|
||||
raise ValueError(f"Invalid type for {namespace}.{key}: expected int, got {type(value).__name__}")
|
||||
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
|
||||
value = int(value)
|
||||
update_config(namespace, key, value)
|
||||
|
||||
@@ -206,7 +206,8 @@ class AstrBotDashBoard():
|
||||
repo_url = post_data["url"]
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
self.plugin_manager.install_plugin(repo_url)
|
||||
# self.plugin_manager.install_plugin(repo_url)
|
||||
asyncio.run_coroutine_threadsafe(self.plugin_manager.install_plugin(repo_url), self.loop).result()
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"安装插件 {repo_url} 成功,2秒后重启")
|
||||
return Response(
|
||||
@@ -272,7 +273,8 @@ class AstrBotDashBoard():
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
self.plugin_manager.update_plugin(plugin_name)
|
||||
# self.plugin_manager.update_plugin(plugin_name)
|
||||
asyncio.run_coroutine_threadsafe(self.plugin_manager.update_plugin(plugin_name), self.loop).result()
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
|
||||
return Response(
|
||||
@@ -301,7 +303,9 @@ class AstrBotDashBoard():
|
||||
@self.dashboard_be.get("/api/check_update")
|
||||
def get_update_info():
|
||||
try:
|
||||
ret = self.astrbot_updator.check_update(None, None)
|
||||
# ret = self.astrbot_updator.check_update(None, None)
|
||||
ret = asyncio.run_coroutine_threadsafe(
|
||||
self.astrbot_updator.check_update(None, None), self.loop).result()
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
@@ -326,7 +330,8 @@ class AstrBotDashBoard():
|
||||
else:
|
||||
latest = False
|
||||
try:
|
||||
self.astrbot_updator.update(latest=latest, version=version)
|
||||
# await self.astrbot_updator.update(latest=latest, version=version)
|
||||
asyncio.run_coroutine_threadsafe(self.astrbot_updator.update(latest=latest, version=version), self.loop).result()
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
return Response(
|
||||
status="success",
|
||||
|
||||
3
main.py
3
main.py
@@ -27,6 +27,8 @@ def main():
|
||||
# delete qqbotpy's logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
logger.info(logo_tmpl)
|
||||
|
||||
bootstrap = AstrBotBootstrap()
|
||||
asyncio.run(bootstrap.run())
|
||||
@@ -58,5 +60,4 @@ if __name__ == "__main__":
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
logger.info(logo_tmpl)
|
||||
main()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import aiohttp
|
||||
import aiohttp, os
|
||||
|
||||
from model.command.manager import CommandManager
|
||||
from model.plugin.manager import PluginManager
|
||||
@@ -27,6 +27,13 @@ class InternalCommandHandler:
|
||||
self.manager.register("t2i", "文转图", 10, self.t2i_toggle)
|
||||
self.manager.register("myid", "用户ID", 10, self.myid)
|
||||
self.manager.register("provider", "LLM 接入源", 10, self.provider)
|
||||
|
||||
def _check_auth(self, message: AstrMessageEvent, context: Context):
|
||||
if os.environ.get("TEST_MODE", "off") == "on":
|
||||
return
|
||||
if message.role != "admin":
|
||||
user_id = message.message_obj.sender.user_id
|
||||
raise Exception(f"用户(ID: {user_id}) 没有足够的权限使用该指令。")
|
||||
|
||||
def provider(self, message: AstrMessageEvent, context: Context):
|
||||
if len(context.llms) == 0:
|
||||
@@ -57,9 +64,8 @@ class InternalCommandHandler:
|
||||
return CommandResult().message("provider: 参数错误。")
|
||||
|
||||
def set_nick(self, message: AstrMessageEvent, context: Context):
|
||||
self._check_auth(message, context)
|
||||
message_str = message.message_str
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("你没有权限使用该指令。")
|
||||
l = message_str.split(" ")
|
||||
if len(l) == 1:
|
||||
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词是:{context.config_helper.wake_prefix[0]}")
|
||||
@@ -74,15 +80,10 @@ class InternalCommandHandler:
|
||||
message_chain=f"已经成功将唤醒前缀设定为 {nick}。",
|
||||
)
|
||||
|
||||
def update(self, message: AstrMessageEvent, context: Context):
|
||||
async def update(self, message: AstrMessageEvent, context: Context):
|
||||
self._check_auth(message, context)
|
||||
tokens = self.manager.command_parser.parse(message.message_str)
|
||||
if message.role != "admin":
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain="你没有权限使用该指令",
|
||||
)
|
||||
update_info = context.updator.check_update(None, None)
|
||||
update_info = await context.updator.check_update(None, None)
|
||||
if tokens.len == 1:
|
||||
ret = ""
|
||||
if not update_info:
|
||||
@@ -93,13 +94,13 @@ class InternalCommandHandler:
|
||||
else:
|
||||
if tokens.get(1) == "latest":
|
||||
try:
|
||||
context.updator.update()
|
||||
await context.updator.update()
|
||||
return CommandResult().message(f"已经成功更新到最新版本 v{update_info.version}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启")
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"更新失败。原因:{str(e)}")
|
||||
elif tokens.get(1).startswith("v"):
|
||||
try:
|
||||
context.updator.update(version=tokens.get(1))
|
||||
await context.updator.update(version=tokens.get(1))
|
||||
return CommandResult().message(f"已经成功更新到版本 v{tokens.get(1)}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启")
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"更新失败。原因:{str(e)}")
|
||||
@@ -107,12 +108,7 @@ class InternalCommandHandler:
|
||||
return CommandResult().message("update: 参数错误。")
|
||||
|
||||
def reboot(self, message: AstrMessageEvent, context: Context):
|
||||
if message.role != "admin":
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain="你没有权限使用该指令",
|
||||
)
|
||||
self._check_auth(message, context)
|
||||
context.updator._reboot(3, context)
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
@@ -120,7 +116,7 @@ class InternalCommandHandler:
|
||||
message_chain="AstrBot 将在 3s 后重启。",
|
||||
)
|
||||
|
||||
def plugin(self, message: AstrMessageEvent, context: Context):
|
||||
async def plugin(self, message: AstrMessageEvent, context: Context):
|
||||
tokens = self.manager.command_parser.parse(message.message_str)
|
||||
if tokens.len == 1:
|
||||
ret = "# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n"
|
||||
@@ -133,10 +129,10 @@ class InternalCommandHandler:
|
||||
if plugin_list_info.strip() == "":
|
||||
return CommandResult().message("plugin v: 没有找到插件。")
|
||||
return CommandResult().message(plugin_list_info)
|
||||
|
||||
self._check_auth(message, context)
|
||||
|
||||
elif tokens.get(1) == "d":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin d: 你没有权限使用该指令。")
|
||||
if tokens.get(1) == "d":
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin d: 请指定要卸载的插件名。")
|
||||
plugin_name = tokens.get(2)
|
||||
@@ -147,25 +143,21 @@ class InternalCommandHandler:
|
||||
return CommandResult().message(f"plugin d: 已经成功卸载插件 {plugin_name}。")
|
||||
|
||||
elif tokens.get(1) == "i":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin i: 你没有权限使用该指令。")
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin i: 请指定要安装的插件的 Github 地址,或者前往可视化面板安装。")
|
||||
plugin_url = tokens.get(2)
|
||||
try:
|
||||
self.plugin_manager.install_plugin(plugin_url)
|
||||
await self.plugin_manager.install_plugin(plugin_url)
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"plugin i: 安装插件失败。原因:{str(e)}")
|
||||
return CommandResult().message("plugin i: 已经成功安装插件。")
|
||||
|
||||
elif tokens.get(1) == "u":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin u: 你没有权限使用该指令。")
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin u: 请指定要更新的插件名。")
|
||||
plugin_name = tokens.get(2)
|
||||
try:
|
||||
self.plugin_manager.update_plugin(plugin_name)
|
||||
await context.plugin_updator.update(plugin_name)
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"plugin u: 更新插件失败。原因:{str(e)}")
|
||||
return CommandResult().message(f"plugin u: 已经成功更新插件 {plugin_name}。")
|
||||
|
||||
@@ -107,13 +107,13 @@ class PluginManager():
|
||||
rc = process.poll()
|
||||
|
||||
|
||||
def install_plugin(self, repo_url: str):
|
||||
async def install_plugin(self, repo_url: str):
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
# we no longer use Git anymore :)
|
||||
# Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
|
||||
|
||||
plugin_path = self.updator.update(repo_url)
|
||||
plugin_path = await self.updator.update(repo_url)
|
||||
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
|
||||
f.write(repo_url)
|
||||
|
||||
@@ -124,14 +124,14 @@ class PluginManager():
|
||||
# if not ok:
|
||||
# raise Exception(err)
|
||||
|
||||
def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
logger.info(f"正在下载插件 {repo} ...")
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
releases = self.updator.fetch_release_info(url=release_url)
|
||||
releases = await self.updator.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warn(f"未在插件 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
@@ -139,7 +139,7 @@ class PluginManager():
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
download_file(release_url, target_path + ".zip")
|
||||
await download_file(release_url, target_path + ".zip")
|
||||
|
||||
def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin:
|
||||
for p in self.context.cached_plugins:
|
||||
@@ -156,12 +156,12 @@ class PluginManager():
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
|
||||
def update_plugin(self, plugin_name: str):
|
||||
async def update_plugin(self, plugin_name: str):
|
||||
plugin = self.get_registered_plugin(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
|
||||
self.updator.update(plugin)
|
||||
await self.updator.update(plugin)
|
||||
|
||||
def plugin_reload(self):
|
||||
cached_plugins = self.context.cached_plugins
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pydantic~=1.10.4
|
||||
aiohttp
|
||||
requests
|
||||
openai
|
||||
qq-botpy
|
||||
chardet~=5.1.0
|
||||
@@ -10,7 +9,6 @@ beautifulsoup4
|
||||
googlesearch-python
|
||||
tiktoken
|
||||
readability-lxml
|
||||
baidu-aip
|
||||
websockets
|
||||
flask
|
||||
psutil
|
||||
|
||||
51
tests/test_http_server.py
Normal file
51
tests/test_http_server.py
Normal file
@@ -0,0 +1,51 @@
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
BASE_URL = "http://0.0.0.0:6185/api"
|
||||
|
||||
async def get_url(url):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
|
||||
async def post_url(url, data):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as response:
|
||||
return await response.json()
|
||||
|
||||
class TestHTTPServer:
|
||||
@pytest.mark.asyncio
|
||||
async def test_config(self):
|
||||
configs = await get_url(f"{BASE_URL}/configs")
|
||||
assert 'data' in configs and 'metadata' in configs['data'] \
|
||||
and 'config' in configs['data']
|
||||
config = configs['data']['config']
|
||||
# test post config
|
||||
await post_url(f"{BASE_URL}/astrbot-configs", config)
|
||||
# text post config with invalid data
|
||||
assert 'rate_limit' in config['platform_settings']
|
||||
config['platform_settings']['rate_limit'] = "invalid"
|
||||
ret = await post_url(f"{BASE_URL}/astrbot-configs", config)
|
||||
assert 'status' in ret and ret['status'] == 'error'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update(self):
|
||||
await get_url(f"{BASE_URL}/check_update")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(self):
|
||||
pname = "astrbot_plugin_bilibili"
|
||||
url = f"https://github.com/Soulter/{pname}"
|
||||
|
||||
await get_url(f"{BASE_URL}/extensions")
|
||||
|
||||
# test install plugin
|
||||
await post_url(f"{BASE_URL}/extensions/install", {
|
||||
"url": url
|
||||
})
|
||||
|
||||
# test uninstall plugin
|
||||
await post_url(f"{BASE_URL}/extensions/uninstall", {
|
||||
"name": pname
|
||||
})
|
||||
@@ -135,7 +135,17 @@ class TestInteralCommandHsandle():
|
||||
abm = self.create("/t2i")
|
||||
await aiocqhttp.handle_msg(abm)
|
||||
await self.fast_test("/help")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin(self):
|
||||
pname = "astrbot_plugin_bilibili"
|
||||
url = f"https://github.com/Soulter/{pname}"
|
||||
await self.fast_test("/plugin")
|
||||
await self.fast_test(f"/plugin l")
|
||||
await self.fast_test(f"/plugin i {url}")
|
||||
await self.fast_test(f"/plugin u {url}")
|
||||
await self.fast_test(f"/plugin d {pname}")
|
||||
|
||||
class TestLLMChat():
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_chat(self):
|
||||
|
||||
15
util/io.py
15
util/io.py
@@ -4,7 +4,6 @@ import shutil
|
||||
import socket
|
||||
import time
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from PIL import Image
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
@@ -99,16 +98,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def download_file(url: str, path: str):
|
||||
async def download_file(url: str, path: str):
|
||||
'''
|
||||
从指定 url 下载文件到指定路径 path
|
||||
'''
|
||||
try:
|
||||
logger.info(f"下载文件: {url}")
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(path, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import requests
|
||||
import aiohttp
|
||||
import json
|
||||
import sys
|
||||
|
||||
@@ -57,8 +57,9 @@ class MetricUploader():
|
||||
"command_stats": self.command_stats,
|
||||
"sys": sys.platform, # 系统版本
|
||||
}
|
||||
resp = requests.post(
|
||||
'https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5) as resp:
|
||||
pass
|
||||
if resp.status_code == 200:
|
||||
ok = resp.json()
|
||||
if ok['status'] == 'ok':
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
import requests
|
||||
import aiohttp
|
||||
from io import BytesIO
|
||||
|
||||
from .base_strategy import RenderStrategy
|
||||
from PIL import ImageFont, Image, ImageDraw
|
||||
@@ -82,8 +83,9 @@ class LocalRenderStrategy(RenderStrategy):
|
||||
try:
|
||||
image_url = re.findall(IMAGE_REGEX, line)[0]
|
||||
print(image_url)
|
||||
image_res = Image.open(requests.get(
|
||||
image_url, stream=True, timeout=5).raw)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as resp:
|
||||
image_res = Image.open(BytesIO(await resp.read()))
|
||||
images[i] = image_res
|
||||
# 最大不得超过image_width的50%
|
||||
img_height = image_res.size[1]
|
||||
|
||||
@@ -31,6 +31,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
pass
|
||||
|
||||
def _reboot(self, delay: int = None, context = None):
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
logger.info("测试模式下不会重启。")
|
||||
return
|
||||
# if delay: time.sleep(delay)
|
||||
py = sys.executable
|
||||
context.running = False
|
||||
@@ -43,11 +46,11 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||
raise e
|
||||
|
||||
def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
return super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
|
||||
|
||||
def update(self, reboot = False, latest = True, version = None):
|
||||
update_data = self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
|
||||
async def update(self, reboot = False, latest = True, version = None):
|
||||
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
|
||||
file_url = None
|
||||
|
||||
if latest:
|
||||
@@ -65,16 +68,10 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
raise Exception(f"未找到版本号为 {version} 的更新文件。")
|
||||
|
||||
try:
|
||||
download_file(file_url, "temp.zip")
|
||||
await download_file(file_url, "temp.zip")
|
||||
self.unzip_file("temp.zip", self.MAIN_PATH)
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
if reboot:
|
||||
self._reboot()
|
||||
|
||||
def unzip_file(self, zip_path: str, target_dir: str):
|
||||
'''
|
||||
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
|
||||
'''
|
||||
pass
|
||||
@@ -18,7 +18,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
def get_plugin_store_path(self) -> str:
|
||||
return self.plugin_store_path
|
||||
|
||||
def update(self, plugin: Union[RegisteredPlugin, str]) -> str:
|
||||
async def update(self, plugin: Union[RegisteredPlugin, str]) -> str:
|
||||
repo_url = None
|
||||
|
||||
if not isinstance(plugin, str):
|
||||
@@ -33,7 +33,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
plugin_path = os.path.join(self.plugin_store_path, self.format_repo_name(repo_url))
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
self.download_from_repo_url(plugin_path, repo_url)
|
||||
await self.download_from_repo_url(plugin_path, repo_url)
|
||||
|
||||
try:
|
||||
remove_dir(plugin_path)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import requests, os, zipfile, shutil
|
||||
import aiohttp, os, zipfile, shutil
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from util.io import on_error, download_file
|
||||
@@ -23,14 +23,15 @@ class RepoZipUpdator():
|
||||
self.path = path
|
||||
self.rm_on_error = on_error
|
||||
|
||||
def fetch_release_info(self, url: str, latest: bool = True) -> list:
|
||||
async def fetch_release_info(self, url: str, latest: bool = True) -> list:
|
||||
'''
|
||||
请求版本信息。
|
||||
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
|
||||
'''
|
||||
result = requests.get(url).json()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
result = await response.json()
|
||||
if not result: return []
|
||||
if latest:
|
||||
ret = self.github_api_release_parser([result[0]])
|
||||
@@ -66,7 +67,7 @@ class RepoZipUpdator():
|
||||
def unzip(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self):
|
||||
async def update(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def compare_version(self, v1: str, v2: str) -> int:
|
||||
@@ -86,8 +87,8 @@ class RepoZipUpdator():
|
||||
return -1
|
||||
return 0
|
||||
|
||||
def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
update_data = self.fetch_release_info(url)
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
update_data = await self.fetch_release_info(url)
|
||||
tag_name = update_data[0]['tag_name']
|
||||
|
||||
if self.compare_version(current_version, tag_name) >= 0:
|
||||
@@ -98,22 +99,22 @@ class RepoZipUpdator():
|
||||
body=update_data[0]['body']
|
||||
)
|
||||
|
||||
def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
logger.info(f"正在下载更新 {repo} ...")
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
releases = self.fetch_release_info(url=release_url)
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warn(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
download_file(release_url, target_path + ".zip")
|
||||
await download_file(release_url, target_path + ".zip")
|
||||
|
||||
|
||||
def unzip_file(self, zip_path: str, target_dir: str):
|
||||
|
||||
Reference in New Issue
Block a user