Compare commits
76 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53dbebb503 | ||
|
|
52df91eb60 | ||
|
|
a9a758d715 | ||
|
|
0226fa7a25 | ||
|
|
a4f47da35c | ||
|
|
29364000e2 | ||
|
|
ceecca44a4 | ||
|
|
50f62e66b0 | ||
|
|
ab39dfd254 | ||
|
|
708fad18b6 | ||
|
|
526ba34d87 | ||
|
|
5d4882dee9 | ||
|
|
48c4361d37 | ||
|
|
c1d070186e | ||
|
|
1a39fd9172 | ||
|
|
0c1ab4158e | ||
|
|
5221566335 | ||
|
|
2291c2d9ba | ||
|
|
0de14c4c8b | ||
|
|
51de0159fb | ||
|
|
37a756aeb3 | ||
|
|
353b6ed761 | ||
|
|
90815b1ac5 | ||
|
|
8a50786e61 | ||
|
|
3b77df0556 | ||
|
|
1fa11062de | ||
|
|
6883de0f1c | ||
|
|
bdde0fe094 | ||
|
|
ab22b8103e | ||
|
|
641d5cd67b | ||
|
|
9fe941e457 | ||
|
|
78060c9985 | ||
|
|
5bd6af3400 | ||
|
|
4ecd78d6a8 | ||
|
|
7e9f54ed2c | ||
|
|
7dd29c707f | ||
|
|
a1489fb1f9 | ||
|
|
5f0f5398e8 | ||
|
|
e3b2396f32 | ||
|
|
6fd70ed26a | ||
|
|
a93e6ff01a | ||
|
|
6db8c38c58 | ||
|
|
d3d3ff7970 | ||
|
|
c5b2b30f79 | ||
|
|
ac2144d65b | ||
|
|
c620b4f919 | ||
|
|
292a3a43ba | ||
|
|
5fc4693b9c | ||
|
|
6dfbaf1b88 | ||
|
|
14c6e56287 | ||
|
|
7e48514f67 | ||
|
|
d8e70c4d7f | ||
|
|
fb52989d62 | ||
|
|
5b72ebaad5 | ||
|
|
98863ab901 | ||
|
|
b5cb5eb969 | ||
|
|
7f4f96f77b | ||
|
|
3b3f75f03e | ||
|
|
a5db4d4e47 | ||
|
|
d3b0f25cfe | ||
|
|
a9c6a68c5f | ||
|
|
c27f172452 | ||
|
|
2eeb5822c1 | ||
|
|
743046d48f | ||
|
|
d3a5205bde | ||
|
|
ae6dd8929a | ||
|
|
dcf96896ef | ||
|
|
67792100bb | ||
|
|
48c1263417 | ||
|
|
12d37381fe | ||
|
|
dcec3f5f84 | ||
|
|
a7c87642b4 | ||
|
|
f8aef78d25 | ||
|
|
f5857aaa0c | ||
|
|
f4222e0923 | ||
|
|
f0caea9026 |
3
.codecov.yml
Normal file
3
.codecov.yml
Normal file
@@ -0,0 +1,3 @@
|
||||
comment:
|
||||
layout: "condensed_header, condensed_files, condensed_footer"
|
||||
hide_project_coverage: TRUE
|
||||
5
.coveragerc
Normal file
5
.coveragerc
Normal file
@@ -0,0 +1,5 @@
|
||||
[run]
|
||||
omit =
|
||||
*/site-packages/*
|
||||
*/dist-packages/*
|
||||
your_package_name/tests/*
|
||||
39
.github/workflows/coverage_test.yml
vendored
Normal file
39
.github/workflows/coverage_test.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: Run tests and upload coverage
|
||||
|
||||
on:
|
||||
push
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run tests and collect coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir temp
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
export LLM_MODEL=${{ secrets.LLM_MODEL }}
|
||||
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
|
||||
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,7 +1,7 @@
|
||||
__pycache__
|
||||
botpy.log
|
||||
.vscode
|
||||
data.db
|
||||
data_v2.db
|
||||
configs/session
|
||||
configs/config.yaml
|
||||
**/.DS_Store
|
||||
@@ -10,4 +10,5 @@ cmd_config.json
|
||||
data/*
|
||||
cookies.json
|
||||
logs/
|
||||
addons/plugins
|
||||
addons/plugins
|
||||
.coverage
|
||||
@@ -8,6 +8,8 @@
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.9+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||

|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=EYGsuUTfe00_iOu9JTXS7_TEpMkXOvwv&jump_from=webapi&authKey=uUEMKCROfsseS+8IzqPjzV3y1tzy4AkykwTib2jNkOFdzezF9s9XknqnIaf3CDft">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
</a>
|
||||
@@ -22,7 +24,6 @@
|
||||
🌍 支持的消息平台
|
||||
- QQ 群、QQ 频道(OneBot、QQ 官方接口)
|
||||
- Telegram([astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件)
|
||||
- WeChat(微信) ([astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件)
|
||||
|
||||
🌍 支持的大模型/底座:
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
# helloworld
|
||||
|
||||
AstrBot 插件模板
|
||||
|
||||
A template plugin for AstrBot plugin feature
|
||||
|
||||
# 支持
|
||||
|
||||
[帮助文档](https://astrbot.soulter.top/center/docs/%E5%BC%80%E5%8F%91/%E6%8F%92%E4%BB%B6%E5%BC%80%E5%8F%91/
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
https://github.com/Soulter/helloworld
|
||||
@@ -1,32 +0,0 @@
|
||||
flag_not_support = False
|
||||
try:
|
||||
from util.plugin_dev.api.v1.bot import Context, AstrMessageEvent, CommandResult
|
||||
from util.plugin_dev.api.v1.config import *
|
||||
except ImportError:
|
||||
flag_not_support = True
|
||||
print("导入接口失败。请升级到 AstrBot 最新版本。")
|
||||
|
||||
'''
|
||||
注意以格式 XXXPlugin 或 Main 来修改插件名。
|
||||
提示:把此模板仓库 fork 之后 clone 到机器人文件夹下的 addons/plugins/ 目录下,然后用 Pycharm/VSC 等工具打开可获更棒的编程体验(自动补全等)
|
||||
'''
|
||||
class HelloWorldPlugin:
|
||||
"""
|
||||
AstrBot 会传递 context 给插件。
|
||||
|
||||
- context.register_commands: 注册指令
|
||||
- context.register_task: 注册任务
|
||||
- context.message_handler: 消息处理器(平台类插件用)
|
||||
"""
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.context = context
|
||||
self.context.register_commands("helloworld", "helloworld", "内置测试指令。", 1, self.helloworld)
|
||||
|
||||
"""
|
||||
指令处理函数。
|
||||
|
||||
- 需要接收两个参数:message: AstrMessageEvent, context: Context
|
||||
- 返回 CommandResult 对象
|
||||
"""
|
||||
def helloworld(self, message: AstrMessageEvent, context: Context):
|
||||
return CommandResult().message("Hello, World!")
|
||||
@@ -1,6 +0,0 @@
|
||||
name: helloworld # 这是你的插件的唯一识别名。
|
||||
desc: 这是 AstrBot 的默认插件。
|
||||
help:
|
||||
version: v1.3 # 插件版本号。格式:v1.1.1 或者 v1.1
|
||||
author: Soulter # 作者
|
||||
repo: https://github.com/Soulter/helloworld # 插件的仓库地址
|
||||
@@ -1,52 +1,40 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import os
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from astrbot.persist.helper import dbConn
|
||||
from dashboard.server import AstrBotDashBoard
|
||||
from model.provider.provider import Provider
|
||||
from astrbot.db.sqlite import SQLiteDatabase
|
||||
from dashboard.server import AstrBotDashboard
|
||||
from model.command.manager import CommandManager
|
||||
from model.command.internal_handler import InternalCommandHandler
|
||||
from model.plugin.manager import PluginManager
|
||||
from model.platform.manager import PlatformManager
|
||||
from typing import Dict, List, Union
|
||||
from typing import Union
|
||||
from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from type.config import VERSION, DB_PATH
|
||||
from logging import Logger
|
||||
from util.cmd_config import CmdConfig
|
||||
from util.cmd_config import AstrBotConfig, try_migrate
|
||||
from util.metrics import MetricUploader
|
||||
from util.config_utils import *
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from util.log import LogManager
|
||||
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
class AstrBotBootstrap():
|
||||
def __init__(self) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.context = Context()
|
||||
self.config_helper = CmdConfig()
|
||||
|
||||
# load configs and ensure the backward compatibility
|
||||
try_migrate_config()
|
||||
try_migrate()
|
||||
self.config_helper = AstrBotConfig()
|
||||
self.context.config_helper = self.config_helper
|
||||
self.context.base_config = self.config_helper.cached_config
|
||||
|
||||
self.context.default_personality = {
|
||||
"name": "default",
|
||||
"prompt": self.context.base_config.get("default_personality_str", ""),
|
||||
}
|
||||
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
|
||||
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
|
||||
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
|
||||
self.context.nick = nick_qq
|
||||
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
|
||||
self.context.version = VERSION
|
||||
|
||||
logger.info("AstrBot v" + self.context.version)
|
||||
|
||||
# set log queue handler
|
||||
LogManager.set_queue_handler(logger, self.context._log_queue)
|
||||
logger.info("AstrBot v" + VERSION)
|
||||
# set log level
|
||||
logger.setLevel(self.config_helper.log_level)
|
||||
# apply proxy settings
|
||||
http_proxy = self.context.base_config.get("http_proxy")
|
||||
https_proxy = self.context.base_config.get("https_proxy")
|
||||
http_proxy = self.context.config_helper.http_proxy
|
||||
https_proxy = self.context.config_helper.https_proxy
|
||||
if http_proxy:
|
||||
os.environ['HTTP_PROXY'] = http_proxy
|
||||
if https_proxy:
|
||||
@@ -57,28 +45,44 @@ class AstrBotBootstrap():
|
||||
logger.info(f"使用代理: {http_proxy}, {https_proxy}")
|
||||
else:
|
||||
logger.info("未使用代理。")
|
||||
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
# set t2i endpoint
|
||||
if self.context.config_helper.t2i_endpoint:
|
||||
self.context.image_renderer.set_network_endpoint(
|
||||
self.context.config_helper.t2i_endpoint
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
self.command_manager = CommandManager()
|
||||
self.plugin_manager = PluginManager(self.context)
|
||||
self.updator = AstrBotUpdator()
|
||||
self.cmd_handler = InternalCommandHandler(self.command_manager, self.plugin_manager)
|
||||
self.db_conn_helper = dbConn()
|
||||
self.db_helper = SQLiteDatabase(DB_PATH)
|
||||
|
||||
# load llm provider
|
||||
self.llm_instance: Provider = None
|
||||
self.load_llm()
|
||||
|
||||
self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper, self.llm_instance)
|
||||
self.message_handler = MessageHandler(self.context, self.command_manager, self.db_helper)
|
||||
self.platfrom_manager = PlatformManager(self.context, self.message_handler)
|
||||
self.dashboard = AstrBotDashBoard(self.context, plugin_manager=self.plugin_manager, astrbot_updator=self.updator)
|
||||
self.metrics_uploader = MetricUploader(self.context)
|
||||
self.dashboard = AstrBotDashboard(self.context,
|
||||
plugin_manager=self.plugin_manager,
|
||||
astrbot_updator=self.updator,
|
||||
db_helper=self.db_helper)
|
||||
self.metrics_uploader = MetricUploader(self.context, self.db_helper)
|
||||
|
||||
self.context.metrics_uploader = self.metrics_uploader
|
||||
self.context.updator = self.updator
|
||||
self.context.plugin_updator = self.plugin_manager.updator
|
||||
self.context.message_handler = self.message_handler
|
||||
self.context.command_manager = self.command_manager
|
||||
|
||||
# load dashboard
|
||||
dashboard_server_task = asyncio.create_task(self.dashboard.run(), name="dashboard")
|
||||
|
||||
if self.test_mode:
|
||||
return
|
||||
|
||||
# load plugins, plugins' commands.
|
||||
self.load_plugins()
|
||||
@@ -88,10 +92,9 @@ class AstrBotBootstrap():
|
||||
platform_tasks = self.load_platform()
|
||||
# load metrics uploader
|
||||
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader")
|
||||
# load dashboard
|
||||
self.dashboard.run_http_server()
|
||||
dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
|
||||
tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks]
|
||||
|
||||
log_task = asyncio.create_task(self.dashboard.lr._receive_log_task(), name="log")
|
||||
tasks = [metrics_upload_task, dashboard_server_task, log_task, *platform_tasks, *self.context.ext_tasks]
|
||||
tasks = [self.handle_task(task) for task in tasks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@@ -109,16 +112,27 @@ class AstrBotBootstrap():
|
||||
return
|
||||
|
||||
def load_llm(self):
|
||||
if 'openai' in self.config_helper.cached_config and \
|
||||
len(self.config_helper.cached_config['openai']['key']) and \
|
||||
self.config_helper.cached_config['openai']['key'][0] is not None:
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
f = False
|
||||
llms = self.context.config_helper.llm
|
||||
logger.info(f"加载 {len(llms)} 个 LLM Provider...")
|
||||
for llm in llms:
|
||||
if llm.enable:
|
||||
if llm.name == "openai":
|
||||
if not llm.key or not llm.enable:
|
||||
logger.warning("没有开启 LLM Provider 或 API Key 未填写。")
|
||||
continue
|
||||
self.load_openai(llm)
|
||||
f = True
|
||||
logger.info(f"已启用 LLM Provider(OpenAI API): {llm.name}。")
|
||||
if f:
|
||||
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
|
||||
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
|
||||
self.llm_instance = ProviderOpenAIOfficial(self.context)
|
||||
self.openai_command_handler.set_provider(self.llm_instance)
|
||||
self.context.register_provider("internal_openai", self.llm_instance)
|
||||
logger.info("已启用 OpenAI API 支持。")
|
||||
self.openai_command_handler.set_provider(self.context.llms[0].llm_instance)
|
||||
|
||||
def load_openai(self, llm_config):
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
inst = ProviderOpenAIOfficial(llm_config, self.db_helper)
|
||||
self.context.register_provider("internal_openai", inst)
|
||||
|
||||
def load_plugins(self):
|
||||
self.plugin_manager.plugin_reload()
|
||||
@@ -126,5 +140,5 @@ class AstrBotBootstrap():
|
||||
def load_platform(self):
|
||||
platforms = self.platfrom_manager.load_platforms()
|
||||
if not platforms:
|
||||
logger.warn("未启用任何消息平台。")
|
||||
logger.warning("未启用任何消息平台。")
|
||||
return platforms
|
||||
64
astrbot/db/__init__.py
Normal file
64
astrbot/db/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from astrbot.db.po import Stats, LLMHistory
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
'''
|
||||
数据库基类
|
||||
'''
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def insert_base_metrics(self, metrics: dict):
|
||||
'''插入基础指标数据'''
|
||||
self.insert_platform_metrics(metrics['platform_stats'])
|
||||
self.insert_plugin_metrics(metrics['plugin_stats'])
|
||||
self.insert_command_metrics(metrics['command_stats'])
|
||||
self.insert_llm_metrics(metrics['llm_stats'])
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
'''插入平台指标数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_plugin_metrics(self, metrics: dict):
|
||||
'''插入插件指标数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_command_metrics(self, metrics: dict):
|
||||
'''插入指令指标数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
'''插入 LLM 指标数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_llm_history(self, session_id: str, content: str):
|
||||
'''更新 LLM 历史记录。当不存在 session_id 时插入'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_llm_history(self, session_id: str = None) -> List[LLMHistory]:
|
||||
'''获取 LLM 历史记录, 如果 session_id 为 None, 返回所有'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
'''获取基础统计数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_total_message_count(self) -> int:
|
||||
'''获取总消息数'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
'''获取基础统计数据(合并)'''
|
||||
raise NotImplementedError
|
||||
42
astrbot/db/po.py
Normal file
42
astrbot/db/po.py
Normal file
@@ -0,0 +1,42 @@
|
||||
'''指标数据'''
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
# default_factory
|
||||
from typing import List
|
||||
|
||||
@dataclass
|
||||
class Platform():
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
@dataclass
|
||||
class Provider():
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
@dataclass
|
||||
class Plugin():
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
@dataclass
|
||||
class Command():
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
@dataclass
|
||||
class Stats():
|
||||
platform: List[Platform] = field(default_factory=list)
|
||||
command: List[Command] = field(default_factory=list)
|
||||
llm: List[Provider] = field(default_factory=list)
|
||||
|
||||
'''LLM 聊天时持久化的信息'''
|
||||
|
||||
@dataclass
|
||||
class LLMHistory():
|
||||
session_id: str
|
||||
content: str
|
||||
211
astrbot/db/sqlite.py
Normal file
211
astrbot/db/sqlite.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import time
|
||||
from astrbot.db.po import (
|
||||
Platform,
|
||||
Command,
|
||||
Provider,
|
||||
Stats,
|
||||
LLMHistory
|
||||
)
|
||||
from . import BaseDatabase
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
def __init__(self, db_path: str) -> None:
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
|
||||
sql = f.read()
|
||||
|
||||
# 初始化数据库
|
||||
self.conn = self._get_conn(self.db_path)
|
||||
c = self.conn.cursor()
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: Tuple = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
conn = self._get_conn(self.db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
if params:
|
||||
c.execute(sql, params)
|
||||
c.close()
|
||||
else:
|
||||
c.execute(sql)
|
||||
c.close()
|
||||
|
||||
conn.commit()
|
||||
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
|
||||
''', (k, v, int(time.time()))
|
||||
)
|
||||
|
||||
def insert_plugin_metrics(self, metrics: dict):
|
||||
pass
|
||||
|
||||
def insert_command_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
|
||||
''', (k, v, int(time.time()))
|
||||
)
|
||||
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
|
||||
''', (k, v, int(time.time()))
|
||||
)
|
||||
|
||||
def update_llm_history(self, session_id: str, content: str):
|
||||
res = self.get_llm_history(session_id)
|
||||
if res:
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE llm_history SET content = ? WHERE session_id = ?
|
||||
''', (content, session_id)
|
||||
)
|
||||
else:
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO llm_history(session_id, content) VALUES (?, ?)
|
||||
''', (session_id, content)
|
||||
)
|
||||
|
||||
def get_llm_history(self, session_id: str = None) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
where_clause = "" if session_id is None else f"WHERE session_id = '{session_id}'"
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM llm_history
|
||||
''' + where_clause
|
||||
)
|
||||
res = c.fetchall()
|
||||
histories = []
|
||||
for row in res:
|
||||
histories.append(LLMHistory(*row))
|
||||
c.close()
|
||||
return histories
|
||||
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
'''获取 offset_sec 秒前到现在的基础统计数据'''
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM platform
|
||||
''' + where_clause
|
||||
)
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
# c.execute(
|
||||
# '''
|
||||
# SELECT * FROM command
|
||||
# ''' + where_clause
|
||||
# )
|
||||
|
||||
# command = []
|
||||
# for row in c.fetchall():
|
||||
# command.append(Command(*row))
|
||||
|
||||
# c.execute(
|
||||
# '''
|
||||
# SELECT * FROM llm
|
||||
# ''' + where_clause
|
||||
# )
|
||||
|
||||
# llm = []
|
||||
# for row in c.fetchall():
|
||||
# llm.append(Provider(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
|
||||
def get_total_message_count(self) -> int:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT SUM(count) FROM platform
|
||||
'''
|
||||
)
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return res[0]
|
||||
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
'''获取 offset_sec 秒前到现在的基础统计数据(合并)'''
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT name, SUM(count), timestamp FROM platform
|
||||
''' + where_clause + " GROUP BY name"
|
||||
)
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
# c.execute(
|
||||
# '''
|
||||
# SELECT name, SUM(count), timestamp FROM command
|
||||
# ''' + where_clause + " GROUP BY name"
|
||||
# )
|
||||
|
||||
# command = []
|
||||
# for row in c.fetchall():
|
||||
# command.append(Command(*row))
|
||||
|
||||
# c.execute(
|
||||
# '''
|
||||
# SELECT name, SUM(count), timestamp FROM llm
|
||||
# ''' + where_clause + " GROUP BY name"
|
||||
# )
|
||||
|
||||
# llm = []
|
||||
# for row in c.fetchall():
|
||||
# llm.append(Provider(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
24
astrbot/db/sqlite_init.sql
Normal file
24
astrbot/db/sqlite_init.sql
Normal file
@@ -0,0 +1,24 @@
|
||||
CREATE TABLE IF NOT EXISTS platform(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS plugin(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS command(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm_history(
|
||||
session_id VARCHAR(32),
|
||||
content TEXT
|
||||
);
|
||||
@@ -1,16 +1,15 @@
|
||||
from aip import AipContentCensor
|
||||
from util.cmd_config import BaiduAIPConfig
|
||||
|
||||
|
||||
class BaiduJudge:
|
||||
def __init__(self, baidu_configs) -> None:
|
||||
if 'app_id' in baidu_configs and 'api_key' in baidu_configs and 'secret_key' in baidu_configs:
|
||||
self.app_id = str(baidu_configs['app_id'])
|
||||
self.api_key = baidu_configs['api_key']
|
||||
self.secret_key = baidu_configs['secret_key']
|
||||
self.client = AipContentCensor(
|
||||
self.app_id, self.api_key, self.secret_key)
|
||||
else:
|
||||
raise ValueError("Baidu configs error! 请填写百度内容审核服务相关配置!")
|
||||
def __init__(self, baidu_configs: BaiduAIPConfig) -> None:
|
||||
self.app_id = baidu_configs.app_id
|
||||
self.api_key = baidu_configs.api_key
|
||||
self.secret_key = baidu_configs.secret_key
|
||||
self.client = AipContentCensor(self.app_id,
|
||||
self.api_key,
|
||||
self.secret_key)
|
||||
|
||||
def judge(self, text):
|
||||
res = self.client.textCensorUserDefined(text)
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
import time
|
||||
import re
|
||||
import time, json
|
||||
import re, os
|
||||
import asyncio
|
||||
import traceback
|
||||
import astrbot.message.unfit_words as uw
|
||||
|
||||
from typing import Dict
|
||||
from astrbot.persist.helper import dbConn
|
||||
from astrbot.db import BaseDatabase
|
||||
from model.provider.provider import Provider
|
||||
from model.command.manager import CommandManager
|
||||
from type.message_event import AstrMessageEvent, MessageResult
|
||||
from type.types import Context
|
||||
from type.command import CommandResult
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
import util.agent.web_searcher as web_searcher
|
||||
from util.agent.func_call import FuncCall
|
||||
from openai._exceptions import *
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -22,16 +24,11 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
class RateLimitHelper():
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.user_rate_limit: Dict[int, int] = {}
|
||||
self.rate_limit_time: int = 60
|
||||
self.rate_limit_count: int = 10
|
||||
rl = context.config_helper.platform_settings.rate_limit
|
||||
self.rate_limit_time: int = rl.time
|
||||
self.rate_limit_count: int = rl.count
|
||||
self.user_frequency = {}
|
||||
|
||||
if 'limit' in context.base_config:
|
||||
if 'count' in context.base_config['limit']:
|
||||
self.rate_limit_count = context.base_config['limit']['count']
|
||||
if 'time' in context.base_config['limit']:
|
||||
self.rate_limit_time = context.base_config['limit']['time']
|
||||
|
||||
|
||||
def check_frequency(self, session_id: str) -> bool:
|
||||
'''
|
||||
检查发言频率
|
||||
@@ -56,13 +53,15 @@ class RateLimitHelper():
|
||||
class ContentSafetyHelper():
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.baidu_judge = None
|
||||
if 'baidu_api' in context.base_config and \
|
||||
'enable' in context.base_config['baidu_aip'] and \
|
||||
context.base_config['baidu_aip']['enable']:
|
||||
aip = context.config_helper.content_safety.baidu_aip
|
||||
if aip.enable:
|
||||
try:
|
||||
from astrbot.message.baidu_aip_judge import BaiduJudge
|
||||
self.baidu_judge = BaiduJudge(context.base_config['baidu_aip'])
|
||||
self.baidu_judge = BaiduJudge(aip)
|
||||
logger.info("已启用百度 AI 内容审核。")
|
||||
except ImportError as e:
|
||||
logger.error("检测到库依赖不完整,将不会启用百度 AI 内容审核。请先使用 pip 安装 `baidu_aip` 包。")
|
||||
logger.error(e)
|
||||
except BaseException as e:
|
||||
logger.error("百度 AI 内容审核初始化失败。")
|
||||
logger.error(e)
|
||||
@@ -104,20 +103,20 @@ class ContentSafetyHelper():
|
||||
class MessageHandler():
|
||||
def __init__(self, context: Context,
|
||||
command_manager: CommandManager,
|
||||
persist_manager: dbConn,
|
||||
provider: Provider) -> None:
|
||||
db_helper: BaseDatabase) -> None:
|
||||
self.context = context
|
||||
self.command_manager = command_manager
|
||||
self.persist_manager = persist_manager
|
||||
self.db_helper = db_helper
|
||||
self.rate_limit_helper = RateLimitHelper(context)
|
||||
self.content_safety_helper = ContentSafetyHelper(context)
|
||||
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
|
||||
self.llm_wake_prefix = self.context.config_helper.llm_settings.wake_prefix
|
||||
self.llm_identifier = self.context.config_helper.llm_settings.identifier
|
||||
if self.llm_wake_prefix:
|
||||
self.llm_wake_prefix = self.llm_wake_prefix.strip()
|
||||
self.nicks = self.context.nick
|
||||
self.provider = provider
|
||||
self.reply_prefix = str(self.context.reply_prefix)
|
||||
|
||||
self.provider = self.context.llms[0].llm_instance if len(self.context.llms) > 0 else None
|
||||
self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix)
|
||||
self.llm_tools = FuncCall(self.provider)
|
||||
|
||||
def set_provider(self, provider: Provider):
|
||||
self.provider = provider
|
||||
|
||||
@@ -128,23 +127,19 @@ class MessageHandler():
|
||||
`llm_provider`: the provider to use for LLM. If None, use the default provider
|
||||
'''
|
||||
msg_plain = message.message_str.strip()
|
||||
provider = llm_provider if llm_provider else self.provider
|
||||
inner_provider = False if llm_provider else True
|
||||
|
||||
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
|
||||
provider = llm_provider if llm_provider else self.provider
|
||||
|
||||
# TODO: this should be configurable
|
||||
# if not message.message_str:
|
||||
# return MessageResult("Hi~")
|
||||
|
||||
# check the rate limit
|
||||
if not message.only_command and not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
|
||||
# return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
|
||||
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制, 跳过。")
|
||||
if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
|
||||
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制,已忽略。")
|
||||
return
|
||||
|
||||
|
||||
# remove the nick prefix
|
||||
for nick in self.nicks:
|
||||
for nick in self.context.config_helper.wake_prefix:
|
||||
if msg_plain.startswith(nick):
|
||||
msg_plain = msg_plain.removeprefix(nick)
|
||||
break
|
||||
@@ -159,12 +154,20 @@ class MessageHandler():
|
||||
is_command_call=True,
|
||||
use_t2i=cmd_res.is_use_t2i
|
||||
)
|
||||
|
||||
# next is the LLM part
|
||||
|
||||
# middlewares
|
||||
for middleware in self.context.middlewares:
|
||||
try:
|
||||
logger.info(f"执行中间件 {middleware.origin}/{middleware.name}...")
|
||||
await middleware.func(message, self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"中间件 {middleware.origin}/{middleware.name} 处理消息时发生异常:{e},跳过。")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if message.only_command:
|
||||
return
|
||||
|
||||
# next is the LLM part
|
||||
# check if the message is a llm-wake-up command
|
||||
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
|
||||
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
|
||||
@@ -183,31 +186,102 @@ class MessageHandler():
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
break
|
||||
web_search = self.context.web_search
|
||||
if not web_search and msg_plain.startswith("ws"):
|
||||
# leverage web search feature
|
||||
web_search = True
|
||||
msg_plain = msg_plain.removeprefix("ws").strip()
|
||||
|
||||
try:
|
||||
if web_search:
|
||||
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True)
|
||||
if not self.llm_tools.empty():
|
||||
# tools-use
|
||||
tool_use_flag = True
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
tools=self.llm_tools.get_func()
|
||||
)
|
||||
|
||||
if isinstance(llm_result, Function):
|
||||
logger.debug(f"function-calling: {llm_result}")
|
||||
func_obj = None
|
||||
for i in self.llm_tools.func_list:
|
||||
if i["name"] == llm_result.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。")
|
||||
try:
|
||||
args = json.loads(llm_result.arguments)
|
||||
args['ame'] = message
|
||||
args['context'] = self.context
|
||||
try:
|
||||
cmd_res = await func_obj(**args)
|
||||
except TypeError as e:
|
||||
args.pop('ame')
|
||||
args.pop('context')
|
||||
cmd_res = await func_obj(**args)
|
||||
if isinstance(cmd_res, CommandResult):
|
||||
return MessageResult(
|
||||
cmd_res.message_chain,
|
||||
is_command_call=True,
|
||||
use_t2i=cmd_res.is_use_t2i
|
||||
)
|
||||
elif isinstance(cmd_res, str):
|
||||
return MessageResult(cmd_res)
|
||||
elif not cmd_res:
|
||||
return
|
||||
else:
|
||||
return MessageResult(f"AstrBot Function-calling 异常:调用:{llm_result} 时,返回了未知的返回值类型。")
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return MessageResult("AstrBot Function-calling 异常:" + str(e))
|
||||
else:
|
||||
return MessageResult(llm_result)
|
||||
|
||||
else:
|
||||
# normal chat
|
||||
tool_use_flag = False
|
||||
# add user info to the prompt
|
||||
if self.llm_identifier:
|
||||
user_id = message.message_obj.sender.user_id
|
||||
user_nickname = message.message_obj.sender.nickname
|
||||
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
msg_plain = user_info + msg_plain
|
||||
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
image_url=image_url
|
||||
)
|
||||
except BadRequestError as e:
|
||||
if tool_use_flag:
|
||||
# seems like the model don't support function-calling
|
||||
logger.error(f"error: {e}. Using local function-calling implementation")
|
||||
|
||||
try:
|
||||
# use local function-calling implementation
|
||||
args = {
|
||||
'question': llm_result,
|
||||
'func_definition': self.llm_tools.func_dump(),
|
||||
}
|
||||
_, has_func = await self.llm_tools.func_call(**args)
|
||||
|
||||
if not has_func:
|
||||
# normal chat
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
image_url=image_url
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return CommandResult("AstrBot Function-calling 异常:" + str(e))
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"LLM 调用失败。")
|
||||
return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e))
|
||||
|
||||
# concatenate the reply prefix
|
||||
|
||||
# concatenate reply prefix
|
||||
if self.reply_prefix:
|
||||
llm_result = self.reply_prefix + llm_result
|
||||
|
||||
# mask the unsafe content
|
||||
# mask unsafe content
|
||||
llm_result = self.content_safety_helper.filter_content(llm_result)
|
||||
check = self.content_safety_helper.baidu_check(llm_result)
|
||||
if not check:
|
||||
|
||||
@@ -1,269 +0,0 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
class dbConn():
|
||||
def __init__(self):
|
||||
db_path = "data/data.db"
|
||||
if os.path.exists("data.db"):
|
||||
shutil.copy("data.db", db_path)
|
||||
with open(os.path.dirname(__file__) + "/initialization.sql", "r") as f:
|
||||
sql = f.read()
|
||||
|
||||
self.conn = sqlite3.connect(db_path)
|
||||
self.conn.text_factory = str
|
||||
c = self.conn.cursor()
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
def record_message(self, platform, session_id):
|
||||
curr_ts = int(time.time())
|
||||
self.increment_stat_session(platform, session_id, 1)
|
||||
self.increment_stat_message(curr_ts, 1)
|
||||
self.increment_stat_platform(curr_ts, platform, 1)
|
||||
|
||||
def insert_session(self, qq_id, history):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
INSERT INTO tb_session(qq_id, history) VALUES (?, ?)
|
||||
''', (qq_id, history)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_session(self, qq_id, history):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
UPDATE tb_session SET history = ? WHERE qq_id = ?
|
||||
''', (history, qq_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_session(self, qq_id):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_session WHERE qq_id = ?
|
||||
''', (qq_id, )
|
||||
)
|
||||
return c.fetchone()
|
||||
|
||||
def get_all_session(self):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_session
|
||||
'''
|
||||
)
|
||||
return c.fetchall()
|
||||
|
||||
def check_session(self, qq_id):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_session WHERE qq_id = ?
|
||||
''', (qq_id, )
|
||||
)
|
||||
return c.fetchone() is not None
|
||||
|
||||
def delete_session(self, qq_id):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
DELETE FROM tb_session WHERE qq_id = ?
|
||||
''', (qq_id, )
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def increment_stat_session(self, platform, session_id, cnt):
|
||||
# if not exist, insert
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
|
||||
if self.check_stat_session(platform, session_id):
|
||||
c.execute(
|
||||
'''
|
||||
UPDATE tb_stat_session SET cnt = cnt + ? WHERE platform = ? AND session_id = ?
|
||||
''', (cnt, platform, session_id)
|
||||
)
|
||||
conn.commit()
|
||||
else:
|
||||
c.execute(
|
||||
'''
|
||||
INSERT INTO tb_stat_session(platform, session_id, cnt) VALUES (?, ?, ?)
|
||||
''', (platform, session_id, cnt)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def check_stat_session(self, platform, session_id):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_stat_session WHERE platform = ? AND session_id = ?
|
||||
''', (platform, session_id)
|
||||
)
|
||||
return c.fetchone() is not None
|
||||
|
||||
def get_all_stat_session(self):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_stat_session
|
||||
'''
|
||||
)
|
||||
return c.fetchall()
|
||||
|
||||
def get_session_cnt_total(self):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT COUNT(*) FROM tb_stat_session
|
||||
'''
|
||||
)
|
||||
return c.fetchone()[0]
|
||||
|
||||
def increment_stat_message(self, ts, cnt):
|
||||
# 以一个小时为单位。ts的单位是秒。
|
||||
# 找到最近的一个小时,如果没有,就插入
|
||||
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
|
||||
ok, new_ts = self.check_stat_message(ts)
|
||||
|
||||
if ok:
|
||||
c.execute(
|
||||
'''
|
||||
UPDATE tb_stat_message SET cnt = cnt + ? WHERE ts = ?
|
||||
''', (cnt, new_ts)
|
||||
)
|
||||
conn.commit()
|
||||
else:
|
||||
c.execute(
|
||||
'''
|
||||
INSERT INTO tb_stat_message(ts, cnt) VALUES (?, ?)
|
||||
''', (new_ts, cnt)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def check_stat_message(self, ts) -> Tuple[bool, int]:
|
||||
# 换算成当地整点的时间戳
|
||||
|
||||
ts = ts - ts % 3600
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_stat_message WHERE ts = ?
|
||||
''', (ts, )
|
||||
)
|
||||
if c.fetchone() is not None:
|
||||
return True, ts
|
||||
else:
|
||||
return False, ts
|
||||
|
||||
def get_last_24h_stat_message(self):
|
||||
# 获取最近24小时的消息统计
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_stat_message WHERE ts > ?
|
||||
''', (time.time() - 86400, )
|
||||
)
|
||||
return c.fetchall()
|
||||
|
||||
def get_message_cnt_total(self) -> int:
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT SUM(cnt) FROM tb_stat_message
|
||||
'''
|
||||
)
|
||||
return c.fetchone()[0]
|
||||
|
||||
def increment_stat_platform(self, ts, platform, cnt):
|
||||
# 以一个小时为单位。ts的单位是秒。
|
||||
# 找到最近的一个小时,如果没有,就插入
|
||||
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
|
||||
ok, new_ts = self.check_stat_platform(ts, platform)
|
||||
|
||||
if ok:
|
||||
c.execute(
|
||||
'''
|
||||
UPDATE tb_stat_platform SET cnt = cnt + ? WHERE ts = ? AND platform = ?
|
||||
''', (cnt, new_ts, platform)
|
||||
)
|
||||
conn.commit()
|
||||
else:
|
||||
c.execute(
|
||||
'''
|
||||
INSERT INTO tb_stat_platform(ts, platform, cnt) VALUES (?, ?, ?)
|
||||
''', (new_ts, platform, cnt)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def check_stat_platform(self, ts, platform):
|
||||
# 换算成当地整点的时间戳
|
||||
|
||||
ts = ts - ts % 3600
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_stat_platform WHERE ts = ? AND platform = ?
|
||||
''', (ts, platform)
|
||||
)
|
||||
if c.fetchone() is not None:
|
||||
return True, ts
|
||||
else:
|
||||
return False, ts
|
||||
|
||||
def get_last_24h_stat_platform(self):
|
||||
# 获取最近24小时的消息统计
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_stat_platform WHERE ts > ?
|
||||
''', (time.time() - 86400, )
|
||||
)
|
||||
return c.fetchall()
|
||||
|
||||
def get_platform_cnt_total(self) -> int:
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT platform, SUM(cnt) FROM tb_stat_platform GROUP BY platform
|
||||
'''
|
||||
)
|
||||
# return c.fetchall()
|
||||
platforms = []
|
||||
ret = c.fetchall()
|
||||
for i in ret:
|
||||
# platforms[i[0]] = i[1]
|
||||
platforms.append({
|
||||
"name": i[0],
|
||||
"count": i[1]
|
||||
})
|
||||
return platforms
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
@@ -1,18 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS tb_session(
|
||||
qq_id VARCHAR(32) PRIMARY KEY,
|
||||
history TEXT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS tb_stat_session(
|
||||
platform VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
cnt INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS tb_stat_message(
|
||||
ts INTEGER,
|
||||
cnt INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS tb_stat_platform(
|
||||
ts INTEGER,
|
||||
platform VARCHAR(32),
|
||||
cnt INTEGER
|
||||
);
|
||||
@@ -1,11 +1,31 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from quart import Quart
|
||||
from type.types import Context
|
||||
|
||||
class DashBoardData():
|
||||
stats: dict = {}
|
||||
configs: dict = {}
|
||||
logger = logging.getLogger("astrbot")
|
||||
class Route():
|
||||
def __init__(self, context: Context, app: Quart):
|
||||
self.context = context
|
||||
self.app = app
|
||||
|
||||
def register_routes(self):
|
||||
for route, (method, func) in self.routes.items():
|
||||
self.app.add_url_rule(f"/api{route}", view_func=func, methods=[method])
|
||||
|
||||
@dataclass
|
||||
class Response():
|
||||
status: str
|
||||
message: str
|
||||
data: dict
|
||||
status: str = None
|
||||
message: str = None
|
||||
data: dict = None
|
||||
|
||||
def error(self, message: str):
|
||||
self.status = "error"
|
||||
self.message = message
|
||||
return self
|
||||
|
||||
def ok(self, data: dict={}, message: str=None):
|
||||
self.status = "ok"
|
||||
self.data = data
|
||||
self.message = message
|
||||
return self
|
||||
@@ -1 +0,0 @@
|
||||
.page-breadcrumb .v-toolbar{background:transparent}
|
||||
@@ -1 +0,0 @@
|
||||
import{x as i,o as l,c as _,w as s,a as e,f as a,J as m,V as c,b as t,t as u,ae as p,B as n,af as o,j as f}from"./index-5ac7c267.js";const b={class:"text-h3"},h={class:"d-flex align-center"},g={class:"d-flex align-center"},V=i({__name:"BaseBreadcrumb",props:{title:String,breadcrumbs:Array,icon:String},setup(d){const r=d;return(x,B)=>(l(),_(c,{class:"page-breadcrumb mb-1 mt-1"},{default:s(()=>[e(a,{cols:"12",md:"12"},{default:s(()=>[e(m,{variant:"outlined",elevation:"0",class:"px-4 py-3 withbg"},{default:s(()=>[e(c,{"no-gutters":"",class:"align-center"},{default:s(()=>[e(a,{md:"5"},{default:s(()=>[t("h3",b,u(r.title),1)]),_:1}),e(a,{md:"7",sm:"12",cols:"12"},{default:s(()=>[e(p,{items:r.breadcrumbs,class:"text-h5 justify-md-end pa-1"},{divider:s(()=>[t("div",h,[e(n(o),{size:"17"})])]),prepend:s(()=>[e(f,{size:"small",icon:"mdi-home",class:"text-secondary mr-2"}),t("div",g,[e(n(o),{size:"17"})])]),_:1},8,["items"])]),_:1})]),_:1})]),_:1})]),_:1})]),_:1}))}});export{V as _};
|
||||
@@ -1 +0,0 @@
|
||||
import{x as e,o as a,c as t,w as o,a as s,B as n,Z as r,W as c}from"./index-5ac7c267.js";const f=e({__name:"BlankLayout",setup(p){return(u,_)=>(a(),t(c,null,{default:o(()=>[s(n(r))]),_:1}))}});export{f as default};
|
||||
1
dashboard/dist/assets/BlankLayout-a97b3dac.js
vendored
Normal file
1
dashboard/dist/assets/BlankLayout-a97b3dac.js
vendored
Normal file
@@ -0,0 +1 @@
|
||||
import{q as e,o as a,c as t,w as o,d as s,x as n,U as r,X as c}from"./index-a2f0b905.js";const f=e({__name:"BlankLayout",setup(p){return(u,_)=>(a(),t(r,null,{default:o(()=>[s(n(c))]),_:1}))}});export{f as default};
|
||||
1
dashboard/dist/assets/ColorPage-beafd674.js
vendored
1
dashboard/dist/assets/ColorPage-beafd674.js
vendored
@@ -1 +0,0 @@
|
||||
import{_ as m}from"./BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js";import{_}from"./UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js";import{x as p,D as a,o as r,s,a as e,w as t,f as o,V as i,F as n,u as g,c as h,a0 as b,e as x,t as y}from"./index-5ac7c267.js";const P=p({__name:"ColorPage",setup(C){const c=a({title:"Colors Page"}),d=a([{title:"Utilities",disabled:!1,href:"#"},{title:"Colors",disabled:!0,href:"#"}]),u=a(["primary","lightprimary","secondary","lightsecondary","info","success","accent","warning","error","darkText","lightText","borderLight","inputBorder","containerBg"]);return(V,k)=>(r(),s(n,null,[e(m,{title:c.value.title,breadcrumbs:d.value},null,8,["title","breadcrumbs"]),e(i,null,{default:t(()=>[e(o,{cols:"12",md:"12"},{default:t(()=>[e(_,{title:"Color Palette"},{default:t(()=>[e(i,null,{default:t(()=>[(r(!0),s(n,null,g(u.value,(l,f)=>(r(),h(o,{md:"3",cols:"12",key:f},{default:t(()=>[e(b,{rounded:"md",class:"align-center justify-center d-flex",height:"100",width:"100%",color:l},{default:t(()=>[x("class: "+y(l),1)]),_:2},1032,["color"])]),_:2},1024))),128))]),_:1})]),_:1})]),_:1})]),_:1})],64))}});export{P as default};
|
||||
@@ -1 +0,0 @@
|
||||
import{o as l,s as o,u as c,c as n,w as u,Q as g,b as d,R as k,F as t,ac as h,O as p,t as m,a as V,ad as f,i as C,q as x,k as v,A as U}from"./index-5ac7c267.js";import{_ as w}from"./UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js";const S={__name:"ConfigDetailCard",props:{config:Array},setup(s){return(y,B)=>(l(!0),o(t,null,c(s.config,r=>(l(),n(w,{key:r.name,title:r.name,style:{"margin-bottom":"16px"}},{default:u(()=>[g(d("a",null,"No data",512),[[k,s.config.length===0]]),(l(!0),o(t,null,c(r.body,e=>(l(),o(t,null,[e.config_type==="item"?(l(),o(t,{key:0},[e.val_type==="bool"?(l(),n(h,{key:0,modelValue:e.value,"onUpdate:modelValue":a=>e.value=a,label:e.name,hint:e.description,color:"primary",inset:""},null,8,["modelValue","onUpdate:modelValue","label","hint"])):e.val_type==="str"?(l(),n(p,{key:1,modelValue:e.value,"onUpdate:modelValue":a=>e.value=a,label:e.name,hint:e.description,style:{"margin-bottom":"8px"},variant:"outlined"},null,8,["modelValue","onUpdate:modelValue","label","hint"])):e.val_type==="int"?(l(),n(p,{key:2,modelValue:e.value,"onUpdate:modelValue":a=>e.value=a,label:e.name,hint:e.description,style:{"margin-bottom":"8px"},variant:"outlined"},null,8,["modelValue","onUpdate:modelValue","label","hint"])):e.val_type==="list"?(l(),o(t,{key:3},[d("span",null,m(e.name),1),V(f,{modelValue:e.value,"onUpdate:modelValue":a=>e.value=a,chips:"",clearable:"",label:"请添加",multiple:"","prepend-icon":"mdi-tag-multiple-outline"},{selection:u(({attrs:a,item:i,select:b,selected:_})=>[V(C,x(a,{"model-value":_,closable:"",onClick:b,"onClick:close":D=>y.remove(i)}),{default:u(()=>[d("strong",null,m(i),1)]),_:2},1040,["model-value","onClick","onClick:close"])]),_:2},1032,["modelValue","onUpdate:modelValue"])],64)):v("",!0)],64)):e.config_type==="divider"?(l(),n(U,{key:1,style:{"margin-top":"8px","margin-bottom":"8px"}})):v("",!0)],64))),256))]),_:2},1032,["title"]))),128))}};export{S as _};
|
||||
1
dashboard/dist/assets/ConfigPage-56ea019d.js
vendored
1
dashboard/dist/assets/ConfigPage-56ea019d.js
vendored
@@ -1 +0,0 @@
|
||||
import{_ as b}from"./UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js";import{x as h,o,c as u,w as t,a,a8 as y,b as c,K as x,e as f,t as g,G as V,A as w,L as S,a9 as $,J as B,s as _,d as v,F as d,u as p,f as G,V as T,ab as j,T as l}from"./index-5ac7c267.js";import{_ as m}from"./ConfigDetailCard-756c045d.js";const D={class:"d-sm-flex align-center justify-space-between"},C=h({__name:"ConfigGroupCard",props:{title:String},setup(e){const s=e;return(i,n)=>(o(),u(B,{variant:"outlined",elevation:"0",class:"withbg",style:{width:"50%"}},{default:t(()=>[a(y,{style:{padding:"10px 20px"}},{default:t(()=>[c("div",D,[a(x,null,{default:t(()=>[f(g(s.title),1)]),_:1}),a(V)])]),_:1}),a(w),a(S,null,{default:t(()=>[$(i.$slots,"default")]),_:3})]),_:3}))}}),I={style:{display:"flex","flex-direction":"row","justify-content":"space-between","align-items":"center","margin-bottom":"12px"}},N={style:{display:"flex","flex-direction":"row"}},R={style:{"margin-right":"10px",color:"black"}},F={style:{color:"#222"}},k=h({__name:"ConfigGroupItem",props:{title:String,desc:String,btnRoute:String,namespace:String},setup(e){const s=e;return(i,n)=>(o(),_("div",I,[c("div",N,[c("h3",R,g(s.title),1),c("p",F,g(s.desc),1)]),a(v,{to:s.btnRoute,color:"primary",class:"ml-2",style:{"border-radius":"10px"}},{default:t(()=>[f("配置")]),_:1},8,["to"])]))}}),L={style:{display:"flex","flex-direction":"row",padding:"16px",gap:"16px",width:"100%"}},P={name:"ConfigPage",components:{UiParentCard:b,ConfigGroupCard:C,ConfigGroupItem:k,ConfigDetailCard:m},data(){return{config_data:[],config_base:[],save_message_snack:!1,save_message:"",save_message_success:"",config_outline:[],namespace:""}},mounted(){this.getConfig()},methods:{switchConfig(e){l.get("/api/configs?namespace="+e).then(s=>{this.namespace=e,this.config_data=s.data.data,console.log(this.config_data)}).catch(s=>{save_message=s,save_message_snack=!0,save_message_success="error"})},getConfig(){l.get("/api/config_outline").then(e=>{this.config_outline=e.data.data,console.log(this.config_outline)}).catch(e=>{save_message=e,save_message_snack=!0,save_message_success="error"}),l.get("/api/configs").then(e=>{this.config_base=e.data.data,console.log(this.config_data)}).catch(e=>{save_message=e,save_message_snack=!0,save_message_success="error"})},updateConfig(){l.post("/api/configs",{base_config:this.config_base,config:this.config_data,namespace:this.namespace}).then(e=>{e.data.status==="success"?(this.save_message=e.data.message,this.save_message_snack=!0,this.save_message_success="success"):(this.save_message=e.data.message,this.save_message_snack=!0,this.save_message_success="error")}).catch(e=>{this.save_message=e,this.save_message_snack=!0,this.save_message_success="error"})}}},J=Object.assign(P,{setup(e){return(s,i)=>(o(),_(d,null,[a(T,null,{default:t(()=>[c("div",L,[(o(!0),_(d,null,p(s.config_outline,n=>(o(),u(C,{key:n.name,title:n.name},{default:t(()=>[(o(!0),_(d,null,p(n.body,r=>(o(),u(k,{title:r.title,desc:r.desc,namespace:r.namespace,onClick:U=>s.switchConfig(r.namespace)},null,8,["title","desc","namespace","onClick"]))),256))]),_:2},1032,["title"]))),128))]),a(G,{cols:"12",md:"12"},{default:t(()=>[a(m,{config:s.config_data},null,8,["config"]),a(m,{config:s.config_base},null,8,["config"])]),_:1})]),_:1}),a(v,{icon:"mdi-content-save",size:"x-large",style:{position:"fixed",right:"52px",bottom:"52px"},color:"darkprimary",onClick:s.updateConfig},null,8,["onClick"]),a(j,{timeout:2e3,elevation:"24",color:s.save_message_success,modelValue:s.save_message_snack,"onUpdate:modelValue":i[0]||(i[0]=n=>s.save_message_snack=n)},{default:t(()=>[f(g(s.save_message),1)]),_:1},8,["color","modelValue"])],64))}});export{J as default};
|
||||
1
dashboard/dist/assets/ConfigPage-e09e97c8.js
vendored
Normal file
1
dashboard/dist/assets/ConfigPage-e09e97c8.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
dashboard/dist/assets/ConfigPage-f564cc69.css
vendored
Normal file
1
dashboard/dist/assets/ConfigPage-f564cc69.css
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.v-tab{text-transform:none!important}
|
||||
File diff suppressed because one or more lines are too long
1
dashboard/dist/assets/DefaultDashboard-512a61eb.js
vendored
Normal file
1
dashboard/dist/assets/DefaultDashboard-512a61eb.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
.CardMediaWrapper{max-width:720px;margin:0 auto;position:relative}.CardMediaBuild{position:absolute;top:0;left:0;width:100%;animation:5s bounce ease-in-out infinite}.CardMediaParts{position:absolute;top:0;left:0;width:100%;animation:10s blink ease-in-out infinite}
|
||||
@@ -1 +0,0 @@
|
||||
import{_ as a}from"./_plugin-vue_export-helper-c27b6911.js";import{o,c,w as s,V as i,a as t,b as e,d as l,e as r,f as d}from"./index-5ac7c267.js";const n="/assets/img-error-bg-41f65efa.svg",_="/assets/img-error-blue-f50c8e77.svg",m="/assets/img-error-text-630dc36d.svg",g="/assets/img-error-purple-b97a483b.svg";const p={},u={class:"text-center"},f=e("div",{class:"CardMediaWrapper"},[e("img",{src:n,alt:"grid",class:"w-100"}),e("img",{src:_,alt:"grid",class:"CardMediaParts"}),e("img",{src:m,alt:"build",class:"CardMediaBuild"}),e("img",{src:g,alt:"build",class:"CardMediaBuild"})],-1),h=e("h1",{class:"text-h1"},"Something is wrong",-1),v=e("p",null,[e("small",null,[r("The page you are looking was moved, removed, "),e("br"),r("renamed, or might never exist! ")])],-1);function x(b,V){return o(),c(i,{"no-gutters":"",class:"h-100vh"},{default:s(()=>[t(d,{class:"d-flex align-center justify-center"},{default:s(()=>[e("div",u,[f,h,v,t(l,{variant:"flat",color:"primary",class:"mt-4",to:"/","prepend-icon":"mdi-home"},{default:s(()=>[r(" Home")]),_:1})])]),_:1})]),_:1})}const C=a(p,[["render",x]]);export{C as default};
|
||||
File diff suppressed because one or more lines are too long
1
dashboard/dist/assets/ExtensionPage-d720ef03.js
vendored
Normal file
1
dashboard/dist/assets/ExtensionPage-d720ef03.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
dashboard/dist/assets/FullLayout-35e69863.js
vendored
1
dashboard/dist/assets/FullLayout-35e69863.js
vendored
File diff suppressed because one or more lines are too long
1
dashboard/dist/assets/FullLayout-8b7c2f13.js
vendored
Normal file
1
dashboard/dist/assets/FullLayout-8b7c2f13.js
vendored
Normal file
File diff suppressed because one or more lines are too long
5
dashboard/dist/assets/LoginPage-5c692a20.js
vendored
5
dashboard/dist/assets/LoginPage-5c692a20.js
vendored
File diff suppressed because one or more lines are too long
5
dashboard/dist/assets/LoginPage-7b23780a.js
vendored
Normal file
5
dashboard/dist/assets/LoginPage-7b23780a.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
import{aw as _,x as d,D as n,o as c,s as m,a as f,w as p,Q as r,b as a,R as o,B as t,ax as h}from"./index-5ac7c267.js";const s={Sidebar_drawer:!0,Customizer_drawer:!1,mini_sidebar:!1,fontTheme:"Roboto",inputBg:!1},l=_({id:"customizer",state:()=>({Sidebar_drawer:s.Sidebar_drawer,Customizer_drawer:s.Customizer_drawer,mini_sidebar:s.mini_sidebar,fontTheme:"Poppins",inputBg:s.inputBg}),getters:{},actions:{SET_SIDEBAR_DRAWER(){this.Sidebar_drawer=!this.Sidebar_drawer},SET_MINI_SIDEBAR(e){this.mini_sidebar=e},SET_FONT(e){this.fontTheme=e}}}),u={class:"logo",style:{display:"flex","align-items":"center"}},b={style:{"font-size":"24px","font-weight":"1000"}},w={style:{"font-size":"20px","font-weight":"1000"}},S={style:{"font-size":"20px"}},z=d({__name:"LogoDark",setup(e){n("rgb(var(--v-theme-primary))"),n("rgb(var(--v-theme-secondary))");const i=l();return(g,B)=>(c(),m("div",u,[f(t(h),{to:"/",style:{"text-decoration":"none",color:"black"}},{default:p(()=>[r(a("span",b,"AstrBot 仪表盘",512),[[o,!t(i).mini_sidebar]]),r(a("span",w,"Astr",512),[[o,t(i).mini_sidebar]]),r(a("span",S,"Bot",512),[[o,t(i).mini_sidebar]])]),_:1})]))}});export{z as _,l as u};
|
||||
@@ -1 +0,0 @@
|
||||
import{_ as o}from"./BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js";import{_ as i}from"./UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js";import{x as n,D as a,o as c,s as m,a as e,w as t,f as d,b as f,V as _,F as u}from"./index-5ac7c267.js";const p=["innerHTML"],v=n({__name:"MaterialIcons",setup(b){const s=a({title:"Material Icons"}),r=a('<iframe src="https://materialdesignicons.com/" frameborder="0" width="100%" height="1000"></iframe>'),l=a([{title:"Icons",disabled:!1,href:"#"},{title:"Material Icons",disabled:!0,href:"#"}]);return(h,M)=>(c(),m(u,null,[e(o,{title:s.value.title,breadcrumbs:l.value},null,8,["title","breadcrumbs"]),e(_,null,{default:t(()=>[e(d,{cols:"12",md:"12"},{default:t(()=>[e(i,{title:"Material Icons"},{default:t(()=>[f("div",{innerHTML:r.value},null,8,p)]),_:1})]),_:1})]),_:1})],64))}});export{v as default};
|
||||
@@ -1 +0,0 @@
|
||||
.custom-devider{border-color:#00000014!important}.googleBtn{border-color:#00000014;margin:30px 0 20px}.outlinedInput .v-field{border:1px solid rgba(0,0,0,.08);box-shadow:none}.orbtn{padding:2px 40px;border-color:#00000014;margin:20px 15px}.pwdInput{position:relative}.pwdInput .v-input__append{position:absolute;right:10px;top:50%;transform:translateY(-50%)}.loginBox{max-width:475px;margin:0 auto}
|
||||
@@ -1 +0,0 @@
|
||||
import{_ as B}from"./LogoDark.vue_vue_type_script_setup_true_lang-d555e5be.js";import{x as y,D as o,o as b,s as U,a as e,w as a,b as n,B as $,d as u,f as d,A as _,e as f,V as r,O as m,aq as q,av as A,F as E,c as F,N as T,J as V,L as P}from"./index-5ac7c267.js";const z="/assets/social-google-9b2fa67a.svg",N=["src"],S=n("span",{class:"ml-2"},"Sign up with Google",-1),D=n("h5",{class:"text-h5 text-center my-4 mb-8"},"Sign up with Email address",-1),G={class:"d-sm-inline-flex align-center mt-2 mb-7 mb-sm-0 font-weight-bold"},L=n("a",{href:"#",class:"ml-1 text-lightText"},"Terms and Condition",-1),O={class:"mt-5 text-right"},j=y({__name:"AuthRegister",setup(w){const c=o(!1),i=o(!1),p=o(""),v=o(""),g=o(),h=o(""),x=o(""),k=o([s=>!!s||"Password is required",s=>s&&s.length<=10||"Password must be less than 10 characters"]),C=o([s=>!!s||"E-mail is required",s=>/.+@.+\..+/.test(s)||"E-mail must be valid"]);function R(){g.value.validate()}return(s,l)=>(b(),U(E,null,[e(u,{block:"",color:"primary",variant:"outlined",class:"text-lightText googleBtn"},{default:a(()=>[n("img",{src:$(z),alt:"google"},null,8,N),S]),_:1}),e(r,null,{default:a(()=>[e(d,{class:"d-flex align-center"},{default:a(()=>[e(_,{class:"custom-devider"}),e(u,{variant:"outlined",class:"orbtn",rounded:"md",size:"small"},{default:a(()=>[f("OR")]),_:1}),e(_,{class:"custom-devider"})]),_:1})]),_:1}),D,e(A,{ref_key:"Regform",ref:g,"lazy-validation":"",action:"/dashboards/analytical",class:"mt-7 loginForm"},{default:a(()=>[e(r,null,{default:a(()=>[e(d,{cols:"12",sm:"6"},{default:a(()=>[e(m,{modelValue:h.value,"onUpdate:modelValue":l[0]||(l[0]=t=>h.value=t),density:"comfortable","hide-details":"auto",variant:"outlined",color:"primary",label:"Firstname"},null,8,["modelValue"])]),_:1}),e(d,{cols:"12",sm:"6"},{default:a(()=>[e(m,{modelValue:x.value,"onUpdate:modelValue":l[1]||(l[1]=t=>x.value=t),density:"comfortable","hide-details":"auto",variant:"outlined",color:"primary",label:"Lastname"},null,8,["modelValue"])]),_:1})]),_:1}),e(m,{modelValue:v.value,"onUpdate:modelValue":l[2]||(l[2]=t=>v.value=t),rules:C.value,label:"Email Address / Username",class:"mt-4 mb-4",required:"",density:"comfortable","hide-details":"auto",variant:"outlined",color:"primary"},null,8,["modelValue","rules"]),e(m,{modelValue:p.value,"onUpdate:modelValue":l[3]||(l[3]=t=>p.value=t),rules:k.value,label:"Password",required:"",density:"comfortable",variant:"outlined",color:"primary","hide-details":"auto","append-icon":i.value?"mdi-eye":"mdi-eye-off",type:i.value?"text":"password","onClick:append":l[4]||(l[4]=t=>i.value=!i.value),class:"pwdInput"},null,8,["modelValue","rules","append-icon","type"]),n("div",G,[e(q,{modelValue:c.value,"onUpdate:modelValue":l[5]||(l[5]=t=>c.value=t),rules:[t=>!!t||"You must agree to continue!"],label:"Agree with?",required:"",color:"primary",class:"ms-n2","hide-details":""},null,8,["modelValue","rules"]),L]),e(u,{color:"secondary",block:"",class:"mt-2",variant:"flat",size:"large",onClick:l[6]||(l[6]=t=>R())},{default:a(()=>[f("Sign Up")]),_:1})]),_:1},512),n("div",O,[e(_),e(u,{variant:"plain",to:"/auth/login",class:"mt-2 text-capitalize mr-n2"},{default:a(()=>[f("Already have an account?")]),_:1})])],64))}});const I={class:"pa-7 pa-sm-12"},J=n("h2",{class:"text-secondary text-h2 mt-8"},"Sign up",-1),Y=n("h4",{class:"text-disabled text-h4 mt-3"},"Enter credentials to continue",-1),M=y({__name:"RegisterPage",setup(w){return(c,i)=>(b(),F(r,{class:"h-100vh","no-gutters":""},{default:a(()=>[e(d,{cols:"12",class:"d-flex align-center bg-lightprimary"},{default:a(()=>[e(T,null,{default:a(()=>[n("div",I,[e(r,{justify:"center"},{default:a(()=>[e(d,{cols:"12",lg:"10",xl:"6",md:"7"},{default:a(()=>[e(V,{elevation:"0",class:"loginBox"},{default:a(()=>[e(V,{variant:"outlined"},{default:a(()=>[e(P,{class:"pa-9"},{default:a(()=>[e(r,null,{default:a(()=>[e(d,{cols:"12",class:"text-center"},{default:a(()=>[e(B),J,Y]),_:1})]),_:1}),e(j)]),_:1})]),_:1})]),_:1})]),_:1})]),_:1})])]),_:1})]),_:1})]),_:1}))}});export{M as default};
|
||||
1
dashboard/dist/assets/ShadowPage-4758709f.js
vendored
1
dashboard/dist/assets/ShadowPage-4758709f.js
vendored
@@ -1 +0,0 @@
|
||||
import{_ as c}from"./BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js";import{_ as f}from"./UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js";import{x as m,D as s,o as l,s as r,a as e,w as a,f as i,V as o,F as d,u as _,J as p,X as b,b as h,t as g}from"./index-5ac7c267.js";const v=m({__name:"ShadowPage",setup(w){const n=s({title:"Shadow Page"}),u=s([{title:"Utilities",disabled:!1,href:"#"},{title:"Shadow",disabled:!0,href:"#"}]);return(V,x)=>(l(),r(d,null,[e(c,{title:n.value.title,breadcrumbs:u.value},null,8,["title","breadcrumbs"]),e(o,null,{default:a(()=>[e(i,{cols:"12",md:"12"},{default:a(()=>[e(f,{title:"Basic Shadow"},{default:a(()=>[e(o,{justify:"center"},{default:a(()=>[(l(),r(d,null,_(25,t=>e(i,{key:t,cols:"auto"},{default:a(()=>[e(p,{height:"100",width:"100",class:b(["mb-5",["d-flex justify-center align-center bg-primary",`elevation-${t}`]])},{default:a(()=>[h("div",null,g(t-1),1)]),_:2},1032,["class"])]),_:2},1024)),64))]),_:1})]),_:1})]),_:1})]),_:1})],64))}});export{v as default};
|
||||
@@ -1 +0,0 @@
|
||||
import{_ as o}from"./BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js";import{_ as n}from"./UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js";import{x as c,D as a,o as i,s as m,a as e,w as t,f as d,b as f,V as _,F as u}from"./index-5ac7c267.js";const b=["innerHTML"],w=c({__name:"TablerIcons",setup(p){const s=a({title:"Tabler Icons"}),r=a('<iframe src="https://tablericons.com/" frameborder="0" width="100%" height="600"></iframe>'),l=a([{title:"Icons",disabled:!1,href:"#"},{title:"Tabler Icons",disabled:!0,href:"#"}]);return(h,T)=>(i(),m(u,null,[e(o,{title:s.value.title,breadcrumbs:l.value},null,8,["title","breadcrumbs"]),e(_,null,{default:t(()=>[e(d,{cols:"12",md:"12"},{default:t(()=>[e(n,{title:"Tabler Icons"},{default:t(()=>[f("div",{innerHTML:r.value},null,8,b)]),_:1})]),_:1})]),_:1})],64))}});export{w as default};
|
||||
@@ -1 +0,0 @@
|
||||
import{_ as m}from"./BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js";import{_ as v}from"./UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js";import{x as f,o as i,c as g,w as e,a,a8 as y,K as b,e as w,t as d,A as C,L as V,a9 as L,J as _,D as o,s as h,f as k,b as t,F as x,u as B,X as H,V as T}from"./index-5ac7c267.js";const s=f({__name:"UiChildCard",props:{title:String},setup(r){const l=r;return(n,c)=>(i(),g(_,{variant:"outlined"},{default:e(()=>[a(y,{class:"py-3"},{default:e(()=>[a(b,{class:"text-h5"},{default:e(()=>[w(d(l.title),1)]),_:1})]),_:1}),a(C),a(V,null,{default:e(()=>[L(n.$slots,"default")]),_:3})]),_:3}))}}),D={class:"d-flex flex-column gap-1"},S={class:"text-caption pa-2 bg-lightprimary"},z=t("div",{class:"text-grey"},"Class",-1),N={class:"font-weight-medium"},$=t("div",null,[t("p",{class:"text-left"},"Left aligned on all viewport sizes."),t("p",{class:"text-center"},"Center aligned on all viewport sizes."),t("p",{class:"text-right"},"Right aligned on all viewport sizes."),t("p",{class:"text-sm-left"},"Left aligned on viewports SM (small) or wider."),t("p",{class:"text-right text-md-left"},"Left aligned on viewports MD (medium) or wider."),t("p",{class:"text-right text-lg-left"},"Left aligned on viewports LG (large) or wider."),t("p",{class:"text-right text-xl-left"},"Left aligned on viewports XL (extra-large) or wider.")],-1),M=t("div",{class:"d-flex justify-space-between flex-row"},[t("a",{href:"#",class:"text-decoration-none"},"Non-underlined link"),t("div",{class:"text-decoration-line-through"},"Line-through text"),t("div",{class:"text-decoration-overline"},"Overline text"),t("div",{class:"text-decoration-underline"},"Underline text")],-1),O=t("div",null,[t("p",{class:"text-high-emphasis"},"High-emphasis has an opacity of 87% in light theme and 100% in dark."),t("p",{class:"text-medium-emphasis"},"Medium-emphasis text and hint text have opacities of 60% in light theme and 70% in dark."),t("p",{class:"text-disabled"},"Disabled text has an opacity of 38% in light theme and 50% in dark.")],-1),j=f({__name:"TypographyPage",setup(r){const l=o({title:"Typography Page"}),n=o([["Heading 1","text-h1"],["Heading 2","text-h2"],["Heading 3","text-h3"],["Heading 4","text-h4"],["Heading 5","text-h5"],["Heading 6","text-h6"],["Subtitle 1","text-subtitle-1"],["Subtitle 2","text-subtitle-2"],["Body 1","text-body-1"],["Body 2","text-body-2"],["Button","text-button"],["Caption","text-caption"],["Overline","text-overline"]]),c=o([{title:"Utilities",disabled:!1,href:"#"},{title:"Typography",disabled:!0,href:"#"}]);return(U,F)=>(i(),h(x,null,[a(m,{title:l.value.title,breadcrumbs:c.value},null,8,["title","breadcrumbs"]),a(T,null,{default:e(()=>[a(k,{cols:"12",md:"12"},{default:e(()=>[a(v,{title:"Basic Typography"},{default:e(()=>[a(s,{title:"Heading"},{default:e(()=>[t("div",D,[(i(!0),h(x,null,B(n.value,([p,u])=>(i(),g(_,{variant:"outlined",key:p,class:"my-4"},{default:e(()=>[t("div",{class:H([u,"pa-2"])},d(p),3),t("div",S,[z,t("div",N,d(u),1)])]),_:2},1024))),128))])]),_:1}),a(s,{title:"Text-alignment",class:"mt-8"},{default:e(()=>[$]),_:1}),a(s,{title:"Decoration",class:"mt-8"},{default:e(()=>[M]),_:1}),a(s,{title:"Opacity",class:"mt-8"},{default:e(()=>[O]),_:1})]),_:1})]),_:1})]),_:1})],64))}});export{j as default};
|
||||
@@ -1 +0,0 @@
|
||||
import{x as n,o,c as i,w as e,a,a8 as d,b as c,K as u,e as p,t as _,a9 as s,A as f,L as V,J as m}from"./index-5ac7c267.js";const C={class:"d-sm-flex align-center justify-space-between"},h=n({__name:"UiParentCard",props:{title:String},setup(l){const r=l;return(t,x)=>(o(),i(m,{variant:"outlined",elevation:"0",class:"withbg"},{default:e(()=>[a(d,null,{default:e(()=>[c("div",C,[a(u,null,{default:e(()=>[p(_(r.title),1)]),_:1}),s(t.$slots,"action")])]),_:3}),a(f),a(V,null,{default:e(()=>[s(t.$slots,"default")]),_:3})]),_:3}))}});export{h as _};
|
||||
34
dashboard/dist/assets/img-error-bg-41f65efa.svg
vendored
34
dashboard/dist/assets/img-error-bg-41f65efa.svg
vendored
@@ -1,34 +0,0 @@
|
||||
<svg width="676" height="391" viewBox="0 0 676 391" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g opacity="0.09">
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 4.49127 197.53)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 342.315 387.578)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 28.0057 211.105)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 365.829 374.002)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 51.52 224.68)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 389.344 360.428)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 75.0345 238.255)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 412.858 346.852)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 98.5488 251.83)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 436.372 333.277)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 122.063 265.405)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 459.887 319.703)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 145.578 278.979)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 483.401 306.127)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 169.092 292.556)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 506.916 292.551)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 192.597 306.127)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 530.43 278.977)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 216.111 319.703)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 553.944 265.402)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 239.626 333.277)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 577.459 251.827)" stroke="black"/>
|
||||
<path d="M263.231 346.905L601.064 151.871" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 600.973 238.252)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 286.654 360.428)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 624.487 224.677)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 310.169 374.002)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 648.002 211.102)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(0.866041 -0.499972 -0.866041 -0.499972 333.683 387.578)" stroke="black"/>
|
||||
<line y1="-0.5" x2="390.089" y2="-0.5" transform="matrix(-0.866041 -0.499972 -0.866041 0.499972 671.516 197.527)" stroke="black"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.9 KiB |
@@ -1,43 +0,0 @@
|
||||
<svg width="676" height="395" viewBox="0 0 676 395" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="26.998" height="26.8293" transform="matrix(0.866041 -0.499972 0.866041 0.499972 361.873 290.126)" fill="#E3F2FD"/>
|
||||
<rect width="24.2748" height="24.1231" transform="matrix(0.866041 -0.499972 0.866041 0.499972 364.249 291.115)" fill="#90CAF9"/>
|
||||
<rect width="26.998" height="26.8293" transform="matrix(0.866041 -0.499972 0.866041 0.499972 291.67 86.4912)" fill="#E3F2FD"/>
|
||||
<rect width="24.2748" height="24.1231" transform="matrix(0.866041 -0.499972 0.866041 0.499972 294.046 87.48)" fill="#90CAF9"/>
|
||||
<g filter="url(#filter0_d)">
|
||||
<path d="M370.694 211.828L365.394 208.768V215.835L365.404 215.829C365.459 216.281 365.785 216.724 366.383 217.069L417.03 246.308C418.347 247.068 420.481 247.068 421.798 246.308L468.671 219.248C469.374 218.842 469.702 218.301 469.654 217.77V210.861L464.282 213.962L418.024 187.257C416.708 186.497 414.573 186.497 413.257 187.257L370.694 211.828Z" fill="url(#paint0_linear)"/>
|
||||
</g>
|
||||
<rect width="59.6284" height="63.9858" rx="5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 364 208.812)" fill="#90CAF9"/>
|
||||
<rect width="59.6284" height="63.9858" rx="5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 364 208.812)" fill="url(#paint1_linear)"/>
|
||||
<rect width="56.6816" height="60.8238" rx="5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 366.645 208.761)" fill="url(#paint2_linear)"/>
|
||||
<path d="M421.238 206.161C421.238 206.434 421.62 206.655 422.092 206.655L432.159 206.656C435.164 206.656 437.6 208.063 437.601 209.798C437.602 211.533 435.166 212.939 432.162 212.938L422.09 212.937C421.62 212.937 421.24 213.157 421.24 213.428L421.241 215.814C421.241 216.087 421.624 216.308 422.096 216.308L432.689 216.309C438.917 216.31 443.967 213.395 443.965 209.799C443.964 206.202 438.914 203.286 432.684 203.286L422.086 203.284C421.617 203.284 421.236 203.504 421.237 203.775L421.238 206.161Z" fill="#1E88E5"/>
|
||||
<path d="M413.422 213.43C413.422 213.157 413.039 212.936 412.567 212.936L402.896 212.935C399.891 212.935 397.455 211.528 397.454 209.793C397.453 208.059 399.889 206.652 402.894 206.653L412.57 206.654C413.039 206.654 413.419 206.435 413.419 206.164L413.418 203.777C413.418 203.504 413.035 203.283 412.563 203.283L402.366 203.282C396.138 203.281 391.089 206.197 391.09 209.793C391.091 213.389 396.141 216.305 402.371 216.306L412.573 216.307C413.042 216.307 413.423 216.088 413.423 215.817L413.422 213.43Z" fill="#1E88E5"/>
|
||||
<path d="M407.999 198.145L411.211 201.235C411.266 201.288 411.332 201.336 411.405 201.379C411.813 201.614 412.461 201.669 412.979 201.49C413.59 201.278 413.787 200.821 413.421 200.469L410.209 197.379C409.843 197.027 409.051 196.913 408.441 197.124C407.831 197.335 407.633 197.793 407.999 198.145Z" fill="#1E88E5"/>
|
||||
<path d="M416.235 200.853C416.235 201.058 416.38 201.244 416.613 201.379C416.846 201.513 417.168 201.597 417.524 201.597C418.236 201.596 418.813 201.263 418.813 200.852L418.812 197.021C418.811 196.61 418.234 196.277 417.522 196.277C416.811 196.278 416.234 196.611 416.234 197.022L416.235 200.853Z" fill="#1E88E5"/>
|
||||
<path d="M421.627 200.47C421.317 200.769 421.412 201.143 421.82 201.379C421.893 201.421 421.977 201.459 422.069 201.491C422.68 201.703 423.472 201.588 423.838 201.236L427.047 198.147C427.413 197.794 427.215 197.337 426.605 197.126C425.994 196.915 425.203 197.029 424.836 197.381L421.627 200.47Z" fill="#1E88E5"/>
|
||||
<path d="M427.056 221.447L423.844 218.357C423.478 218.005 422.686 217.891 422.076 218.102C421.466 218.314 421.268 218.771 421.634 219.123L424.846 222.213C424.901 222.266 424.967 222.314 425.04 222.357C425.448 222.592 426.097 222.647 426.614 222.468C427.225 222.257 427.423 221.799 427.056 221.447Z" fill="#1E88E5"/>
|
||||
<path d="M418.82 218.739C418.82 218.328 418.243 217.995 417.531 217.995C416.819 217.995 416.242 218.329 416.242 218.74L416.243 222.57C416.244 222.776 416.388 222.962 416.621 223.096C416.854 223.231 417.177 223.314 417.533 223.314C418.245 223.314 418.822 222.981 418.821 222.57L418.82 218.739Z" fill="#1E88E5"/>
|
||||
<path d="M413.428 219.122C413.794 218.77 413.596 218.312 412.986 218.101C412.375 217.89 411.584 218.004 411.217 218.356L408.008 221.445C407.698 221.744 407.793 222.118 408.201 222.354C408.274 222.396 408.358 222.434 408.45 222.466C409.061 222.678 409.853 222.563 410.219 222.211L413.428 219.122Z" fill="#1E88E5"/>
|
||||
<defs>
|
||||
<filter id="filter0_d" x="301.394" y="186.687" width="232.264" height="208.191" filterUnits="userSpaceOnUse" color-interpolation-filters="sRGB">
|
||||
<feFlood flood-opacity="0" result="BackgroundImageFix"/>
|
||||
<feColorMatrix in="SourceAlpha" type="matrix" values="0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 127 0"/>
|
||||
<feOffset dy="84"/>
|
||||
<feGaussianBlur stdDeviation="32"/>
|
||||
<feColorMatrix type="matrix" values="0 0 0 0 0.129412 0 0 0 0 0.588235 0 0 0 0 0.952941 0 0 0 0.2 0"/>
|
||||
<feBlend mode="normal" in2="BackgroundImageFix" result="effect1_dropShadow"/>
|
||||
<feBlend mode="normal" in="SourceGraphic" in2="effect1_dropShadow" result="shape"/>
|
||||
</filter>
|
||||
<linearGradient id="paint0_linear" x1="417.526" y1="205.789" x2="365.394" y2="216.782" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#2196F3"/>
|
||||
<stop offset="1" stop-color="#B1DCFF"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="paint1_linear" x1="0.503035" y1="2.68177" x2="20.3032" y2="42.2842" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#FAFAFA" stop-opacity="0.74"/>
|
||||
<stop offset="1" stop-color="#91CBFA"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="paint2_linear" x1="-18.5494" y1="-44.8799" x2="14.7845" y2="40.5766" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#FAFAFA" stop-opacity="0.74"/>
|
||||
<stop offset="1" stop-color="#91CBFA"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 5.5 KiB |
@@ -1,42 +0,0 @@
|
||||
<svg width="710" height="391" viewBox="0 0 710 391" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="26.9258" height="26.7576" transform="matrix(0.866041 -0.499972 0.866041 0.499972 161.088 154.333)" fill="#EDE7F6"/>
|
||||
<rect width="24.9267" height="24.7709" transform="matrix(0.866041 -0.499972 0.866041 0.499972 162.809 155.327)" fill="#B39DDB"/>
|
||||
<rect width="26.9258" height="26.7576" transform="matrix(0.866041 -0.499972 0.866041 0.499972 536.744 181.299)" fill="#EDE7F6"/>
|
||||
<rect width="24.9267" height="24.7709" transform="matrix(0.866041 -0.499972 0.866041 0.499972 538.465 182.292)" fill="#B39DDB"/>
|
||||
<g filter="url(#filter0_d)">
|
||||
<path d="M67.7237 137.573V134.673H64.009V140.824L64.0177 140.829C64.0367 141.477 64.4743 142.121 65.3305 142.615L103.641 164.733C105.393 165.744 108.232 165.744 109.983 164.733L204.044 110.431C204.879 109.949 205.316 109.324 205.355 108.693L205.355 108.692V108.68C205.358 108.628 205.358 108.576 205.355 108.523L205.362 102.335L200.065 104.472L165.733 84.6523C163.982 83.6413 161.142 83.6413 159.391 84.6523L67.7237 137.573Z" fill="url(#paint0_linear)"/>
|
||||
</g>
|
||||
<rect width="115.933" height="51.5596" rx="5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 62.1588 134.683)" fill="#673AB7"/>
|
||||
<rect width="115.933" height="51.5596" rx="5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 62.1588 134.683)" fill="url(#paint1_linear)" fill-opacity="0.3"/>
|
||||
<mask id="mask0" mask-type="alpha" maskUnits="userSpaceOnUse" x="64" y="78" width="141" height="81">
|
||||
<rect width="115.933" height="51.5596" rx="5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 62.1588 134.683)" fill="#673AB7"/>
|
||||
</mask>
|
||||
<g mask="url(#mask0)">
|
||||
</g>
|
||||
<mask id="mask1" mask-type="alpha" maskUnits="userSpaceOnUse" x="64" y="78" width="141" height="81">
|
||||
<rect width="115.933" height="51.5596" rx="5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 62.1588 134.683)" fill="#673AB7"/>
|
||||
</mask>
|
||||
<g mask="url(#mask1)">
|
||||
<rect width="64.3732" height="64.3732" rx="5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 111.303 81.6006)" fill="#5E35B1"/>
|
||||
<rect opacity="0.7" x="0.866041" width="63.3732" height="63.3732" rx="4.5" transform="matrix(0.866041 -0.499972 0.866041 0.499972 79.1848 87.8305)" stroke="#5E35B1"/>
|
||||
</g>
|
||||
<defs>
|
||||
<filter id="filter0_d" x="0.0090332" y="83.894" width="269.353" height="229.597" filterUnits="userSpaceOnUse" color-interpolation-filters="sRGB">
|
||||
<feFlood flood-opacity="0" result="BackgroundImageFix"/>
|
||||
<feColorMatrix in="SourceAlpha" type="matrix" values="0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 127 0"/>
|
||||
<feOffset dy="84"/>
|
||||
<feGaussianBlur stdDeviation="32"/>
|
||||
<feColorMatrix type="matrix" values="0 0 0 0 0.403922 0 0 0 0 0.227451 0 0 0 0 0.717647 0 0 0 0.2 0"/>
|
||||
<feBlend mode="normal" in2="BackgroundImageFix" result="effect1_dropShadow"/>
|
||||
<feBlend mode="normal" in="SourceGraphic" in2="effect1_dropShadow" result="shape"/>
|
||||
</filter>
|
||||
<linearGradient id="paint0_linear" x1="200.346" y1="102.359" x2="71.0293" y2="158.071" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#A491C8"/>
|
||||
<stop offset="1" stop-color="#D7C5F8"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="paint1_linear" x1="8.1531" y1="-0.145767" x2="57.1962" y2="72.3003" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="white"/>
|
||||
<stop offset="1" stop-color="white" stop-opacity="0"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.3 KiB |
@@ -1,27 +0,0 @@
|
||||
<svg width="676" height="391" viewBox="0 0 676 391" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M267.744 237.142L279.699 230.24L300.636 242.329L288.682 249.231L313.566 263.598L286.344 279.314L261.46 264.947L215.984 291.203L197.779 282.558L169.334 211.758L169.092 211.618L196.313 195.902L267.744 237.142ZM219.359 265.077L240.523 252.859L204.445 232.029L205.487 234.589L219.359 265.077Z" fill="#FFAB91"/>
|
||||
<path d="M469.959 120.206L481.913 113.304L502.851 125.392L490.897 132.294L515.78 146.661L488.559 162.377L463.675 148.011L418.199 174.266L399.994 165.621L371.548 94.8211L371.307 94.6816L398.528 78.9654L469.959 120.206ZM421.574 148.141L442.737 135.922L406.66 115.093L407.701 117.653L421.574 148.141Z" fill="#FFAB91"/>
|
||||
<path d="M204.523 235.027V232.237L219.401 265.014L240.555 252.926V255.018L218.936 267.339L204.523 235.027Z" fill="#D84315"/>
|
||||
<path d="M406.738 118.09V115.301L421.616 148.078L442.77 135.99V138.082L421.151 150.402L406.738 118.09Z" fill="#D84315"/>
|
||||
<rect width="109.114" height="136.405" transform="matrix(0.866025 -0.5 0.866025 0.5 220.507 181.925)" fill="url(#paint0_linear)"/>
|
||||
<rect width="40.2357" height="70.0545" transform="matrix(0.866025 -0.5 0.866025 0.5 280.437 201.886)" fill="url(#paint1_linear)"/>
|
||||
<rect x="25.1147" width="80.1144" height="107.405" transform="matrix(0.866025 -0.5 0.866025 0.5 223.872 194.482)" stroke="#1565C0" stroke-width="29"/>
|
||||
<rect x="25.1147" width="80.1144" height="107.405" transform="matrix(0.866025 -0.5 0.866025 0.5 223.872 194.482)" stroke="url(#paint2_linear)" stroke-width="29"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M279.517 230.177L267.662 237.15L196.064 195.772L168.866 211.58L169.331 212.097L170.096 214.002L196.436 198.795L267.866 240.035L279.821 233.133L298.211 243.751L300.787 242.265L279.517 230.177ZM291.278 250.695L288.804 252.124L311.1 264.996L313.805 263.418L291.278 250.695Z" fill="#D84315"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M481.732 113.24L469.877 120.214L398.279 78.8359L371.081 94.6433L371.546 95.1603L372.311 97.0652L398.651 81.8581L470.081 123.099L482.036 116.196L500.426 126.814L503.002 125.328L481.732 113.24ZM493.493 133.759L491.019 135.187L513.315 148.06L516.02 146.482L493.493 133.759Z" fill="#D84315"/>
|
||||
<path d="M288.674 252.229V249.207L291.929 251.067L288.674 252.229Z" fill="#D84315"/>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear" x1="77.7511" y1="139.902" x2="-10.8629" y2="8.75671" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#3076C8"/>
|
||||
<stop offset="0.992076" stop-color="#91CBFA"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="paint1_linear" x1="25.8162" y1="51.0447" x2="68.7073" y2="-5.41524" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#2E75C7"/>
|
||||
<stop offset="1" stop-color="#4283CC"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="paint2_linear" x1="-16.1224" y1="-47.972" x2="123.494" y2="290.853" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="white"/>
|
||||
<stop offset="1" stop-color="white" stop-opacity="0"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.9 KiB |
720
dashboard/dist/assets/index-5ac7c267.js
vendored
720
dashboard/dist/assets/index-5ac7c267.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
716
dashboard/dist/assets/index-a2f0b905.js
vendored
Normal file
716
dashboard/dist/assets/index-a2f0b905.js
vendored
Normal file
File diff suppressed because one or more lines are too long
9
dashboard/dist/assets/md5-086248bf.js
vendored
9
dashboard/dist/assets/md5-086248bf.js
vendored
File diff suppressed because one or more lines are too long
9
dashboard/dist/assets/md5-f95c7b53.js
vendored
Normal file
9
dashboard/dist/assets/md5-f95c7b53.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -1,6 +0,0 @@
|
||||
<svg width="22" height="22" viewBox="0 0 22 22" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M5.06129 13.2253L4.31871 15.9975L1.60458 16.0549C0.793457 14.5504 0.333374 12.8292 0.333374 11C0.333374 9.23119 0.763541 7.56319 1.52604 6.09448H1.52662L3.94296 6.53748L5.00146 8.93932C4.77992 9.58519 4.65917 10.2785 4.65917 11C4.65925 11.783 4.80108 12.5332 5.06129 13.2253Z" fill="#FBBB00"/>
|
||||
<path d="M21.4804 9.00732C21.6029 9.65257 21.6668 10.3189 21.6668 11C21.6668 11.7637 21.5865 12.5086 21.4335 13.2271C20.9143 15.6722 19.5575 17.8073 17.678 19.3182L17.6774 19.3177L14.6339 19.1624L14.2031 16.4734C15.4503 15.742 16.425 14.5974 16.9384 13.2271H11.2346V9.00732H17.0216H21.4804Z" fill="#518EF8"/>
|
||||
<path d="M17.6772 19.3176L17.6777 19.3182C15.8498 20.7875 13.5277 21.6666 11 21.6666C6.93783 21.6666 3.40612 19.3962 1.60449 16.0549L5.0612 13.2253C5.96199 15.6294 8.28112 17.3408 11 17.3408C12.1686 17.3408 13.2634 17.0249 14.2029 16.4734L17.6772 19.3176Z" fill="#28B446"/>
|
||||
<path d="M17.8085 2.78892L14.353 5.61792C13.3807 5.01017 12.2313 4.65908 11 4.65908C8.21963 4.65908 5.85713 6.44896 5.00146 8.93925L1.52658 6.09442H1.526C3.30125 2.67171 6.8775 0.333252 11 0.333252C13.5881 0.333252 15.9612 1.25517 17.8085 2.78892Z" fill="#F14336"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.2 KiB |
4
dashboard/dist/index.html
vendored
4
dashboard/dist/index.html
vendored
@@ -11,8 +11,8 @@
|
||||
href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Poppins:wght@400;500;600;700&family=Roboto:wght@400;500;700&display=swap"
|
||||
/>
|
||||
<title>AstrBot - 仪表盘</title>
|
||||
<script type="module" crossorigin src="/assets/index-5ac7c267.js"></script>
|
||||
<link rel="stylesheet" href="/assets/index-0f1523f3.css">
|
||||
<script type="module" crossorigin src="/assets/index-a2f0b905.js"></script>
|
||||
<link rel="stylesheet" href="/assets/index-86dd25ba.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
|
||||
@@ -1,537 +0,0 @@
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
from . import DashBoardData
|
||||
from typing import Union, Optional
|
||||
from util.cmd_config import CmdConfig
|
||||
from dataclasses import dataclass
|
||||
from util.plugin_dev.api.v1.config import update_config
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
@dataclass
|
||||
class DashBoardConfig():
|
||||
config_type: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
path: Optional[str] = None # 仅 item 才需要
|
||||
body: Optional[list['DashBoardConfig']] = None # 仅 group 才需要
|
||||
value: Optional[Union[list, dict, str, int, bool]] = None # 仅 item 才需要
|
||||
val_type: Optional[str] = None # 仅 item 才需要
|
||||
|
||||
|
||||
class DashBoardHelper():
|
||||
def __init__(self, context: Context, dashboard_data: DashBoardData):
|
||||
dashboard_data.configs = {
|
||||
"data": []
|
||||
}
|
||||
self.context = context
|
||||
self.parse_default_config(dashboard_data, context.base_config)
|
||||
|
||||
# 将 config.yaml、 中的配置解析到 dashboard_data.configs 中
|
||||
def parse_default_config(self, dashboard_data: DashBoardData, config: dict):
|
||||
|
||||
try:
|
||||
qq_official_platform_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="QQ(官方)",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启用 QQ_OFFICIAL 平台",
|
||||
description="官方的接口,仅支持 QQ 频道。详见 q.qq.com",
|
||||
value=config['qqbot']['enable'],
|
||||
path="qqbot.enable",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="QQ机器人APPID",
|
||||
description="详见 q.qq.com",
|
||||
value=config['qqbot']['appid'],
|
||||
path="qqbot.appid",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="QQ机器人令牌",
|
||||
description="详见 q.qq.com",
|
||||
value=config['qqbot']['token'],
|
||||
path="qqbot.token",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="QQ机器人 Secret",
|
||||
description="详见 q.qq.com",
|
||||
value=config['qqbot_secret'],
|
||||
path="qqbot_secret",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否允许 QQ 频道私聊",
|
||||
description="如果启用,机器人会响应私聊消息。",
|
||||
value=config['direct_message_mode'],
|
||||
path="direct_message_mode",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否接收QQ群消息",
|
||||
description="需要机器人有相应的群消息接收权限。在 q.qq.com 上查看。",
|
||||
value=config['qqofficial_enable_group_message'],
|
||||
path="qqofficial_enable_group_message",
|
||||
),
|
||||
]
|
||||
)
|
||||
qq_gocq_platform_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="QQ(nakuru)",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启用",
|
||||
description="",
|
||||
value=config['gocqbot']['enable'],
|
||||
path="gocqbot.enable",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="HTTP 服务器地址",
|
||||
description="",
|
||||
value=config['gocq_host'],
|
||||
path="gocq_host",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="HTTP 服务器端口",
|
||||
description="",
|
||||
value=config['gocq_http_port'],
|
||||
path="gocq_http_port",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="WebSocket 服务器端口",
|
||||
description="目前仅支持正向 WebSocket",
|
||||
value=config['gocq_websocket_port'],
|
||||
path="gocq_websocket_port",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否响应群消息",
|
||||
description="",
|
||||
value=config['gocq_react_group'],
|
||||
path="gocq_react_group",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否响应私聊消息",
|
||||
description="",
|
||||
value=config['gocq_react_friend'],
|
||||
path="gocq_react_friend",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否响应群成员增加消息",
|
||||
description="",
|
||||
value=config['gocq_react_group_increase'],
|
||||
path="gocq_react_group_increase",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否响应频道消息",
|
||||
description="",
|
||||
value=config['gocq_react_guild'],
|
||||
path="gocq_react_guild",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="转发阈值(字符数)",
|
||||
description="机器人回复的消息长度超出这个值后,会被折叠成转发卡片发出以减少刷屏。",
|
||||
value=config['qq_forward_threshold'],
|
||||
path="qq_forward_threshold",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
qq_aiocqhttp_platform_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="QQ(aiocqhttp)",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启用",
|
||||
description="",
|
||||
value=config['aiocqhttp']['enable'],
|
||||
path="aiocqhttp.enable",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="WebSocket 反向连接 host",
|
||||
description="",
|
||||
value=config['aiocqhttp']['ws_reverse_host'],
|
||||
path="aiocqhttp.ws_reverse_host",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="WebSocket 反向连接 port",
|
||||
description="",
|
||||
value=config['aiocqhttp']['ws_reverse_port'],
|
||||
path="aiocqhttp.ws_reverse_port",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
general_platform_detail_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="通用平台配置",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启动消息文字转图片",
|
||||
description="启动后,机器人会将消息转换为图片发送,以降低风控风险。",
|
||||
value=config['qq_pic_mode'],
|
||||
path="qq_pic_mode",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="消息限制时间",
|
||||
description="在此时间内,机器人不会回复同一个用户的消息。单位:秒",
|
||||
value=config['limit']['time'],
|
||||
path="limit.time",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="消息限制次数",
|
||||
description="在上面的时间内,如果用户发送消息超过此次数,则机器人不会回复。单位:次",
|
||||
value=config['limit']['count'],
|
||||
path="limit.count",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="回复前缀",
|
||||
description="[xxxx] 你好! 其中xxxx是你可以填写的前缀。如果为空则不显示。",
|
||||
value=config['reply_prefix'],
|
||||
path="reply_prefix",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="list",
|
||||
name="通用管理员用户 ID(支持多个管理员)。通过 !myid 指令获取。",
|
||||
description="",
|
||||
value=config['other_admins'],
|
||||
path="other_admins",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="独立会话",
|
||||
description="是否启用独立会话模式,即 1 个用户自然账号 1 个会话。",
|
||||
value=config['uniqueSessionMode'],
|
||||
path="uniqueSessionMode",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="LLM 唤醒词",
|
||||
description="如果不为空, 那么只有当消息以此词开头时,才会调用大语言模型进行回复。如设置为 /chat,那么只有当消息以 /chat 开头时,才会调用大语言模型进行回复。",
|
||||
value=config['llm_wake_prefix'],
|
||||
path="llm_wake_prefix",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
openai_official_llm_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="OpenAI 官方接口类设置",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="list",
|
||||
name="OpenAI API Key",
|
||||
description="OpenAI API 的 Key。支持使用非官方但兼容的 API(第三方中转key)。",
|
||||
value=config['openai']['key'],
|
||||
path="openai.key",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI API 节点地址(api base)",
|
||||
description="OpenAI API 的节点地址,配合非官方 API 使用。如果不想填写,那么请填写 none",
|
||||
value=config['openai']['api_base'],
|
||||
path="openai.api_base",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI model",
|
||||
description="OpenAI LLM 模型。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['model'],
|
||||
path="openai.chatGPTConfigs.model",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="OpenAI max_tokens",
|
||||
description="OpenAI 最大生成长度。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['max_tokens'],
|
||||
path="openai.chatGPTConfigs.max_tokens",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="float",
|
||||
name="OpenAI temperature",
|
||||
description="OpenAI 温度。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['temperature'],
|
||||
path="openai.chatGPTConfigs.temperature",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="float",
|
||||
name="OpenAI top_p",
|
||||
description="OpenAI top_p。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['top_p'],
|
||||
path="openai.chatGPTConfigs.top_p",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="float",
|
||||
name="OpenAI frequency_penalty",
|
||||
description="OpenAI frequency_penalty。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['frequency_penalty'],
|
||||
path="openai.chatGPTConfigs.frequency_penalty",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="float",
|
||||
name="OpenAI presence_penalty",
|
||||
description="OpenAI presence_penalty。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['presence_penalty'],
|
||||
path="openai.chatGPTConfigs.presence_penalty",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="OpenAI 总生成长度限制",
|
||||
description="OpenAI 总生成长度限制。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['total_tokens_limit'],
|
||||
path="openai.total_tokens_limit",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI 图像生成模型",
|
||||
description="OpenAI 图像生成模型。",
|
||||
value=config['openai_image_generate']['model'],
|
||||
path="openai_image_generate.model",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI 图像生成大小",
|
||||
description="OpenAI 图像生成大小。",
|
||||
value=config['openai_image_generate']['size'],
|
||||
path="openai_image_generate.size",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI 图像生成风格",
|
||||
description="OpenAI 图像生成风格。修改前请参考 OpenAI 官方文档",
|
||||
value=config['openai_image_generate']['style'],
|
||||
path="openai_image_generate.style",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI 图像生成质量",
|
||||
description="OpenAI 图像生成质量。修改前请参考 OpenAI 官方文档",
|
||||
value=config['openai_image_generate']['quality'],
|
||||
path="openai_image_generate.quality",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="问题题首提示词",
|
||||
description="如果填写了此项,在每个对大语言模型的请求中,都会在问题前加上此提示词。",
|
||||
value=config['llm_env_prompt'],
|
||||
path="llm_env_prompt",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="默认人格文本",
|
||||
description="默认人格文本",
|
||||
value=config['default_personality_str'],
|
||||
path="default_personality_str",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
baidu_aip_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="百度内容审核",
|
||||
description="需要去申请",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启动百度内容审核服务",
|
||||
description="",
|
||||
value=config['baidu_aip']['enable'],
|
||||
path="baidu_aip.enable"
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="APP ID",
|
||||
description="",
|
||||
value=config['baidu_aip']['app_id'],
|
||||
path="baidu_aip.app_id"
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="API KEY",
|
||||
description="",
|
||||
value=config['baidu_aip']['api_key'],
|
||||
path="baidu_aip.api_key"
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="SECRET KEY",
|
||||
description="",
|
||||
value=config['baidu_aip']['secret_key'],
|
||||
path="baidu_aip.secret_key"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
other_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="其他配置",
|
||||
description="其他配置描述",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="HTTP 代理地址",
|
||||
description="建议上下一致",
|
||||
value=config['http_proxy'],
|
||||
path="http_proxy",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="HTTPS 代理地址",
|
||||
description="建议上下一致",
|
||||
value=config['https_proxy'],
|
||||
path="https_proxy",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="面板用户名",
|
||||
description="是的,就是你理解的这个面板的用户名",
|
||||
value=config['dashboard_username'],
|
||||
path="dashboard_username",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
dashboard_data.configs['data'] = [
|
||||
qq_official_platform_group,
|
||||
qq_gocq_platform_group,
|
||||
general_platform_detail_group,
|
||||
openai_official_llm_group,
|
||||
other_group,
|
||||
baidu_aip_group,
|
||||
qq_aiocqhttp_platform_group
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"配置文件解析错误:{e}")
|
||||
raise e
|
||||
|
||||
def save_config(self, post_config: list, namespace: str):
|
||||
'''
|
||||
根据 path 解析并保存配置
|
||||
'''
|
||||
|
||||
queue = post_config
|
||||
while len(queue) > 0:
|
||||
config = queue.pop(0)
|
||||
if config['config_type'] == "group":
|
||||
for item in config['body']:
|
||||
queue.append(item)
|
||||
elif config['config_type'] == "item":
|
||||
if config['path'] is None or config['path'] == "":
|
||||
continue
|
||||
|
||||
path = config['path'].split('.')
|
||||
if len(path) == 0:
|
||||
continue
|
||||
|
||||
if config['val_type'] == "bool":
|
||||
self._write_config(
|
||||
namespace, config['path'], config['value'])
|
||||
elif config['val_type'] == "str":
|
||||
self._write_config(
|
||||
namespace, config['path'], config['value'])
|
||||
elif config['val_type'] == "int":
|
||||
try:
|
||||
self._write_config(
|
||||
namespace, config['path'], int(config['value']))
|
||||
except:
|
||||
raise ValueError(f"配置项 {config['name']} 的值必须是整数")
|
||||
elif config['val_type'] == "float":
|
||||
try:
|
||||
self._write_config(
|
||||
namespace, config['path'], float(config['value']))
|
||||
except:
|
||||
raise ValueError(f"配置项 {config['name']} 的值必须是浮点数")
|
||||
elif config['val_type'] == "list":
|
||||
if config['value'] is None:
|
||||
self._write_config(namespace, config['path'], [])
|
||||
elif not isinstance(config['value'], list):
|
||||
raise ValueError(f"配置项 {config['name']} 的值必须是列表")
|
||||
self._write_config(
|
||||
namespace, config['path'], config['value'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"未知或者未实现的配置项类型:{config['val_type']}")
|
||||
|
||||
def _write_config(self, namespace: str, key: str, value):
|
||||
if namespace == "" or namespace.startswith("internal_"):
|
||||
# 机器人自带配置,存到 config.yaml
|
||||
self.context.config_helper.put_by_dot_str(key, value)
|
||||
else:
|
||||
update_config(namespace, key, value)
|
||||
17
dashboard/routes/__init__.py
Normal file
17
dashboard/routes/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .auth import AuthRoute
|
||||
from .plugin import PluginRoute
|
||||
from .config import ConfigRoute
|
||||
from .update import UpdateRoute
|
||||
from .stat import StatRoute
|
||||
from .log import LogRoute
|
||||
from .static_file import StaticFileRoute
|
||||
|
||||
__all__ = [
|
||||
"AuthRoute",
|
||||
"PluginRoute",
|
||||
"ConfigRoute",
|
||||
"UpdateRoute",
|
||||
"StatRoute",
|
||||
"LogRoute",
|
||||
"StaticFileRoute"
|
||||
]
|
||||
33
dashboard/routes/auth.py
Normal file
33
dashboard/routes/auth.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from .. import Route, Response
|
||||
from quart import Quart, request
|
||||
from type.types import Context
|
||||
|
||||
class AuthRoute(Route):
|
||||
def __init__(self, context: Context, app: Quart) -> None:
|
||||
super().__init__(context, app)
|
||||
self.routes = {
|
||||
'/auth/login': ('POST', self.login),
|
||||
'/auth/password/reset': ('POST', self.reset_password),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def login(self):
|
||||
username = self.context.config_helper.dashboard.username
|
||||
password = self.context.config_helper.dashboard.password
|
||||
post_data = await request.json
|
||||
if post_data["username"] == username and post_data["password"] == password:
|
||||
return Response().ok({
|
||||
"token": "astrbot-test-token",
|
||||
"username": username
|
||||
}).__dict__
|
||||
else:
|
||||
return Response().error("用户名或密码错误").__dict__
|
||||
|
||||
async def reset_password(self):
|
||||
password = self.context.config_helper.dashboard.password
|
||||
post_data = await request.json
|
||||
if post_data["password"] == password:
|
||||
self.context.config_helper.dashboard.password = post_data['new_password']
|
||||
return Response().ok(None).__dict__
|
||||
else:
|
||||
return Response().error("原密码错误").__dict__
|
||||
80
dashboard/routes/config.py
Normal file
80
dashboard/routes/config.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import os, json, threading
|
||||
from .. import Route, Response
|
||||
from ..utils.config import *
|
||||
from quart import Quart, request
|
||||
from type.types import Context
|
||||
from type.config import CONFIG_METADATA_2
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
|
||||
|
||||
class ConfigRoute(Route):
|
||||
def __init__(self, context: Context, app: Quart, astrbot_updator: AstrBotUpdator) -> None:
|
||||
super().__init__(context, app)
|
||||
self.config_key_dont_show = ['dashboard', 'config_version']
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.routes = {
|
||||
'/config/get': ('GET', self.get_configs),
|
||||
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
|
||||
'/config/plugin/update': ('POST', self.post_extension_configs),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def get_configs(self):
|
||||
# namespace 为空时返回 AstrBot 配置
|
||||
# 否则返回指定 namespace 的插件配置
|
||||
namespace = "" if "namespace" not in request.args else request.args["namespace"]
|
||||
if not namespace:
|
||||
return Response().ok(await self._get_astrbot_config()).__dict__
|
||||
return Response().ok(await self._get_extension_config(namespace)).__dict__
|
||||
|
||||
async def post_astrbot_configs(self):
|
||||
post_configs = await request.json
|
||||
try:
|
||||
await self._save_astrbot_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人将在 3 秒内重启以应用新的配置。").__dict__
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def post_extension_configs(self):
|
||||
post_configs = await request.json
|
||||
try:
|
||||
await self._save_extension_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人将在 3 秒内重启以应用新的配置。").__dict__
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.context.config_helper.to_dict()
|
||||
for key in self.config_key_dont_show:
|
||||
if key in config:
|
||||
del config[key]
|
||||
return {
|
||||
"metadata": CONFIG_METADATA_2,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
async def _get_extension_config(self, namespace: str):
|
||||
path = f"data/config/{namespace}.json"
|
||||
if not os.path.exists(path):
|
||||
return []
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
return [{
|
||||
"config_type": "group",
|
||||
"name": namespace + " 插件配置",
|
||||
"description": "",
|
||||
"body": list(json.load(f).values())
|
||||
},]
|
||||
|
||||
async def _save_astrbot_configs(self, post_configs: dict):
|
||||
try:
|
||||
save_astrbot_config(post_configs, self.context)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _save_extension_configs(self, post_configs: dict):
|
||||
try:
|
||||
save_extension_config(post_configs)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start()
|
||||
except Exception as e:
|
||||
raise e
|
||||
50
dashboard/routes/log.py
Normal file
50
dashboard/routes/log.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
from quart import websocket
|
||||
from quart import Quart
|
||||
from type.types import Context
|
||||
from .. import logger
|
||||
|
||||
class Broker:
|
||||
def __init__(self) -> None:
|
||||
self.connections = set()
|
||||
|
||||
async def send(self, message: str):
|
||||
for connection in self.connections:
|
||||
try:
|
||||
await connection.send(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"发送日志失败: {e.__str__()}")
|
||||
|
||||
|
||||
class LogRoute:
|
||||
def __init__(self, context: Context, app: Quart) -> None:
|
||||
self.app = app
|
||||
self.context = context
|
||||
self.broker = Broker()
|
||||
self.app.add_url_rule('/api/live-log', view_func=self.log, methods=['GET'], websocket=True)
|
||||
|
||||
async def _receive_log_task(self):
|
||||
while True:
|
||||
message = await self.context._log_queue.get()
|
||||
await self.broker.send(message)
|
||||
|
||||
async def _get_log_history(self):
|
||||
try:
|
||||
dq = self.context._log_queue.get_cache()
|
||||
ret = ""
|
||||
for log in dq:
|
||||
log = log.replace("\n", "\n\r")
|
||||
ret += log + "\n\r"
|
||||
return ret
|
||||
except Exception as e:
|
||||
logger.warning(f"读取日志历史失败: {e.__str__()}")
|
||||
return ""
|
||||
|
||||
async def log(self):
|
||||
try:
|
||||
await websocket.send(await self._get_log_history())
|
||||
self.broker.connections.add(websocket)
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
self.broker.connections.remove(websocket)
|
||||
86
dashboard/routes/plugin.py
Normal file
86
dashboard/routes/plugin.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import threading, traceback, uuid
|
||||
from .. import Route, Response, logger
|
||||
from quart import Quart, request
|
||||
from type.types import Context
|
||||
from model.plugin.manager import PluginManager
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(self, context: Context, app: Quart, astrbot_updator: AstrBotUpdator, plugin_manager: PluginManager) -> None:
|
||||
super().__init__(context, app)
|
||||
self.routes = {
|
||||
'/plugin/get': ('GET', self.get_plugins),
|
||||
'/plugin/install': ('POST', self.install_plugin),
|
||||
'/plugin/install-upload': ('POST', self.install_plugin_upload),
|
||||
'/plugin/update': ('POST', self.update_plugin),
|
||||
'/plugin/uninstall': ('POST', self.uninstall_plugin),
|
||||
}
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.plugin_manager = plugin_manager
|
||||
self.register_routes()
|
||||
|
||||
async def get_plugins(self):
|
||||
_plugin_resp = []
|
||||
for plugin in self.context.cached_plugins:
|
||||
_p = plugin.metadata
|
||||
_t = {
|
||||
"name": _p.plugin_name,
|
||||
"repo": '' if _p.repo is None else _p.repo,
|
||||
"author": _p.author,
|
||||
"desc": _p.desc,
|
||||
"version": _p.version
|
||||
}
|
||||
_plugin_resp.append(_t)
|
||||
return Response().ok(_plugin_resp).__dict__
|
||||
|
||||
async def install_plugin(self):
|
||||
post_data = await request.json
|
||||
repo_url = post_data["url"]
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
await self.plugin_manager.install_plugin(repo_url)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"安装插件 {repo_url} 成功, 2秒后重启")
|
||||
return Response().ok(None, "安装成功,程序将在 2 秒内重启。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def install_plugin_upload(self):
|
||||
try:
|
||||
file = request.files['file']
|
||||
print(file.filename)
|
||||
logger.info(f"正在安装用户上传的插件 {file.filename}")
|
||||
file_path = f"data/temp/{uuid.uuid4()}.zip"
|
||||
file.save(file_path)
|
||||
self.plugin_manager.install_plugin_from_file(file_path)
|
||||
logger.info(f"安装插件 {file.filename} 成功")
|
||||
return Response().ok(None, "安装成功!!").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def uninstall_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在卸载插件 {plugin_name}")
|
||||
self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
logger.info(f"卸载插件 {plugin_name} 成功")
|
||||
return Response().ok(None, "卸载成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def update_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
await self.plugin_manager.update_plugin(plugin_name)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
|
||||
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
62
dashboard/routes/stat.py
Normal file
62
dashboard/routes/stat.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import traceback, psutil, time, datetime
|
||||
from .. import Route, Response, logger
|
||||
from quart import Quart, request
|
||||
from type.types import Context
|
||||
from astrbot.db import BaseDatabase
|
||||
from type.config import VERSION
|
||||
|
||||
class StatRoute(Route):
|
||||
def __init__(self, context: Context, app: Quart, db_helper: BaseDatabase) -> None:
|
||||
super().__init__(context, app)
|
||||
self.routes = {
|
||||
'/stat/get': ('GET', self.get_stat),
|
||||
'/stat/version': ('GET', self.get_version),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.register_routes()
|
||||
|
||||
def format_sec(self, sec: int):
|
||||
m, s = divmod(sec, 60)
|
||||
h, m = divmod(m, 60)
|
||||
return f"{h}小时{m}分{s}秒"
|
||||
|
||||
async def get_version(self):
|
||||
return Response().ok({
|
||||
"version": VERSION
|
||||
}).__dict__
|
||||
|
||||
async def get_stat(self):
|
||||
offset_sec = request.args.get('offset_sec', 86400)
|
||||
offset_sec = int(offset_sec)
|
||||
try:
|
||||
stat = self.db_helper.get_base_stats(offset_sec)
|
||||
now = int(time.time())
|
||||
start_time = now - offset_sec
|
||||
message_time_based_stats = []
|
||||
|
||||
idx = 0
|
||||
for bucket_end in range(start_time, now, 1800):
|
||||
cnt = 0
|
||||
while idx < len(stat.platform) and stat.platform[idx].timestamp < bucket_end:
|
||||
cnt += stat.platform[idx].count
|
||||
idx += 1
|
||||
message_time_based_stats.append([bucket_end, cnt])
|
||||
|
||||
stat_dict = stat.__dict__
|
||||
|
||||
stat_dict.update({
|
||||
"platform": self.db_helper.get_grouped_base_stats(offset_sec).platform,
|
||||
"message_count": self.db_helper.get_total_message_count() or 0,
|
||||
"platform_count": len(self.context.platforms),
|
||||
"message_time_series": message_time_based_stats,
|
||||
"running": self.format_sec(int(time.time() - self.context._start_running)),
|
||||
"memory": {
|
||||
"process": psutil.Process().memory_info().rss >> 20,
|
||||
"system": psutil.virtual_memory().total >> 20
|
||||
}
|
||||
})
|
||||
|
||||
return Response().ok(stat_dict).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
14
dashboard/routes/static_file.py
Normal file
14
dashboard/routes/static_file.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .. import Route
|
||||
from quart import Quart
|
||||
from type.types import Context
|
||||
|
||||
class StaticFileRoute(Route):
|
||||
def __init__(self, context: Context, app: Quart) -> None:
|
||||
super().__init__(context, app)
|
||||
|
||||
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default']
|
||||
for i in index_:
|
||||
self.app.add_url_rule(i, view_func=self.index)
|
||||
|
||||
async def index(self):
|
||||
return await self.app.send_static_file('index.html')
|
||||
45
dashboard/routes/update.py
Normal file
45
dashboard/routes/update.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import threading, traceback
|
||||
from .. import Route, Response, logger
|
||||
from quart import Quart, request
|
||||
from type.types import Context
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
|
||||
class UpdateRoute(Route):
|
||||
def __init__(self, context: Context, app: Quart, astrbot_updator: AstrBotUpdator) -> None:
|
||||
super().__init__(context, app)
|
||||
self.routes = {
|
||||
'/update/check': ('GET', self.check_update),
|
||||
'/update/do': ('POST', self.update_project),
|
||||
}
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.register_routes()
|
||||
|
||||
async def check_update(self):
|
||||
try:
|
||||
ret = await self.astrbot_updator.check_update(None, None)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
data={
|
||||
"has_new_version": ret is not None
|
||||
}
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def update_project(self):
|
||||
data = await request.json
|
||||
version = data.get('version', '')
|
||||
if version == "" or version == "latest":
|
||||
latest = True
|
||||
version = ''
|
||||
else:
|
||||
latest = False
|
||||
try:
|
||||
await self.astrbot_updator.update(latest=latest, version=version)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -1,510 +1,39 @@
|
||||
import websockets
|
||||
import json
|
||||
import threading
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from . import DashBoardData, Response
|
||||
from flask import Flask, request
|
||||
from werkzeug.serving import make_server
|
||||
from astrbot.persist.helper import dbConn
|
||||
import asyncio
|
||||
from quart import Quart
|
||||
from quart.logging import default_handler
|
||||
from type.types import Context
|
||||
from typing import List
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from dashboard.helper import DashBoardHelper
|
||||
from util.io import get_local_ip_addresses
|
||||
from .routes import *
|
||||
from . import logger
|
||||
from astrbot.db import BaseDatabase
|
||||
from model.plugin.manager import PluginManager
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from util.io import get_local_ip_addresses
|
||||
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
class AstrBotDashBoard():
|
||||
def __init__(self, context: Context, plugin_manager: PluginManager, astrbot_updator: AstrBotUpdator):
|
||||
class AstrBotDashboard():
|
||||
def __init__(self, context: Context,
|
||||
plugin_manager: PluginManager,
|
||||
astrbot_updator: AstrBotUpdator,
|
||||
db_helper: BaseDatabase) -> None:
|
||||
self.context = context
|
||||
self.plugin_manager = plugin_manager
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.dashboard_data = DashBoardData()
|
||||
self.dashboard_helper = DashBoardHelper(self.context, self.dashboard_data)
|
||||
self.app = Quart("dashboard", static_folder="dist", static_url_path="/")
|
||||
self.app.json.sort_keys = False
|
||||
|
||||
self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/")
|
||||
logging.getLogger('werkzeug').setLevel(logging.ERROR)
|
||||
self.dashboard_be.logger.setLevel(logging.ERROR)
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
|
||||
self.ws_clients = {} # remote_ip: ws
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.ar = AuthRoute(context, self.app)
|
||||
self.ur = UpdateRoute(context, self.app, astrbot_updator)
|
||||
self.sr = StatRoute(context, self.app, db_helper)
|
||||
self.pr = PluginRoute(context, self.app, astrbot_updator, plugin_manager)
|
||||
self.cr = ConfigRoute(context, self.app, astrbot_updator)
|
||||
self.lr = LogRoute(context, self.app)
|
||||
self.sfr = StaticFileRoute(context, self.app)
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
while self.context.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
self.http_server_thread: threading.Thread = None
|
||||
|
||||
@self.dashboard_be.get("/")
|
||||
def index():
|
||||
# 返回页面
|
||||
return self.dashboard_be.send_static_file("index.html")
|
||||
|
||||
@self.dashboard_be.get("/auth/login")
|
||||
def _():
|
||||
return self.dashboard_be.send_static_file("index.html")
|
||||
|
||||
@self.dashboard_be.get("/config")
|
||||
def rt_config():
|
||||
return self.dashboard_be.send_static_file("index.html")
|
||||
|
||||
@self.dashboard_be.get("/logs")
|
||||
def rt_logs():
|
||||
return self.dashboard_be.send_static_file("index.html")
|
||||
|
||||
@self.dashboard_be.get("/extension")
|
||||
def rt_extension():
|
||||
return self.dashboard_be.send_static_file("index.html")
|
||||
|
||||
@self.dashboard_be.get("/dashboard/default")
|
||||
def rt_dashboard():
|
||||
return self.dashboard_be.send_static_file("index.html")
|
||||
|
||||
@self.dashboard_be.post("/api/authenticate")
|
||||
def authenticate():
|
||||
username = self.context.base_config.get("dashboard_username", "")
|
||||
password = self.context.base_config.get("dashboard_password", "")
|
||||
# 获得请求体
|
||||
post_data = request.json
|
||||
if post_data["username"] == username and post_data["password"] == password:
|
||||
return Response(
|
||||
status="success",
|
||||
message="登录成功。",
|
||||
data={
|
||||
"token": "astrbot-test-token",
|
||||
"username": username
|
||||
}
|
||||
).__dict__
|
||||
else:
|
||||
return Response(
|
||||
status="error",
|
||||
message="用户名或密码错误。",
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/change_password")
|
||||
def change_password():
|
||||
password = self.context.base_config.get("dashboard_password", "")
|
||||
# 获得请求体
|
||||
post_data = request.json
|
||||
if post_data["password"] == password:
|
||||
self.context.config_helper.put("dashboard_password", post_data["new_password"])
|
||||
return Response(
|
||||
status="success",
|
||||
message="修改成功。",
|
||||
data=None
|
||||
).__dict__
|
||||
else:
|
||||
return Response(
|
||||
status="error",
|
||||
message="原密码错误。",
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.get("/api/stats")
|
||||
def get_stats():
|
||||
db_inst = dbConn()
|
||||
all_session = db_inst.get_all_stat_session()
|
||||
last_24_message = db_inst.get_last_24h_stat_message()
|
||||
# last_24_platform = db_inst.get_last_24h_stat_platform()
|
||||
platforms = db_inst.get_platform_cnt_total()
|
||||
self.dashboard_data.stats["session"] = []
|
||||
self.dashboard_data.stats["session_total"] = db_inst.get_session_cnt_total(
|
||||
)
|
||||
self.dashboard_data.stats["message"] = last_24_message
|
||||
self.dashboard_data.stats["message_total"] = db_inst.get_message_cnt_total(
|
||||
)
|
||||
self.dashboard_data.stats["platform"] = platforms
|
||||
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=self.dashboard_data.stats
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.get("/api/configs")
|
||||
def get_configs():
|
||||
# 如果params中有namespace,则返回该namespace下的配置
|
||||
# 否则返回所有配置
|
||||
namespace = "" if "namespace" not in request.args else request.args["namespace"]
|
||||
conf = self._get_configs(namespace)
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=conf
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.get("/api/config_outline")
|
||||
def get_config_outline():
|
||||
outline = self._generate_outline()
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=outline
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/configs")
|
||||
def post_configs():
|
||||
post_configs = request.json
|
||||
try:
|
||||
self.on_post_configs(post_configs)
|
||||
return Response(
|
||||
status="success",
|
||||
message="保存成功~ 机器人将在 2 秒内重启以应用新的配置。",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=self.dashboard_data.configs
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.get("/api/extensions")
|
||||
def get_plugins():
|
||||
_plugin_resp = []
|
||||
for plugin in self.context.cached_plugins:
|
||||
_p = plugin.metadata
|
||||
_t = {
|
||||
"name": _p.plugin_name,
|
||||
"repo": '' if _p.repo is None else _p.repo,
|
||||
"author": _p.author,
|
||||
"desc": _p.desc,
|
||||
"version": _p.version
|
||||
}
|
||||
_plugin_resp.append(_t)
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=_plugin_resp
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/extensions/install")
|
||||
def install_plugin():
|
||||
post_data = request.json
|
||||
repo_url = post_data["url"]
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
self.plugin_manager.install_plugin(repo_url)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"安装插件 {repo_url} 成功,2秒后重启")
|
||||
return Response(
|
||||
status="success",
|
||||
message="安装成功,机器人将在 2 秒内重启。",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/install: {traceback.format_exc()}")
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/extensions/upload-install")
|
||||
def upload_install_plugin():
|
||||
try:
|
||||
file = request.files['file']
|
||||
print(file.filename)
|
||||
logger.info(f"正在安装用户上传的插件 {file.filename}")
|
||||
# save file to temp/
|
||||
file_path = f"temp/{uuid.uuid4()}.zip"
|
||||
file.save(file_path)
|
||||
self.plugin_manager.install_plugin_from_file(file_path)
|
||||
logger.info(f"安装插件 {file.filename} 成功")
|
||||
return Response(
|
||||
status="success",
|
||||
message="安装成功~",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/upload-install: {traceback.format_exc()}")
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/extensions/uninstall")
|
||||
def uninstall_plugin():
|
||||
post_data = request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在卸载插件 {plugin_name}")
|
||||
self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
logger.info(f"卸载插件 {plugin_name} 成功")
|
||||
return Response(
|
||||
status="success",
|
||||
message="卸载成功~",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/uninstall: {traceback.format_exc()}")
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/extensions/update")
|
||||
def update_plugin():
|
||||
post_data = request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
self.plugin_manager.update_plugin(plugin_name)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
|
||||
return Response(
|
||||
status="success",
|
||||
message="更新成功,机器人将在 2 秒内重启。",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/log")
|
||||
def log():
|
||||
for item in self.ws_clients:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.ws_clients[item].send(request.data.decode()), self.loop).result()
|
||||
except Exception as e:
|
||||
pass
|
||||
return 'ok'
|
||||
|
||||
@self.dashboard_be.get("/api/check_update")
|
||||
def get_update_info():
|
||||
try:
|
||||
ret = self.astrbot_updator.check_update(None, None)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
data={
|
||||
"has_new_version": ret is not None
|
||||
}
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/check_update: {traceback.format_exc()}")
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/update_project")
|
||||
def update_project_api():
|
||||
version = request.json['version']
|
||||
if version == "" or version == "latest":
|
||||
latest = True
|
||||
version = ''
|
||||
else:
|
||||
latest = False
|
||||
try:
|
||||
self.astrbot_updator.update(latest=latest, version=version)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
return Response(
|
||||
status="success",
|
||||
message="更新成功,机器人将在 3 秒内重启。",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.get("/api/llm/list")
|
||||
def llm_list():
|
||||
ret = []
|
||||
for llm in self.context.llms:
|
||||
ret.append(llm.llm_name)
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=ret
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.get("/api/llm")
|
||||
def llm():
|
||||
text = request.args["text"]
|
||||
llm = request.args["llm"]
|
||||
for llm_ in self.context.llms:
|
||||
if llm_.llm_name == llm:
|
||||
try:
|
||||
ret = asyncio.run_coroutine_threadsafe(
|
||||
llm_.llm_instance.text_chat(text), self.loop).result()
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=ret
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
return Response(
|
||||
status="error",
|
||||
message="LLM not found.",
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
def on_post_configs(self, post_configs: dict):
|
||||
try:
|
||||
if 'base_config' in post_configs:
|
||||
self.dashboard_helper.save_config(
|
||||
post_configs['base_config'], namespace='') # 基础配置
|
||||
self.dashboard_helper.save_config(
|
||||
post_configs['config'], namespace=post_configs['namespace']) # 选定配置
|
||||
self.dashboard_helper.parse_default_config(
|
||||
self.dashboard_data, self.context.config_helper.get_all())
|
||||
# 重启
|
||||
threading.Thread(target=self.astrbot_updator._reboot,
|
||||
args=(2, self.context), daemon=True).start()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _get_configs(self, namespace: str):
|
||||
if namespace == "":
|
||||
ret = [self.dashboard_data.configs['data'][4],
|
||||
self.dashboard_data.configs['data'][5],]
|
||||
elif namespace == "internal_platform_qq_official":
|
||||
ret = [self.dashboard_data.configs['data'][0],]
|
||||
elif namespace == "internal_platform_qq_gocq":
|
||||
ret = [self.dashboard_data.configs['data'][1],]
|
||||
elif namespace == "internal_platform_general": # 全局平台配置
|
||||
ret = [self.dashboard_data.configs['data'][2],]
|
||||
elif namespace == "internal_llm_openai_official":
|
||||
ret = [self.dashboard_data.configs['data'][3],]
|
||||
elif namespace == "internal_platform_qq_aiocqhttp":
|
||||
ret = [self.dashboard_data.configs['data'][6],]
|
||||
else:
|
||||
path = f"data/config/{namespace}.json"
|
||||
if not os.path.exists(path):
|
||||
return []
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
ret = [{
|
||||
"config_type": "group",
|
||||
"name": namespace + " 插件配置",
|
||||
"description": "",
|
||||
"body": list(json.load(f).values())
|
||||
},]
|
||||
return ret
|
||||
|
||||
def _generate_outline(self):
|
||||
'''
|
||||
生成配置大纲。目前分为 platform(消息平台配置) 和 llm(语言模型配置) 两大类。
|
||||
插件的info函数中如果带了plugin_type字段,则会被归类到对应的大纲中。目前仅支持 platform 和 llm 两种类型。
|
||||
'''
|
||||
outline = [
|
||||
{
|
||||
"type": "platform",
|
||||
"name": "配置通用消息平台",
|
||||
"body": [
|
||||
{
|
||||
"title": "通用",
|
||||
"desc": "通用平台配置",
|
||||
"namespace": "internal_platform_general",
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(官方)",
|
||||
"desc": "QQ官方API。支持频道、群、私聊(需获得群权限)",
|
||||
"namespace": "internal_platform_qq_official",
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(nakuru)",
|
||||
"desc": "适用于 go-cqhttp",
|
||||
"namespace": "internal_platform_qq_gocq",
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(aiocqhttp)",
|
||||
"desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。",
|
||||
"namespace": "internal_platform_qq_aiocqhttp",
|
||||
"tag": ""
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "llm",
|
||||
"name": "配置 LLM",
|
||||
"body": [
|
||||
{
|
||||
"title": "OpenAI Official",
|
||||
"desc": "也支持使用官方接口的中转服务",
|
||||
"namespace": "internal_llm_openai_official",
|
||||
"tag": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
for plugin in self.context.cached_plugins:
|
||||
for item in outline:
|
||||
if item['type'] == plugin.metadata.plugin_type:
|
||||
item['body'].append({
|
||||
"title": plugin.metadata.plugin_name,
|
||||
"desc": plugin.metadata.desc,
|
||||
"namespace": plugin.metadata.plugin_name,
|
||||
"tag": plugin.metadata.plugin_name
|
||||
})
|
||||
return outline
|
||||
|
||||
async def get_log_history(self):
|
||||
try:
|
||||
with open("logs/astrbot/astrbot.log", "r", encoding="utf-8") as f:
|
||||
return f.readlines()[-100:]
|
||||
except Exception as e:
|
||||
logger.warning(f"读取日志历史失败: {e.__str__()}")
|
||||
return []
|
||||
|
||||
async def __handle_msg(self, websocket, path):
|
||||
address = websocket.remote_address
|
||||
self.ws_clients[address] = websocket
|
||||
data = await self.get_log_history()
|
||||
data = ''.join(data).replace('\n', '\r\n')
|
||||
await websocket.send(data)
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
# logger.info(f"和 {address} 的 websocket 连接已断开")
|
||||
del self.ws_clients[address]
|
||||
break
|
||||
except Exception as e:
|
||||
# logger.info(f"和 {path} 的 websocket 连接发生了错误: {e.__str__()}")
|
||||
del self.ws_clients[address]
|
||||
break
|
||||
|
||||
async def ws_server(self):
|
||||
ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186)
|
||||
logger.info("WebSocket 服务器已启动。")
|
||||
await ws_server
|
||||
|
||||
def http_server(self):
|
||||
http_server = make_server(
|
||||
'0.0.0.0', 6185, self.dashboard_be, threaded=True)
|
||||
http_server.serve_forever()
|
||||
|
||||
def run_http_server(self):
|
||||
self.http_server_thread = threading.Thread(target=self.http_server, daemon=True).start()
|
||||
ip_address = get_local_ip_addresses()
|
||||
ip_str = f"http://{ip_address}:6185"
|
||||
logger.info(f"HTTP 服务器已启动,可访问: {ip_str} 等来登录可视化面板。")
|
||||
def run(self):
|
||||
ip_addr = get_local_ip_addresses()
|
||||
logger.info(f"仪表盘已启动,可访问 http://{ip_addr}:6185 登录。")
|
||||
return self.app.run_task(host="0.0.0.0", port=6185, shutdown_trigger=self.shutdown_trigger_placeholder)
|
||||
98
dashboard/utils/config.py
Normal file
98
dashboard/utils/config.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from dataclasses import asdict
|
||||
from util.plugin_dev.api.v1.config import update_config
|
||||
from type.config import CONFIG_METADATA_2
|
||||
from type.types import Context
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
return int(value)
|
||||
elif type_ == "float" and isinstance(value, str) \
|
||||
and value.replace(".", "", 1).isdigit():
|
||||
return float(value)
|
||||
elif type_ == "float" and isinstance(value, int):
|
||||
return float(value)
|
||||
|
||||
def get_default_val_by_type(type_: str):
|
||||
if type_ == "int":
|
||||
return 0
|
||||
elif type_ == "float":
|
||||
return 0.0
|
||||
elif type_ == "bool":
|
||||
return False
|
||||
elif type_ == "string":
|
||||
return ""
|
||||
elif type_ == "list":
|
||||
return []
|
||||
elif type_ == "object":
|
||||
return {}
|
||||
|
||||
|
||||
def validate_config(data, context: Context):
|
||||
errors = []
|
||||
def validate(data, metadata=CONFIG_METADATA_2, path=""):
|
||||
for key, meta in metadata.items():
|
||||
if key not in data:
|
||||
continue
|
||||
value = data[key]
|
||||
# null 转换
|
||||
if value is None:
|
||||
data[key] = get_default_val_by_type(meta["type"])
|
||||
continue
|
||||
# 递归验证
|
||||
if meta["type"] == "list" and isinstance(value, list):
|
||||
for item in value:
|
||||
validate(item, meta["items"], path=f"{path}{key}.")
|
||||
elif meta["type"] == "object" and isinstance(value, dict):
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
|
||||
if meta["type"] == "int" and not isinstance(value, int):
|
||||
casted = try_cast(value, "int")
|
||||
if casted is None:
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}")
|
||||
data[key] = casted
|
||||
elif meta["type"] == "float" and not isinstance(value, float):
|
||||
casted = try_cast(value, "float")
|
||||
if casted is None:
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}")
|
||||
data[key] = casted
|
||||
elif meta["type"] == "bool" and not isinstance(value, bool):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "string" and not isinstance(value, str):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "object" and not isinstance(value, dict):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
validate(data)
|
||||
|
||||
# hardcode warning
|
||||
data['config_version'] = context.config_helper.config_version
|
||||
data['dashboard'] = asdict(context.config_helper.dashboard)
|
||||
|
||||
return errors
|
||||
|
||||
def save_astrbot_config(post_config: dict, context: Context):
|
||||
'''验证并保存配置'''
|
||||
errors = validate_config(post_config, context)
|
||||
if errors:
|
||||
raise ValueError(f"格式校验未通过: {errors}")
|
||||
context.config_helper.flush_config(post_config)
|
||||
|
||||
def save_extension_config(post_config: dict):
|
||||
if 'namespace' not in post_config:
|
||||
raise ValueError("Missing key: namespace")
|
||||
if 'config' not in post_config:
|
||||
raise ValueError("Missing key: config")
|
||||
|
||||
namespace = post_config['namespace']
|
||||
config: list = post_config['config'][0]['body']
|
||||
for item in config:
|
||||
key = item['path']
|
||||
value = item['value']
|
||||
typ = item['val_type']
|
||||
if typ == 'int':
|
||||
if not value.isdigit():
|
||||
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
|
||||
value = int(value)
|
||||
update_config(namespace, key, value)
|
||||
16
main.py
16
main.py
@@ -6,8 +6,7 @@ import warnings
|
||||
import traceback
|
||||
import mimetypes
|
||||
from astrbot.bootstrap import AstrBotBootstrap
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
from util.log import LogManager
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logo_tmpl = r"""
|
||||
@@ -27,6 +26,8 @@ def main():
|
||||
# delete qqbotpy's logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
logger.info(logo_tmpl)
|
||||
|
||||
bootstrap = AstrBotBootstrap()
|
||||
asyncio.run(bootstrap.run())
|
||||
@@ -42,7 +43,8 @@ def check_env():
|
||||
exit()
|
||||
|
||||
os.makedirs("data/config", exist_ok=True)
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
os.makedirs("data/plugins", exist_ok=True)
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
# workaround for issue #181
|
||||
mimetypes.add_type("text/javascript", ".js")
|
||||
@@ -51,11 +53,5 @@ def check_env():
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_env()
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
logger.info(logo_tmpl)
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
main()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import aiohttp
|
||||
import aiohttp, os
|
||||
|
||||
from model.command.manager import CommandManager
|
||||
from model.plugin.manager import PluginManager
|
||||
@@ -6,9 +6,9 @@ from type.message_event import AstrMessageEvent
|
||||
from type.command import CommandResult
|
||||
from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.web_searcher import search_from_bing, fetch_website_content
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -19,23 +19,30 @@ class InternalCommandHandler:
|
||||
self.plugin_manager = plugin_manager
|
||||
|
||||
self.manager.register("help", "查看帮助", 10, self.help)
|
||||
self.manager.register("wake", "设置机器人唤醒词", 10, self.set_nick)
|
||||
self.manager.register("update", "更新 AstrBot", 10, self.update)
|
||||
self.manager.register("wake", "唤醒前缀", 10, self.set_nick)
|
||||
self.manager.register("update", "更新管理", 10, self.update)
|
||||
self.manager.register("plugin", "插件管理", 10, self.plugin)
|
||||
self.manager.register("reboot", "重启 AstrBot", 10, self.reboot)
|
||||
self.manager.register("websearch", "网页搜索开关", 10, self.web_search)
|
||||
self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle)
|
||||
self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid)
|
||||
self.manager.register("provider", "查看和切换当前使用的 LLM 资源来源", 10, self.provider)
|
||||
self.manager.register("websearch", "网页搜索", 10, self.web_search)
|
||||
self.manager.register("t2i", "文转图", 10, self.t2i_toggle)
|
||||
self.manager.register("myid", "用户ID", 10, self.myid)
|
||||
self.manager.register("provider", "LLM 接入源", 10, self.provider)
|
||||
|
||||
def _check_auth(self, message: AstrMessageEvent, context: Context):
|
||||
if os.environ.get("TEST_MODE", "off") == "on":
|
||||
return
|
||||
if message.role != "admin":
|
||||
user_id = message.message_obj.sender.user_id
|
||||
raise Exception(f"用户(ID: {user_id}) 没有足够的权限使用该指令。")
|
||||
|
||||
def provider(self, message: AstrMessageEvent, context: Context):
|
||||
if len(context.llms) == 0:
|
||||
return CommandResult().message("当前没有加载任何 LLM 资源。")
|
||||
return CommandResult().message("当前没有加载任何 LLM 接入源。")
|
||||
|
||||
tokens = self.manager.command_parser.parse(message.message_str)
|
||||
|
||||
if tokens.len == 1:
|
||||
ret = "## 当前载入的 LLM 资源\n"
|
||||
ret = "## 当前载入的 LLM 接入源\n"
|
||||
for idx, llm in enumerate(context.llms):
|
||||
ret += f"{idx}. {llm.llm_name}"
|
||||
if llm.origin:
|
||||
@@ -44,7 +51,7 @@ class InternalCommandHandler:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
ret += "\n使用 provider <序号> 切换 LLM 资源。"
|
||||
ret += "\n使用 provider <序号> 切换 LLM 接入源。"
|
||||
return CommandResult().message(ret)
|
||||
else:
|
||||
try:
|
||||
@@ -52,58 +59,48 @@ class InternalCommandHandler:
|
||||
if idx >= len(context.llms):
|
||||
return CommandResult().message("provider: 无效的序号。")
|
||||
context.message_handler.set_provider(context.llms[idx].llm_instance)
|
||||
return CommandResult().message(f"已经成功切换到 LLM 资源 {context.llms[idx].llm_name}。")
|
||||
return CommandResult().message(f"已经成功切换到 LLM 接入源 {context.llms[idx].llm_name}。")
|
||||
except BaseException as e:
|
||||
return CommandResult().message("provider: 参数错误。")
|
||||
|
||||
def set_nick(self, message: AstrMessageEvent, context: Context):
|
||||
self._check_auth(message, context)
|
||||
message_str = message.message_str
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("你没有权限使用该指令。")
|
||||
l = message_str.split(" ")
|
||||
if len(l) == 1:
|
||||
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.nick}")
|
||||
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词是:{context.config_helper.wake_prefix[0]}")
|
||||
nick = l[1].strip()
|
||||
if not nick:
|
||||
return CommandResult().message("wake: 请指定唤醒词。")
|
||||
context.config_helper.put("nick_qq", nick)
|
||||
context.nick = tuple(nick)
|
||||
context.config_helper.wake_prefix = [nick]
|
||||
context.config_helper.save_config()
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain=f"已经成功将唤醒词设定为 {nick}。",
|
||||
message_chain=f"已经成功将唤醒前缀设定为 {nick}。",
|
||||
)
|
||||
|
||||
def update(self, message: AstrMessageEvent, context: Context):
|
||||
async def update(self, message: AstrMessageEvent, context: Context):
|
||||
self._check_auth(message, context)
|
||||
tokens = self.manager.command_parser.parse(message.message_str)
|
||||
if message.role != "admin":
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain="你没有权限使用该指令",
|
||||
)
|
||||
update_info = context.updator.check_update(None, None)
|
||||
update_info = await context.updator.check_update(None, None)
|
||||
if tokens.len == 1:
|
||||
ret = ""
|
||||
if not update_info:
|
||||
ret = f"当前已经是最新版本 v{VERSION}。"
|
||||
else:
|
||||
ret = f"发现新版本 {update_info.version},更新内容如下:\n---\n{update_info.body}\n---\n- 使用 /update latest 更新到最新版本。\n- 使用 /update vX.X.X 更新到指定版本。"
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain=ret,
|
||||
)
|
||||
return CommandResult().message(ret)
|
||||
else:
|
||||
if tokens.get(1) == "latest":
|
||||
try:
|
||||
context.updator.update()
|
||||
await context.updator.update()
|
||||
return CommandResult().message(f"已经成功更新到最新版本 v{update_info.version}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启")
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"更新失败。原因:{str(e)}")
|
||||
elif tokens.get(1).startswith("v"):
|
||||
try:
|
||||
context.updator.update(version=tokens.get(1))
|
||||
await context.updator.update(version=tokens.get(1))
|
||||
return CommandResult().message(f"已经成功更新到版本 v{tokens.get(1)}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启")
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"更新失败。原因:{str(e)}")
|
||||
@@ -111,12 +108,7 @@ class InternalCommandHandler:
|
||||
return CommandResult().message("update: 参数错误。")
|
||||
|
||||
def reboot(self, message: AstrMessageEvent, context: Context):
|
||||
if message.role != "admin":
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain="你没有权限使用该指令",
|
||||
)
|
||||
self._check_auth(message, context)
|
||||
context.updator._reboot(3, context)
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
@@ -124,7 +116,7 @@ class InternalCommandHandler:
|
||||
message_chain="AstrBot 将在 3s 后重启。",
|
||||
)
|
||||
|
||||
def plugin(self, message: AstrMessageEvent, context: Context):
|
||||
async def plugin(self, message: AstrMessageEvent, context: Context):
|
||||
tokens = self.manager.command_parser.parse(message.message_str)
|
||||
if tokens.len == 1:
|
||||
ret = "# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n"
|
||||
@@ -137,10 +129,10 @@ class InternalCommandHandler:
|
||||
if plugin_list_info.strip() == "":
|
||||
return CommandResult().message("plugin v: 没有找到插件。")
|
||||
return CommandResult().message(plugin_list_info)
|
||||
|
||||
self._check_auth(message, context)
|
||||
|
||||
elif tokens.get(1) == "d":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin d: 你没有权限使用该指令。")
|
||||
if tokens.get(1) == "d":
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin d: 请指定要卸载的插件名。")
|
||||
plugin_name = tokens.get(2)
|
||||
@@ -151,25 +143,21 @@ class InternalCommandHandler:
|
||||
return CommandResult().message(f"plugin d: 已经成功卸载插件 {plugin_name}。")
|
||||
|
||||
elif tokens.get(1) == "i":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin i: 你没有权限使用该指令。")
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin i: 请指定要安装的插件的 Github 地址,或者前往可视化面板安装。")
|
||||
plugin_url = tokens.get(2)
|
||||
try:
|
||||
self.plugin_manager.install_plugin(plugin_url)
|
||||
await self.plugin_manager.install_plugin(plugin_url)
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"plugin i: 安装插件失败。原因:{str(e)}")
|
||||
return CommandResult().message("plugin i: 已经成功安装插件。")
|
||||
|
||||
elif tokens.get(1) == "u":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin u: 你没有权限使用该指令。")
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin u: 请指定要更新的插件名。")
|
||||
plugin_name = tokens.get(2)
|
||||
try:
|
||||
self.plugin_manager.update_plugin(plugin_name)
|
||||
await context.plugin_updator.update(plugin_name)
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"plugin u: 更新插件失败。原因:{str(e)}")
|
||||
return CommandResult().message(f"plugin u: 已经成功更新插件 {plugin_name}。")
|
||||
@@ -183,20 +171,20 @@ class InternalCommandHandler:
|
||||
async with session.get("https://soulter.top/channelbot/notice.json") as resp:
|
||||
notice = (await resp.json())["notice"]
|
||||
except BaseException as e:
|
||||
logger.warn("An error occurred while fetching astrbot notice. Never mind, it's not important.")
|
||||
logger.warning("An error occurred while fetching astrbot notice. Never mind, it's not important.")
|
||||
|
||||
msg = "# Help Center\n## 指令列表\n"
|
||||
msg = "# 帮助中心\n## 指令\n"
|
||||
for key, value in self.manager.commands_handler.items():
|
||||
if value.plugin_metadata:
|
||||
msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n"
|
||||
else: msg += f"- `{key}`: {value.description}\n"
|
||||
# plugins
|
||||
if context.cached_plugins != None:
|
||||
if context.cached_plugins:
|
||||
plugin_list_info = ""
|
||||
for plugin in context.cached_plugins:
|
||||
plugin_list_info += f"- `{plugin.metadata.plugin_name}` {plugin.metadata.desc}\n"
|
||||
if plugin_list_info.strip() != "":
|
||||
msg += "\n## 插件列表\n> 使用plugin v 插件名 查看插件帮助\n"
|
||||
msg += "\n## 插件\n> 使用plugin v 插件名 查看插件帮助\n"
|
||||
msg += plugin_list_info
|
||||
msg += notice
|
||||
|
||||
@@ -208,17 +196,39 @@ class InternalCommandHandler:
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain=f"网页搜索功能当前状态: {context.web_search}",
|
||||
message_chain=f"网页搜索功能当前状态: {context.config_helper.llm_settings.web_search}",
|
||||
)
|
||||
elif l[1] == 'on':
|
||||
context.web_search = True
|
||||
context.config_helper.llm_settings.web_search = True
|
||||
context.config_helper.save_config()
|
||||
context.register_llm_tool("web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"description": "搜索关键词"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
search_from_bing
|
||||
)
|
||||
context.register_llm_tool("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "要获取内容的网页链接"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain="已开启网页搜索",
|
||||
)
|
||||
elif l[1] == 'off':
|
||||
context.web_search = False
|
||||
context.config_helper.llm_settings.web_search = False
|
||||
context.config_helper.save_config()
|
||||
context.unregister_llm_tool("web_search")
|
||||
context.unregister_llm_tool("fetch_website_content")
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
@@ -232,17 +242,17 @@ class InternalCommandHandler:
|
||||
)
|
||||
|
||||
def t2i_toggle(self, message: AstrMessageEvent, context: Context):
|
||||
p = context.t2i_mode
|
||||
p = context.config_helper.t2i
|
||||
if p:
|
||||
context.config_helper.put("qq_pic_mode", False)
|
||||
context.t2i_mode = False
|
||||
context.config_helper.t2i = False
|
||||
context.config_helper.save_config()
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain="已关闭文本转图片模式。",
|
||||
)
|
||||
context.config_helper.put("qq_pic_mode", True)
|
||||
context.t2i_mode = True
|
||||
context.config_helper.t2i = True
|
||||
context.config_helper.save_config()
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
@@ -262,5 +272,5 @@ class InternalCommandHandler:
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain=f"在 {message.platform} 上获取你的ID失败,原因: {str(e)}",
|
||||
message_chain=f"获取失败,原因: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from type.command import CommandResult
|
||||
from type.register import RegisteredPlugins
|
||||
from model.command.parser import CommandParser
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -124,8 +124,11 @@ class CommandManager():
|
||||
else:
|
||||
command_result = handler(message_event, context)
|
||||
|
||||
if not isinstance(command_result, CommandResult):
|
||||
raise ValueError(f"Command {command} handler should return CommandResult.")
|
||||
# if not isinstance(command_result, CommandResult):
|
||||
# raise ValueError(f"Command {command} handler should return CommandResult.")
|
||||
|
||||
if not command_result:
|
||||
return
|
||||
|
||||
context.metrics_uploader.command_stats[command] += 1
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from model.command.manager import CommandManager
|
||||
from type.message_event import AstrMessageEvent
|
||||
from type.command import CommandResult
|
||||
from type.types import Context
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
|
||||
|
||||
@@ -5,6 +5,9 @@ from type.astrbot_message import AstrBotMessage
|
||||
from type.command import CommandResult
|
||||
from type.astrbot_message import MessageType
|
||||
|
||||
class T2IException(Exception):
|
||||
def __init__(self, message: str = "文本转图片时发生错误") -> None:
|
||||
super().__init__(message)
|
||||
|
||||
class Platform():
|
||||
def __init__(self, platform_name: str, context) -> None:
|
||||
@@ -40,14 +43,18 @@ class Platform():
|
||||
'''
|
||||
pass
|
||||
|
||||
def parse_message_outline(self, message: AstrBotMessage) -> str:
|
||||
def parse_message_outline(self, message: Union[AstrBotMessage, list]) -> str:
|
||||
'''
|
||||
将消息解析成大纲消息形式,如: xxxxx[图片]xxxxx。用于输出日志等。
|
||||
'''
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
ret = ''
|
||||
parsed = message if isinstance(message, list) else message.message
|
||||
if isinstance(message, list):
|
||||
parsed = message
|
||||
elif isinstance(message, AstrBotMessage):
|
||||
parsed = message.message
|
||||
elif isinstance(message, str):
|
||||
return message
|
||||
|
||||
try:
|
||||
for node in parsed:
|
||||
if isinstance(node, Plain):
|
||||
@@ -61,13 +68,14 @@ class Platform():
|
||||
return ret[:100] if len(ret) > 100 else ret
|
||||
|
||||
def check_nick(self, message_str: str) -> bool:
|
||||
if self.context.nick:
|
||||
for nick in self.context.nick:
|
||||
if nick and message_str.strip().startswith(nick):
|
||||
return True
|
||||
w = self.context.config_helper.wake_prefix
|
||||
if not w: return False
|
||||
for nick in w:
|
||||
if nick and message_str.strip().startswith(nick):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def convert_to_t2i_chain(self, message_result: list) -> list:
|
||||
async def convert_to_t2i_chain(self, message_result: list) -> Union[List[Image], None]:
|
||||
plain_str = ""
|
||||
rendered_images = []
|
||||
for i in message_result:
|
||||
|
||||
@@ -3,9 +3,15 @@ import asyncio
|
||||
from util.io import port_checker
|
||||
from type.register import RegisteredPlatform
|
||||
from type.types import Context
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import (
|
||||
PlatformConfig,
|
||||
AiocqhttpPlatformConfig,
|
||||
NakuruPlatformConfig,
|
||||
QQOfficialPlatformConfig
|
||||
)
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -13,36 +19,40 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
class PlatformManager():
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
self.context = context
|
||||
self.config = context.base_config
|
||||
self.msg_handler = message_handler
|
||||
|
||||
def load_platforms(self):
|
||||
tasks = []
|
||||
|
||||
if 'gocqbot' in self.config and self.config['gocqbot']['enable']:
|
||||
logger.info("启用 QQ(nakuru 适配器)")
|
||||
tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter"))
|
||||
|
||||
if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']:
|
||||
logger.info("启用 QQ(aiocqhttp 适配器)")
|
||||
tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter"))
|
||||
platforms = self.context.config_helper.platform
|
||||
logger.info(f"加载 {len(platforms)} 个机器人消息平台...")
|
||||
for platform in platforms:
|
||||
if not platform.enable:
|
||||
continue
|
||||
if platform.name == "qq_official":
|
||||
assert isinstance(platform, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
|
||||
logger.info(f"加载 QQ官方 机器人消息平台 (appid: {platform.appid})")
|
||||
tasks.append(asyncio.create_task(self.qqofficial_bot(platform), name="qqofficial-adapter"))
|
||||
elif platform.name == "nakuru":
|
||||
assert isinstance(platform, NakuruPlatformConfig), "nakuru: 无法识别的配置类型。"
|
||||
logger.info(f"加载 QQ(nakuru) 机器人消息平台 ({platform.host}, {platform.websocket_port}, {platform.port})")
|
||||
tasks.append(asyncio.create_task(self.nakuru_bot(platform), name="nakuru-adapter"))
|
||||
elif platform.name == "aiocqhttp":
|
||||
assert isinstance(platform, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
|
||||
logger.info("加载 QQ(aiocqhttp) 机器人消息平台")
|
||||
tasks.append(asyncio.create_task(self.aiocq_bot(platform), name="aiocqhttp-adapter"))
|
||||
|
||||
# QQ频道
|
||||
if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None:
|
||||
logger.info("启用 QQ(官方 API) 机器人消息平台")
|
||||
tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter"))
|
||||
|
||||
return tasks
|
||||
|
||||
async def gocq_bot(self):
|
||||
async def nakuru_bot(self, config: NakuruPlatformConfig):
|
||||
'''
|
||||
运行 QQ(nakuru 适配器)
|
||||
'''
|
||||
from model.platform.qq_nakuru import QQGOCQ
|
||||
from model.platform.qq_nakuru import QQNakuru
|
||||
noticed = False
|
||||
host = self.config.get("gocq_host", "127.0.0.1")
|
||||
port = self.config.get("gocq_websocket_port", 6700)
|
||||
http_port = self.config.get("gocq_http_port", 5700)
|
||||
host = config.host
|
||||
port = config.websocket_port
|
||||
http_port = config.port
|
||||
logger.info(
|
||||
f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}")
|
||||
while True:
|
||||
@@ -56,30 +66,30 @@ class PlatformManager():
|
||||
logger.info("nakuru 适配器已连接。")
|
||||
break
|
||||
try:
|
||||
qq_gocq = QQGOCQ(self.context, self.msg_handler)
|
||||
qq_gocq = QQNakuru(self.context, self.msg_handler, config)
|
||||
self.context.platforms.append(RegisteredPlatform(
|
||||
platform_name="nakuru", platform_instance=qq_gocq, origin="internal"))
|
||||
await qq_gocq.run()
|
||||
except BaseException as e:
|
||||
logger.error("启动 nakuru 适配器时出现错误: " + str(e))
|
||||
|
||||
def aiocq_bot(self):
|
||||
def aiocq_bot(self, config):
|
||||
'''
|
||||
运行 QQ(aiocqhttp 适配器)
|
||||
'''
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler)
|
||||
qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler, config)
|
||||
self.context.platforms.append(RegisteredPlatform(
|
||||
platform_name="aiocqhttp", platform_instance=qq_aiocqhttp, origin="internal"))
|
||||
return qq_aiocqhttp.run_aiocqhttp()
|
||||
|
||||
def qqchan_bot(self):
|
||||
def qqofficial_bot(self, config):
|
||||
'''
|
||||
运行 QQ 官方机器人适配器
|
||||
'''
|
||||
try:
|
||||
from model.platform.qq_official import QQOfficial
|
||||
qqchannel_bot = QQOfficial(self.context, self.msg_handler)
|
||||
qqchannel_bot = QQOfficial(self.context, self.msg_handler, config)
|
||||
self.context.platforms.append(RegisteredPlatform(
|
||||
platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal"))
|
||||
return qqchannel_bot.run()
|
||||
|
||||
@@ -4,28 +4,32 @@ import traceback
|
||||
import logging
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
from . import Platform
|
||||
from . import Platform, T2IException
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from type.command import *
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
class AIOCQHTTP(Platform):
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
def __init__(self, context: Context,
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig) -> None:
|
||||
super().__init__("aiocqhttp", context)
|
||||
assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.waiting = {}
|
||||
self.context = context
|
||||
self.unique_session = self.context.unique_session
|
||||
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
|
||||
self.host = self.context.base_config['aiocqhttp']['ws_reverse_host']
|
||||
self.port = self.context.base_config['aiocqhttp']['ws_reverse_port']
|
||||
self.config = platform_config
|
||||
self.unique_session = context.config_helper.platform_settings.unique_session
|
||||
self.host = platform_config.ws_reverse_host
|
||||
self.port = platform_config.ws_reverse_port
|
||||
|
||||
def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
|
||||
@@ -80,7 +84,7 @@ class AIOCQHTTP(Platform):
|
||||
def run_aiocqhttp(self):
|
||||
if not self.host or not self.port:
|
||||
return
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp')
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
@@ -105,7 +109,7 @@ class AIOCQHTTP(Platform):
|
||||
while self.context.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def pre_check(self, message: AstrBotMessage) -> bool:
|
||||
async def pre_check(self, message: AstrBotMessage) -> bool:
|
||||
# if message chain contains Plain components or
|
||||
# At components which points to self_id, return True
|
||||
if message.type == MessageType.FRIEND_MESSAGE:
|
||||
@@ -114,7 +118,7 @@ class AIOCQHTTP(Platform):
|
||||
if isinstance(comp, At) and str(comp.qq) == message.self_id:
|
||||
return True, "at"
|
||||
# check commands which ignore prefix
|
||||
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
|
||||
if await self.context.command_manager.check_command_ignore_prefix(message.message_str):
|
||||
return True, "command"
|
||||
# check nicks
|
||||
if self.check_nick(message.message_str):
|
||||
@@ -125,17 +129,9 @@ class AIOCQHTTP(Platform):
|
||||
logger.info(
|
||||
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
|
||||
|
||||
ok, reason = self.pre_check(message)
|
||||
ok, reason = await self.pre_check(message)
|
||||
if not ok:
|
||||
return
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.sender.user_id)
|
||||
if sender_id == self.context.base_config.get('admin_qq', '') or \
|
||||
sender_id in self.context.base_config.get('other_admins', []):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
|
||||
# parse unified message origin
|
||||
unified_msg_origin = None
|
||||
@@ -152,7 +148,6 @@ class AIOCQHTTP(Platform):
|
||||
self.context,
|
||||
"aiocqhttp",
|
||||
message.session_id,
|
||||
role,
|
||||
unified_msg_origin,
|
||||
reason == "command") # only_command
|
||||
|
||||
@@ -164,11 +159,8 @@ class AIOCQHTTP(Platform):
|
||||
if message_result.callback:
|
||||
message_result.callback()
|
||||
|
||||
# 如果是等待回复的消息
|
||||
if message.session_id in self.waiting and self.waiting[message.session_id] == '':
|
||||
self.waiting[message.session_id] = message
|
||||
|
||||
|
||||
return message_result
|
||||
|
||||
async def reply_msg(self,
|
||||
message: AstrBotMessage,
|
||||
result_message: list,
|
||||
@@ -176,31 +168,35 @@ class AIOCQHTTP(Platform):
|
||||
"""
|
||||
回复用户唤醒机器人的消息。(被动回复)
|
||||
"""
|
||||
logger.info(
|
||||
f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
|
||||
|
||||
res = result_message
|
||||
|
||||
if isinstance(res, str):
|
||||
res = [Plain(text=res), ]
|
||||
try:
|
||||
await self._reply(message, result_message, use_t2i)
|
||||
except T2IException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(f"文本转图片时发生错误,将使用纯文本发送。")
|
||||
await self._reply(message, result_message, False)
|
||||
return result_message
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
await self._reply(message, rendered_images)
|
||||
return
|
||||
except BaseException as e:
|
||||
logger.warn(traceback.format_exc())
|
||||
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
|
||||
|
||||
await self._reply(message, res)
|
||||
|
||||
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
|
||||
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent], use_t2i: bool = None):
|
||||
await self.record_metrics()
|
||||
if isinstance(message_chain, str):
|
||||
message_chain = [Plain(text=message_chain), ]
|
||||
|
||||
# 文转图处理
|
||||
if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(message_chain, list):
|
||||
try:
|
||||
message_chain = await self.convert_to_t2i_chain(message_chain)
|
||||
if not message_chain: raise T2IException()
|
||||
except BaseException as e:
|
||||
raise T2IException()
|
||||
|
||||
# log
|
||||
if isinstance(message, AstrBotMessage):
|
||||
logger.info(
|
||||
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}")
|
||||
else:
|
||||
logger.info(f"回复消息: {message_chain}")
|
||||
|
||||
# 解析成 OneBot json 格式并发送
|
||||
ret = []
|
||||
image_idx = []
|
||||
for idx, segment in enumerate(message_chain):
|
||||
@@ -210,33 +206,42 @@ class AIOCQHTTP(Platform):
|
||||
if isinstance(segment, Image):
|
||||
image_idx.append(idx)
|
||||
ret.append(d)
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
logger.info(f"回复消息: {ret}")
|
||||
return
|
||||
try:
|
||||
if isinstance(message, AstrBotMessage):
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
if isinstance(message, dict):
|
||||
if 'group_id' in message:
|
||||
await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
|
||||
elif 'user_id' in message:
|
||||
await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
|
||||
else:
|
||||
raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
|
||||
await self._reply_wrapper(message, ret)
|
||||
except ActionFailed as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
if e.retcode == 1200:
|
||||
# ENOENT
|
||||
if not image_idx:
|
||||
raise e
|
||||
logger.info("检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。")
|
||||
logger.warning("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。")
|
||||
for idx in image_idx:
|
||||
if ret[idx]['data']['file'].startswith('file://'):
|
||||
logger.info(f"正在上传图片: {ret[idx]['data']['path']}")
|
||||
# 除了上传到图床,想不到更好的办法。
|
||||
image_url = await self.context.image_uploader.upload_image(ret[idx]['data']['path'])
|
||||
logger.info(f"上传成功。")
|
||||
ret[idx]['data']['file'] = image_url
|
||||
ret[idx]['data']['path'] = image_url
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
|
||||
await self._reply_wrapper(message, ret)
|
||||
else:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
raise e
|
||||
|
||||
async def _reply_wrapper(self, message: Union[AstrBotMessage, Dict], ret: List):
|
||||
if isinstance(message, AstrBotMessage):
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
if isinstance(message, dict):
|
||||
if 'group_id' in message:
|
||||
await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
|
||||
elif 'user_id' in message:
|
||||
await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
|
||||
else:
|
||||
raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
|
||||
|
||||
async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
|
||||
'''
|
||||
以主动的方式给QQ用户、QQ群发送一条消息。
|
||||
@@ -247,8 +252,12 @@ class AIOCQHTTP(Platform):
|
||||
- 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号;
|
||||
|
||||
'''
|
||||
|
||||
await self._reply(target, result_message.message_chain)
|
||||
try:
|
||||
await self._reply(target, result_message.message_chain, result_message.is_use_t2i)
|
||||
except T2IException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(f"文本转图片时发生错误,将使用纯文本发送。")
|
||||
await self._reply(target, result_message.message_chain, False)
|
||||
|
||||
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
|
||||
if message_type == MessageType.GROUP_MESSAGE:
|
||||
|
||||
@@ -11,13 +11,14 @@ from nakuru import (
|
||||
)
|
||||
from typing import Union, List, Dict
|
||||
from type.types import Context
|
||||
from . import Platform
|
||||
from . import Platform, T2IException
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from type.command import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, NakuruPlatformConfig
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -28,47 +29,43 @@ class FakeSource:
|
||||
self.group_id = group_id
|
||||
|
||||
|
||||
class QQGOCQ(Platform):
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
class QQNakuru(Platform):
|
||||
def __init__(self, context: Context,
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig) -> None:
|
||||
super().__init__("nakuru", context)
|
||||
assert isinstance(platform_config, NakuruPlatformConfig), "gocq: 无法识别的配置类型。"
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.waiting = {}
|
||||
self.context = context
|
||||
self.unique_session = self.context.unique_session
|
||||
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
|
||||
|
||||
self.unique_session = context.config_helper.platform_settings.unique_session
|
||||
self.config = platform_config
|
||||
|
||||
self.client = CQHTTP(
|
||||
host=self.context.base_config.get("gocq_host", "127.0.0.1"),
|
||||
port=self.context.base_config.get("gocq_websocket_port", 6700),
|
||||
http_port=self.context.base_config.get("gocq_http_port", 5700),
|
||||
host=self.config.host,
|
||||
port=self.config.websocket_port,
|
||||
http_port=self.config.port
|
||||
)
|
||||
gocq_app = self.client
|
||||
|
||||
@gocq_app.receiver("GroupMessage")
|
||||
async def _(app: CQHTTP, source: GroupMessage):
|
||||
if self.context.base_config.get("gocq_react_group", True):
|
||||
if self.config.enable_group:
|
||||
abm = self.convert_message(source)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@gocq_app.receiver("FriendMessage")
|
||||
async def _(app: CQHTTP, source: FriendMessage):
|
||||
if self.context.base_config.get("gocq_react_friend", True):
|
||||
if self.config.enable_direct_message:
|
||||
abm = self.convert_message(source)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@gocq_app.receiver("GroupMemberIncrease")
|
||||
async def _(app: CQHTTP, source: GroupMemberIncrease):
|
||||
if self.context.base_config.get("gocq_react_group_increase", True):
|
||||
await app.sendGroupMessage(source.group_id, [
|
||||
Plain(text=self.announcement)
|
||||
])
|
||||
|
||||
@gocq_app.receiver("GuildMessage")
|
||||
async def _(app: CQHTTP, source: GuildMessage):
|
||||
if self.cc.get("gocq_react_guild", True):
|
||||
if self.config.enable_guild:
|
||||
abm = self.convert_message(source)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@@ -114,14 +111,6 @@ class QQGOCQ(Platform):
|
||||
session_id = message.raw_message.user_id
|
||||
|
||||
message.session_id = session_id
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.raw_message.user_id)
|
||||
if sender_id == self.context.base_config.get('admin_qq', '') or \
|
||||
sender_id in self.context.base_config.get('other_admins', []):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
|
||||
# parse unified message origin
|
||||
unified_msg_origin = None
|
||||
@@ -143,7 +132,6 @@ class QQGOCQ(Platform):
|
||||
self.context,
|
||||
"nakuru",
|
||||
session_id,
|
||||
role,
|
||||
unified_msg_origin,
|
||||
reason == 'command') # only_command
|
||||
|
||||
@@ -155,49 +143,47 @@ class QQGOCQ(Platform):
|
||||
if message_result.callback:
|
||||
message_result.callback()
|
||||
|
||||
# 如果是等待回复的消息
|
||||
if session_id in self.waiting and self.waiting[session_id] == '':
|
||||
self.waiting[session_id] = message
|
||||
|
||||
async def reply_msg(self,
|
||||
message: AstrBotMessage,
|
||||
result_message: List[BaseMessageComponent],
|
||||
use_t2i: bool = None):
|
||||
"""
|
||||
回复用户唤醒机器人的消息。(被动回复)
|
||||
"""
|
||||
source = message.raw_message
|
||||
res = result_message
|
||||
"""
|
||||
assert isinstance(message.raw_message, (GroupMessage, FriendMessage, GuildMessage))
|
||||
|
||||
try:
|
||||
await self._reply(message, result_message, use_t2i)
|
||||
except T2IException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(f"文本转图片时发生错误,将使用纯文本发送。")
|
||||
await self._reply(message, result_message, False)
|
||||
return result_message
|
||||
|
||||
assert isinstance(source,
|
||||
(GroupMessage, FriendMessage, GuildMessage))
|
||||
|
||||
logger.info(
|
||||
f"{source.user_id} <- {self.parse_message_outline(res)}")
|
||||
|
||||
if isinstance(res, str):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
await self._reply(source, rendered_images)
|
||||
return
|
||||
except BaseException as e:
|
||||
logger.warn(traceback.format_exc())
|
||||
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
|
||||
|
||||
await self._reply(source, res)
|
||||
|
||||
async def _reply(self, source, message_chain: List[BaseMessageComponent]):
|
||||
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent], use_t2i: bool = None):
|
||||
await self.record_metrics()
|
||||
if isinstance(message_chain, str):
|
||||
message_chain = [Plain(text=message_chain), ]
|
||||
|
||||
# 文转图处理
|
||||
if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(message_chain, list):
|
||||
try:
|
||||
message_chain = await self.convert_to_t2i_chain(message_chain)
|
||||
if not message_chain: raise T2IException()
|
||||
except BaseException as e:
|
||||
raise T2IException()
|
||||
|
||||
# log
|
||||
if isinstance(message, AstrBotMessage):
|
||||
logger.info(
|
||||
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}")
|
||||
else:
|
||||
logger.info(f"回复消息: {message_chain}")
|
||||
|
||||
source = message.raw_message
|
||||
is_dict = isinstance(source, dict)
|
||||
|
||||
# 发消息
|
||||
typ = None
|
||||
if is_dict:
|
||||
if "group_id" in source:
|
||||
@@ -226,7 +212,7 @@ class QQGOCQ(Platform):
|
||||
plain_text_len += len(i.text)
|
||||
elif isinstance(i, Image):
|
||||
image_num += 1
|
||||
if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
|
||||
if plain_text_len > self.context.config_helper.platform_settings.forward_threshold or image_num > 1:
|
||||
# 删除At
|
||||
for i in message_chain:
|
||||
if isinstance(i, At):
|
||||
@@ -252,7 +238,13 @@ class QQGOCQ(Platform):
|
||||
|
||||
guild_id 不是频道号。
|
||||
'''
|
||||
await self._reply(target, result_message.message_chain)
|
||||
try:
|
||||
await self._reply(target, result_message.message_chain, result_message.is_use_t2i)
|
||||
except T2IException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(f"文本转图片时发生错误,将使用纯文本发送。")
|
||||
await self._reply(target, result_message.message_chain, False)
|
||||
return result_message
|
||||
|
||||
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
|
||||
'''
|
||||
@@ -292,21 +284,4 @@ class QQGOCQ(Platform):
|
||||
)
|
||||
abm.tag = "nakuru"
|
||||
abm.message = message.message
|
||||
return abm
|
||||
|
||||
def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]:
|
||||
'''
|
||||
等待下一条消息,超时 300s 后抛出异常
|
||||
'''
|
||||
self.waiting[group_id] = ''
|
||||
cnt = 0
|
||||
while True:
|
||||
if group_id in self.waiting and self.waiting[group_id] != '':
|
||||
# 去掉
|
||||
ret = self.waiting[group_id]
|
||||
del self.waiting[group_id]
|
||||
return ret
|
||||
cnt += 1
|
||||
if cnt > 300:
|
||||
raise Exception("等待消息超时。")
|
||||
time.sleep(1)
|
||||
return abm
|
||||
@@ -16,9 +16,10 @@ from type.message_event import *
|
||||
from type.command import *
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -52,32 +53,35 @@ class botClient(Client):
|
||||
|
||||
class QQOfficial(Platform):
|
||||
|
||||
def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
|
||||
def __init__(self, context: Context,
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig,
|
||||
test_mode = False) -> None:
|
||||
super().__init__("qqofficial", context)
|
||||
assert isinstance(platform_config, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.waiting: dict = {}
|
||||
self.context = context
|
||||
self.config = platform_config
|
||||
|
||||
self.appid = context.base_config['qqbot']['appid']
|
||||
self.token = context.base_config['qqbot']['token']
|
||||
self.secret = context.base_config['qqbot_secret']
|
||||
self.unique_session = context.unique_session
|
||||
qq_group = context.base_config['qqofficial_enable_group_message']
|
||||
|
||||
self.appid = platform_config.appid
|
||||
self.secret = platform_config.secret
|
||||
self.unique_session = context.config_helper.platform_settings.unique_session
|
||||
qq_group = platform_config.enable_group_c2c
|
||||
guild_dm = platform_config.enable_guild_direct_message
|
||||
|
||||
if qq_group:
|
||||
self.intents = botpy.Intents(
|
||||
public_messages=True,
|
||||
public_guild_messages=True,
|
||||
direct_message=context.base_config['direct_message_mode']
|
||||
direct_message=guild_dm
|
||||
)
|
||||
else:
|
||||
self.intents = botpy.Intents(
|
||||
public_guild_messages=True,
|
||||
direct_message=context.base_config['direct_message_mode']
|
||||
direct_message=guild_dm
|
||||
)
|
||||
self.client = botClient(
|
||||
intents=self.intents,
|
||||
@@ -87,7 +91,7 @@ class QQOfficial(Platform):
|
||||
|
||||
self.client.set_platform(self)
|
||||
|
||||
self.test_mode = test_mode
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False):
|
||||
plain_text = ""
|
||||
@@ -169,24 +173,10 @@ class QQOfficial(Platform):
|
||||
return abm
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
return self.client.start(
|
||||
appid=self.appid,
|
||||
secret=self.secret
|
||||
)
|
||||
except BaseException as e:
|
||||
# 早期的 qq-botpy 版本使用 token 登录。
|
||||
logger.error(traceback.format_exc())
|
||||
self.client = botClient(
|
||||
intents=self.intents,
|
||||
bot_log=False,
|
||||
timeout=20,
|
||||
)
|
||||
self.client.set_platform(self)
|
||||
return self.client.start(
|
||||
appid=self.appid,
|
||||
token=self.token
|
||||
)
|
||||
return self.client.start(
|
||||
appid=self.appid,
|
||||
secret=self.secret
|
||||
)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
assert isinstance(message.raw_message, (botpy.message.Message,
|
||||
@@ -209,16 +199,8 @@ class QQOfficial(Platform):
|
||||
session_id = str(message.raw_message.author.id)
|
||||
message.session_id = session_id
|
||||
|
||||
# 解析出 role
|
||||
sender_id = message.sender.user_id
|
||||
if sender_id == self.context.base_config.get('admin_qqchan', None) or \
|
||||
sender_id in self.context.base_config.get('other_admins', None):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
|
||||
# construct astrbot message event
|
||||
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id, role)
|
||||
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id)
|
||||
|
||||
message_result = await self.message_handler.handle(ame)
|
||||
if not message_result:
|
||||
@@ -228,10 +210,6 @@ class QQOfficial(Platform):
|
||||
if message_result.callback:
|
||||
message_result.callback()
|
||||
|
||||
# 如果是等待回复的消息
|
||||
if session_id in self.waiting and self.waiting[session_id] == '':
|
||||
self.waiting[session_id] = message
|
||||
|
||||
return ret
|
||||
|
||||
async def reply_msg(self,
|
||||
@@ -250,10 +228,15 @@ class QQOfficial(Platform):
|
||||
plain_text = ''
|
||||
image_path = ''
|
||||
msg_ref = None
|
||||
rendered_images = []
|
||||
rendered_images = None
|
||||
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(result_message)
|
||||
if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list):
|
||||
try:
|
||||
rendered_images = await self.convert_to_t2i_chain(result_message)
|
||||
except BaseException as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(f"文本转图片时发生错误: {e},将尝试默认方式。")
|
||||
rendered_images = None
|
||||
|
||||
if isinstance(result_message, list):
|
||||
plain_text, image_path = await self._parse_to_qqofficial(result_message, message.type == MessageType.GROUP_MESSAGE)
|
||||
@@ -395,20 +378,3 @@ class QQOfficial(Platform):
|
||||
|
||||
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
|
||||
raise NotImplementedError("qqofficial 不支持此方法。")
|
||||
|
||||
def wait_for_message(self, channel_id: int) -> AstrBotMessage:
|
||||
'''
|
||||
等待指定 channel_id 的下一条信息,超时 300s 后抛出异常
|
||||
'''
|
||||
self.waiting[channel_id] = ''
|
||||
cnt = 0
|
||||
while True:
|
||||
if channel_id in self.waiting and self.waiting[channel_id] != '':
|
||||
# 去掉
|
||||
ret = self.waiting[channel_id]
|
||||
del self.waiting[channel_id]
|
||||
return ret
|
||||
cnt += 1
|
||||
if cnt > 300:
|
||||
raise Exception("等待消息超时。")
|
||||
time.sleep(1)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from type.register import RegisteredPlugins
|
||||
from typing import List, Union, Callable
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
@@ -24,4 +24,3 @@ class PluginCommandBridge():
|
||||
|
||||
def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False):
|
||||
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix))
|
||||
|
||||
@@ -5,7 +5,7 @@ import traceback
|
||||
import uuid
|
||||
import shutil
|
||||
import yaml
|
||||
import subprocess
|
||||
import logging
|
||||
|
||||
from util.updator.plugin_updator import PluginUpdator
|
||||
from util.io import remove_dir, download_file
|
||||
@@ -13,8 +13,9 @@ from types import ModuleType
|
||||
from type.types import Context
|
||||
from type.plugin import *
|
||||
from type.register import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from pip import main as pip_main
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -83,55 +84,63 @@ class PluginManager():
|
||||
self.update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
|
||||
|
||||
def update_plugin_dept(self, path):
|
||||
mirror = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
py = sys.executable
|
||||
# os.system(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com")
|
||||
pip_main(['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/'])
|
||||
# mirror = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
# py = sys.executable
|
||||
# cmd = f"{py} -m pip install -r {path} -i {mirror} --trusted-host mirrors.aliyun.com"
|
||||
# if break_system_package:
|
||||
# cmd += " --break-system-package"
|
||||
# process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, universal_newlines=True)
|
||||
|
||||
process = subprocess.Popen(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com",
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, universal_newlines=True)
|
||||
|
||||
while True:
|
||||
output = process.stdout.readline()
|
||||
if output == '' and process.poll() is not None:
|
||||
break
|
||||
if output:
|
||||
output = output.strip()
|
||||
if output.startswith("Requirement already satisfied"):
|
||||
continue
|
||||
if output.startswith("Using cached"):
|
||||
continue
|
||||
if output.startswith("Looking in indexes"):
|
||||
continue
|
||||
logger.info(output)
|
||||
|
||||
rc = process.poll()
|
||||
# while True:
|
||||
# output = process.stdout.readline()
|
||||
# err = process.stderr.readline()
|
||||
# if err:
|
||||
# err = err.strip()
|
||||
# logger.error(err)
|
||||
# if "no such option: --break-system-package" in err:
|
||||
# self.update_plugin_dept(path, break_system_package=False)
|
||||
# break
|
||||
# if output == '' and process.poll() is not None:
|
||||
# break
|
||||
# if output:
|
||||
# output = output.strip()
|
||||
# if output.startswith("Requirement already satisfied"):
|
||||
# continue
|
||||
# if output.startswith("Using cached"):
|
||||
# continue
|
||||
# if output.startswith("Looking in indexes"):
|
||||
# continue
|
||||
# logger.info(output)
|
||||
|
||||
# rc = process.poll()
|
||||
|
||||
|
||||
def install_plugin(self, repo_url: str):
|
||||
async def install_plugin(self, repo_url: str):
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
# we no longer use Git anymore :)
|
||||
# Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
|
||||
|
||||
plugin_path = self.updator.update(repo_url)
|
||||
plugin_path = await self.updator.update(repo_url)
|
||||
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
|
||||
f.write(repo_url)
|
||||
|
||||
self.check_plugin_dept_update()
|
||||
# self.check_plugin_dept_update()
|
||||
|
||||
return plugin_path
|
||||
# ok, err = self.plugin_reload()
|
||||
# if not ok:
|
||||
# raise Exception(err)
|
||||
|
||||
def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
logger.info(f"正在下载插件 {repo} ...")
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
releases = self.updator.fetch_release_info(url=release_url)
|
||||
releases = await self.updator.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warn(f"未在插件 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
@@ -139,7 +148,7 @@ class PluginManager():
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
download_file(release_url, target_path + ".zip")
|
||||
await download_file(release_url, target_path + ".zip")
|
||||
|
||||
def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin:
|
||||
for p in self.context.cached_plugins:
|
||||
@@ -156,12 +165,12 @@ class PluginManager():
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
|
||||
def update_plugin(self, plugin_name: str):
|
||||
async def update_plugin(self, plugin_name: str):
|
||||
plugin = self.get_registered_plugin(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
|
||||
self.updator.update(plugin)
|
||||
await self.updator.update(plugin)
|
||||
|
||||
def plugin_reload(self):
|
||||
cached_plugins = self.context.cached_plugins
|
||||
@@ -182,9 +191,16 @@ class PluginManager():
|
||||
|
||||
logger.info(f"正在加载插件 {root_dir_name} ...")
|
||||
|
||||
self.check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
|
||||
module = __import__("addons.plugins." +
|
||||
# self.check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
|
||||
try:
|
||||
module = __import__("data.plugins." +
|
||||
root_dir_name + "." + p, fromlist=[p])
|
||||
except (ModuleNotFoundError, ImportError) as e:
|
||||
# 尝试安装插件依赖
|
||||
logger.error(f"尝试安装插件依赖。")
|
||||
self.check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
module = __import__("data.plugins." +
|
||||
root_dir_name + "." + p, fromlist=[p])
|
||||
|
||||
cls = self.get_classes(module)
|
||||
@@ -216,6 +232,11 @@ class PluginManager():
|
||||
traceback.print_exc()
|
||||
fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
|
||||
if not fail_rec:
|
||||
return True, None
|
||||
else:
|
||||
@@ -252,7 +273,7 @@ class PluginManager():
|
||||
# remove the temp dir
|
||||
remove_dir(temp_dir)
|
||||
|
||||
self.check_plugin_dept_update()
|
||||
# self.check_plugin_dept_update()
|
||||
|
||||
# ok, err = self.plugin_reload()
|
||||
# if not ok:
|
||||
|
||||
@@ -8,18 +8,18 @@ import traceback
|
||||
import base64
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import *
|
||||
from util.io import download_image_by_url
|
||||
|
||||
from astrbot.persist.helper import dbConn
|
||||
from astrbot.db import BaseDatabase
|
||||
from model.provider.provider import Provider
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.cmd_config import LLMConfig
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from typing import List, Dict
|
||||
|
||||
from type.types import Context
|
||||
from dataclasses import asdict
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -47,22 +47,16 @@ MODELS = {
|
||||
}
|
||||
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(self, context: Context) -> None:
|
||||
def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase) -> None:
|
||||
super().__init__()
|
||||
|
||||
os.makedirs("data/openai", exist_ok=True)
|
||||
|
||||
self.context = context
|
||||
self.key_data_path = "data/openai/keys.json"
|
||||
self.api_keys = []
|
||||
self.chosen_api_key = None
|
||||
self.base_url = None
|
||||
self.llm_config = llm_config
|
||||
self.keys_data = {} # 记录超额
|
||||
|
||||
cfg = context.base_config['openai']
|
||||
|
||||
if cfg['key']: self.api_keys = cfg['key']
|
||||
if cfg['api_base']: self.base_url = cfg['api_base']
|
||||
if llm_config.key: self.api_keys = llm_config.key
|
||||
if llm_config.api_base: self.base_url = llm_config.api_base
|
||||
if not self.api_keys:
|
||||
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
|
||||
else:
|
||||
@@ -75,53 +69,45 @@ class ProviderOpenAIOfficial(Provider):
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
self.model_configs: Dict = cfg['chatGPTConfigs']
|
||||
super().set_curr_model(self.model_configs['model'])
|
||||
self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
|
||||
super().set_curr_model(llm_config.model_config.model)
|
||||
if llm_config.image_generation_model_config:
|
||||
self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config)
|
||||
self.session_memory: Dict[str, List] = {} # 会话记忆
|
||||
self.session_memory_lock = threading.Lock()
|
||||
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
|
||||
self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小
|
||||
|
||||
logger.info("正在载入分词器 cl100k_base...")
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
|
||||
logger.info("分词器载入完成。")
|
||||
|
||||
self.DEFAULT_PERSONALITY = context.default_personality
|
||||
self.DEFAULT_PERSONALITY = {
|
||||
"prompt": self.llm_config.default_personality,
|
||||
"name": "default"
|
||||
}
|
||||
self.curr_personality = self.DEFAULT_PERSONALITY
|
||||
self.session_personality = {} # 记录了某个session是否已设置人格。
|
||||
# 从 SQLite DB 读取历史记录
|
||||
# 读取历史记录
|
||||
self.db_helper = db_helper
|
||||
try:
|
||||
db1 = dbConn()
|
||||
for session in db1.get_all_session():
|
||||
for history in db_helper.get_llm_history():
|
||||
self.session_memory_lock.acquire()
|
||||
self.session_memory[session[0]] = json.loads(session[1])['data']
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
self.session_memory_lock.release()
|
||||
except BaseException as e:
|
||||
logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
logger.warning(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
# 定时保存历史记录
|
||||
threading.Thread(target=self.dump_history, daemon=True).start()
|
||||
|
||||
def dump_history(self):
|
||||
'''
|
||||
转储历史记录
|
||||
'''
|
||||
time.sleep(10)
|
||||
db = dbConn()
|
||||
'''转储历史记录'''
|
||||
time.sleep(30)
|
||||
while True:
|
||||
try:
|
||||
for key in self.session_memory:
|
||||
data = self.session_memory[key]
|
||||
data_json = {
|
||||
'data': data
|
||||
}
|
||||
if db.check_session(key):
|
||||
db.update_session(key, json.dumps(data_json))
|
||||
else:
|
||||
db.insert_session(key, json.dumps(data_json))
|
||||
logger.debug("已保存 OpenAI 会话历史记录")
|
||||
for session_id, content in self.session_memory.items():
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(content))
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
logger.error("保存 LLM 历史记录失败: " + str(e))
|
||||
finally:
|
||||
time.sleep(10*60)
|
||||
|
||||
@@ -133,7 +119,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.session_personality = {} # 重置
|
||||
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
|
||||
tokens_num = len(encoded_prompt)
|
||||
model = self.model_configs['model']
|
||||
model = self.get_curr_model()
|
||||
if model in MODELS and tokens_num > MODELS[model] - 500:
|
||||
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
|
||||
|
||||
@@ -172,7 +158,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for record in self.session_memory[session_id]:
|
||||
if "user" in record and record['user']:
|
||||
if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
|
||||
logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
|
||||
logger.warn(f"由于当前模型 {self.get_curr_model()} 不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
|
||||
continue
|
||||
context.append(record['user'])
|
||||
if "AI" in record and record['AI']:
|
||||
@@ -184,7 +170,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
是否是 LVM
|
||||
'''
|
||||
return self.model_configs['model'].startswith("gpt-4")
|
||||
return self.get_curr_model().startswith("gpt-4")
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
@@ -237,7 +223,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.session_memory[session_id].append(message)
|
||||
|
||||
# 根据 模型的上下文窗口 淘汰掉多余的记录
|
||||
curr_model = self.model_configs['model']
|
||||
curr_model = self.get_curr_model()
|
||||
if curr_model in MODELS:
|
||||
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
|
||||
# if message['usage_tokens'] > maxium_tokens_num:
|
||||
@@ -296,6 +282,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
extra_conf: Dict = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on":
|
||||
return "这是一个测试消息。"
|
||||
|
||||
super().accu_model_stat()
|
||||
if not session_id:
|
||||
session_id = "unknown"
|
||||
@@ -313,7 +302,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 1. 可以保证之后 pop 的时候不会出现问题
|
||||
# 2. 可以保证不会超过最大 token 数
|
||||
_encoded_prompt = self.tokenizer.encode(prompt)
|
||||
curr_model = self.model_configs['model']
|
||||
curr_model = self.get_curr_model()
|
||||
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
|
||||
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
|
||||
prompt = self.tokenizer.decode(_encoded_prompt)
|
||||
@@ -324,7 +313,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 获取上下文,openai 格式
|
||||
contexts = await self.retrieve_context(session_id)
|
||||
|
||||
conf = self.model_configs
|
||||
conf = asdict(self.llm_config.model_config)
|
||||
if extra_conf: conf.update(extra_conf)
|
||||
|
||||
# start request
|
||||
@@ -336,12 +325,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if tools:
|
||||
completion_coro = self.client.chat.completions.create(
|
||||
messages=contexts,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
**conf
|
||||
)
|
||||
else:
|
||||
completion_coro = self.client.chat.completions.create(
|
||||
messages=contexts,
|
||||
stream=False,
|
||||
**conf
|
||||
)
|
||||
try:
|
||||
@@ -355,10 +346,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if ok: continue
|
||||
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
|
||||
except BadRequestError as e:
|
||||
retry += 1
|
||||
logger.warn(f"OpenAI 请求异常:{e}。")
|
||||
if "image_url is only supported by certain models." in str(e):
|
||||
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
|
||||
raise e
|
||||
raise Exception(f"当前模型 { self.get_curr_model() } 不支持图片输入,请更换模型。")
|
||||
except RateLimitError as e:
|
||||
if "You exceeded your current quota" in str(e):
|
||||
self.keys_data[self.chosen_api_key] = False
|
||||
@@ -434,11 +425,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
retry = 0
|
||||
conf = self.image_generator_model_configs
|
||||
super().accu_model_stat(model=conf['model'])
|
||||
if not conf:
|
||||
logger.error("OpenAI 图片生成模型配置不存在。")
|
||||
raise Exception("OpenAI 图片生成模型配置不存在。")
|
||||
|
||||
super().accu_model_stat(model=conf['model'])
|
||||
while retry < 3:
|
||||
try:
|
||||
images_response = await self.client.images.generate(
|
||||
@@ -494,12 +484,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
return contexts_str, len(self.session_memory[session_id])
|
||||
|
||||
def set_model(self, model: str):
|
||||
self.model_configs['model'] = model
|
||||
self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
|
||||
# TODO: 更新配置文件
|
||||
super().set_curr_model(model)
|
||||
|
||||
def get_configs(self):
|
||||
return self.model_configs
|
||||
return asdict(self.llm_config)
|
||||
|
||||
def get_keys_data(self):
|
||||
return self.keys_data
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pydantic~=1.10.4
|
||||
aiohttp
|
||||
requests
|
||||
openai
|
||||
qq-botpy
|
||||
chardet~=5.1.0
|
||||
@@ -10,10 +9,8 @@ beautifulsoup4
|
||||
googlesearch-python
|
||||
tiktoken
|
||||
readability-lxml
|
||||
baidu-aip
|
||||
websockets
|
||||
flask
|
||||
quart
|
||||
psutil
|
||||
lxml_html_clean
|
||||
SparkleLogging
|
||||
colorlog
|
||||
aiocqhttp
|
||||
|
||||
18
tests/mocks/onebot.py
Normal file
18
tests/mocks/onebot.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import copy
|
||||
from aiocqhttp import Event
|
||||
|
||||
class MockOneBotMessage():
|
||||
def __init__(self):
|
||||
# 这些数据不是敏感的
|
||||
self.group_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882500, 'message_id': -2147480159, 'message_seq': -2147480159, 'real_id': -2147480159, 'message_type': 'group', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': '', 'role': 'owner'}, 'raw_message': '[CQ:at,qq=3430871669] just reply me `ok`', 'font': 14, 'sub_type': 'normal', 'message': [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': ' just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message', 'group_id': 849750470})
|
||||
self.friend_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882599, 'message_id': -2147480157, 'message_seq': -2147480157, 'real_id': -2147480157, 'message_type': 'private', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': ''}, 'raw_message': 'just reply me `ok`', 'font': 14, 'sub_type': 'friend', 'message': [{'data': {'text': 'just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message'})
|
||||
|
||||
def create_random_group_message(self):
|
||||
return self.group_event_sample
|
||||
|
||||
def create_random_direct_message(self):
|
||||
return self.friend_event_sample
|
||||
|
||||
def create_msg(self, text: str):
|
||||
self.group_event_sample.message = [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': text}, 'type': 'text'}]
|
||||
return self.group_event_sample
|
||||
54
tests/mocks/qq_official.py
Normal file
54
tests/mocks/qq_official.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import botpy.message
|
||||
|
||||
class MockQQOfficialMessage():
|
||||
def __init__(self):
|
||||
# 这些数据已经经过去敏处理
|
||||
self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T19:58:52+08:00'}
|
||||
self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:06:32+08:00'}
|
||||
self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:15:24+08:00'}
|
||||
self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849"
|
||||
|
||||
self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'}
|
||||
self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
|
||||
self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
|
||||
self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
|
||||
|
||||
self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'}
|
||||
self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
|
||||
self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
|
||||
self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
|
||||
|
||||
def create_random_group_message(self):
|
||||
mocked = botpy.message.GroupMessage(
|
||||
api=None,
|
||||
event_id=self.group_event_id_sample,
|
||||
data=self.group_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
def create_random_guild_message(self):
|
||||
mocked = botpy.message.Message(
|
||||
api=None,
|
||||
event_id=self.guild_event_id_sample,
|
||||
data=self.guild_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
def create_random_direct_message(self):
|
||||
mocked = botpy.message.DirectMessage(
|
||||
api=None,
|
||||
event_id=self.direct_event_id_sample,
|
||||
data=self.direct_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
def create_msg(self, text: str):
|
||||
sample = self.group_plain_text_sample.copy()
|
||||
sample['content'] = text
|
||||
mocked = botpy.message.Message(
|
||||
api=None,
|
||||
event_id=self.group_event_id_sample,
|
||||
data=sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
51
tests/test_http_server.py
Normal file
51
tests/test_http_server.py
Normal file
@@ -0,0 +1,51 @@
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
BASE_URL = "http://0.0.0.0:6185/api"
|
||||
|
||||
async def get_url(url):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
|
||||
async def post_url(url, data):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as response:
|
||||
return await response.json()
|
||||
|
||||
class TestHTTPServer:
|
||||
@pytest.mark.asyncio
|
||||
async def test_config(self):
|
||||
configs = await get_url(f"{BASE_URL}/configs")
|
||||
assert 'data' in configs and 'metadata' in configs['data'] \
|
||||
and 'config' in configs['data']
|
||||
config = configs['data']['config']
|
||||
# test post config
|
||||
await post_url(f"{BASE_URL}/astrbot-configs", config)
|
||||
# text post config with invalid data
|
||||
assert 'rate_limit' in config['platform_settings']
|
||||
config['platform_settings']['rate_limit'] = "invalid"
|
||||
ret = await post_url(f"{BASE_URL}/astrbot-configs", config)
|
||||
assert 'status' in ret and ret['status'] == 'error'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update(self):
|
||||
await get_url(f"{BASE_URL}/check_update")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(self):
|
||||
pname = "astrbot_plugin_bilibili"
|
||||
url = f"https://github.com/Soulter/{pname}"
|
||||
|
||||
await get_url(f"{BASE_URL}/extensions")
|
||||
|
||||
# test install plugin
|
||||
await post_url(f"{BASE_URL}/extensions/install", {
|
||||
"url": url
|
||||
})
|
||||
|
||||
# test uninstall plugin
|
||||
await post_url(f"{BASE_URL}/extensions/uninstall", {
|
||||
"name": pname
|
||||
})
|
||||
155
tests/test_message.py
Normal file
155
tests/test_message.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
import os
|
||||
|
||||
from tests.mocks.qq_official import MockQQOfficialMessage
|
||||
from tests.mocks.onebot import MockOneBotMessage
|
||||
|
||||
from astrbot.bootstrap import AstrBotBootstrap
|
||||
from model.platform.qq_official import QQOfficial
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from util.log import LogManager
|
||||
|
||||
from util.cmd_config import QQOfficialPlatformConfig, AiocqhttpPlatformConfig
|
||||
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
os.environ['TEST_MODE'] = 'on'
|
||||
bootstrap = AstrBotBootstrap()
|
||||
|
||||
llm_config = bootstrap.context.config_helper.llm[0]
|
||||
llm_config.api_base = os.environ['OPENAI_API_BASE']
|
||||
llm_config.key = [os.environ['OPENAI_API_KEY']]
|
||||
llm_config.model_config.model = os.environ['LLM_MODEL']
|
||||
llm_config.model_config.max_tokens = 1000
|
||||
asyncio.run(bootstrap.run())
|
||||
llm_provider = ProviderOpenAIOfficial(llm_config, bootstrap.db_helper)
|
||||
bootstrap.message_handler.provider = llm_provider
|
||||
bootstrap.config_helper.wake_prefix = ["/"]
|
||||
bootstrap.config_helper.admins_id = ["905617992"]
|
||||
|
||||
for p_config in bootstrap.context.config_helper.platform:
|
||||
if isinstance(p_config, QQOfficialPlatformConfig):
|
||||
qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler, p_config)
|
||||
elif isinstance(p_config, AiocqhttpPlatformConfig):
|
||||
aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler, p_config)
|
||||
|
||||
class TestBasicMessageHandle():
|
||||
@pytest.mark.asyncio
|
||||
async def test_qqofficial_group_message(self):
|
||||
group_message = MockQQOfficialMessage().create_random_group_message()
|
||||
abm = qq_official._parse_from_qqofficial(group_message, MessageType.GROUP_MESSAGE)
|
||||
ret = await qq_official.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qqofficial_guild_message(self):
|
||||
guild_message = MockQQOfficialMessage().create_random_guild_message()
|
||||
abm = qq_official._parse_from_qqofficial(guild_message, MessageType.GUILD_MESSAGE)
|
||||
ret = await qq_official.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
# 有共同性,为了节约开销,不测试频道私聊。
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_qqofficial_private_message(self):
|
||||
# private_message = MockQQOfficialMessage().create_random_direct_message()
|
||||
# abm = qq_official._parse_from_qqofficial(private_message, MessageType.FRIEND_MESSAGE)
|
||||
# ret = await qq_official.handle_msg(abm)
|
||||
# print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiocqhttp_group_message(self):
|
||||
event = MockOneBotMessage().create_random_group_message()
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiocqhttp_direct_message(self):
|
||||
event = MockOneBotMessage().create_random_direct_message()
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
class TestInteralCommandHsandle():
|
||||
def create(self, text: str):
|
||||
event = MockOneBotMessage().create_msg(text)
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
return abm
|
||||
|
||||
async def fast_test(self, text: str):
|
||||
abm = self.create(text)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(f"Command: {text}, Result: {ret.result_message}")
|
||||
return ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_save(self):
|
||||
abm = self.create("/websearch on")
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
assert bootstrap.context.config_helper.llm_settings.web_search \
|
||||
== bootstrap.config_helper.get("llm_settings")['web_search']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websearch(self):
|
||||
await self.fast_test("/websearch")
|
||||
await self.fast_test("/websearch on")
|
||||
await self.fast_test("/websearch off")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help(self):
|
||||
await self.fast_test("/help")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_myid(self):
|
||||
await self.fast_test("/myid")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wake(self):
|
||||
await self.fast_test("/wake")
|
||||
await self.fast_test("/wake #")
|
||||
assert "#" in bootstrap.context.config_helper.wake_prefix
|
||||
assert "#" in bootstrap.context.config_helper.get("wake_prefix")
|
||||
await self.fast_test("#wake /")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep(self):
|
||||
await self.fast_test("/provider")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update(self):
|
||||
await self.fast_test("/update")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_t2i(self):
|
||||
if not bootstrap.context.config_helper.t2i:
|
||||
abm = self.create("/t2i")
|
||||
await aiocqhttp.handle_msg(abm)
|
||||
await self.fast_test("/help")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin(self):
|
||||
pname = "astrbot_plugin_bilibili"
|
||||
url = f"https://github.com/Soulter/{pname}"
|
||||
await self.fast_test("/plugin")
|
||||
await self.fast_test(f"/plugin l")
|
||||
await self.fast_test(f"/plugin i {url}")
|
||||
await self.fast_test(f"/plugin u {url}")
|
||||
await self.fast_test(f"/plugin d {pname}")
|
||||
|
||||
class TestLLMChat():
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_chat(self):
|
||||
os.environ["TEST_LLM"] = "on"
|
||||
ret = await llm_provider.text_chat("Just reply `ok`", "test")
|
||||
print(ret)
|
||||
event = MockOneBotMessage().create_msg("Just reply `ok`")
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
os.environ["TEST_LLM"] = "off"
|
||||
|
||||
28
type/cached_queue.py
Normal file
28
type/cached_queue.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from asyncio import Queue
|
||||
from collections import deque
|
||||
from typing import Deque
|
||||
|
||||
class CachedQueue(Queue):
|
||||
def __init__(self, maxsize: int = 0, cachesize: int = 200):
|
||||
super().__init__(maxsize)
|
||||
self.cache = deque(maxlen=cachesize)
|
||||
|
||||
def put_nowait(self, item):
|
||||
self.cache.append(item)
|
||||
super().put_nowait(item)
|
||||
|
||||
def get_nowait(self):
|
||||
item = super().get_nowait()
|
||||
return item
|
||||
|
||||
def get(self):
|
||||
item = super().get()
|
||||
return item
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
with self.mutex:
|
||||
self._queue.clear()
|
||||
|
||||
def get_cache(self) -> Deque:
|
||||
return self.cache
|
||||
350
type/config.py
350
type/config.py
@@ -1,75 +1,287 @@
|
||||
VERSION = '3.3.9'
|
||||
VERSION = '3.3.17'
|
||||
DB_PATH = 'data/data_v2.db'
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"qqbot": {
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"token": "",
|
||||
},
|
||||
"gocqbot": {
|
||||
"enable": False,
|
||||
},
|
||||
"uniqueSessionMode": False,
|
||||
"dump_history_interval": 10,
|
||||
"limit": {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
},
|
||||
"notice": "",
|
||||
"direct_message_mode": True,
|
||||
"reply_prefix": "",
|
||||
"baidu_aip": {
|
||||
"enable": False,
|
||||
"app_id": "",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
},
|
||||
"openai": {
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"chatGPTConfigs": {
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 6000,
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
# 新版本配置文件,摈弃旧版本令人困惑的配置项 :D
|
||||
DEFAULT_CONFIG_VERSION_2 = {
|
||||
"config_version": 2,
|
||||
"platform": [
|
||||
{
|
||||
"id": "default",
|
||||
"name": "qq_official",
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"enable_group_c2c": True,
|
||||
"enable_guild_direct_message": True,
|
||||
},
|
||||
"total_tokens_limit": 10000,
|
||||
{
|
||||
"id": "default",
|
||||
"name": "nakuru",
|
||||
"enable": False,
|
||||
"host": "172.0.0.1",
|
||||
"port": 5700,
|
||||
"websocket_port": 6700,
|
||||
"enable_group": True,
|
||||
"enable_guild": True,
|
||||
"enable_direct_message": True,
|
||||
"enable_group_increase": True,
|
||||
},
|
||||
{
|
||||
"id": "default",
|
||||
"name": "aiocqhttp",
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 6199,
|
||||
}
|
||||
],
|
||||
"platform_settings": {
|
||||
"unique_session": False,
|
||||
"rate_limit": {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
},
|
||||
"reply_prefix": "",
|
||||
"forward_threshold": 200, # 转发消息的阈值
|
||||
},
|
||||
"qq_forward_threshold": 200,
|
||||
"qq_welcome": "",
|
||||
"qq_pic_mode": True,
|
||||
"gocq_host": "127.0.0.1",
|
||||
"gocq_http_port": 5700,
|
||||
"gocq_websocket_port": 6700,
|
||||
"gocq_react_group": True,
|
||||
"gocq_react_guild": True,
|
||||
"gocq_react_friend": True,
|
||||
"gocq_react_group_increase": True,
|
||||
"other_admins": [],
|
||||
"CHATGPT_BASE_URL": "",
|
||||
"qqbot_secret": "",
|
||||
"qqofficial_enable_group_message": False,
|
||||
"admin_qq": "",
|
||||
"nick_qq": ["/", "!"],
|
||||
"admin_qqchan": "",
|
||||
"llm_env_prompt": "",
|
||||
"llm_wake_prefix": "",
|
||||
"default_personality_str": "",
|
||||
"openai_image_generate": {
|
||||
"model": "dall-e-3",
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
"llm": [
|
||||
{
|
||||
"id": "default",
|
||||
"name": "openai",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"prompt_prefix": "",
|
||||
"default_personality": "",
|
||||
"model_config": {
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 6000,
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
"image_generation_model_config": {
|
||||
"enable": True,
|
||||
"model": "dall-e-3",
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
}
|
||||
},
|
||||
],
|
||||
"llm_settings": {
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"identifier": False,
|
||||
},
|
||||
"http_proxy": "",
|
||||
"content_safety": {
|
||||
"baidu_aip": {
|
||||
"enable": False,
|
||||
"app_id": "",
|
||||
"api_key": "",
|
||||
"secret_key": "",
|
||||
},
|
||||
"internal_keywords": {
|
||||
"enable": True,
|
||||
"extra_keywords": [],
|
||||
}
|
||||
},
|
||||
"wake_prefix": ["/"],
|
||||
"t2i": True,
|
||||
"dump_history_interval": 10,
|
||||
"admins_id": [],
|
||||
"https_proxy": "",
|
||||
"dashboard_username": "",
|
||||
"dashboard_password": "",
|
||||
"aiocqhttp": {
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 0,
|
||||
}
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
"username": "",
|
||||
"password": "",
|
||||
},
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
}
|
||||
|
||||
# 这个是用于迁移旧版本配置文件的映射表
|
||||
MAPPINGS_1_2 = [
|
||||
[["qqbot", "enable"], ["platform", 0, "enable"]],
|
||||
[["qqbot", "appid"], ["platform", 0, "appid"]],
|
||||
[["qqbot", "token"], ["platform", 0, "secret"]],
|
||||
[["qqofficial_enable_group_message"], ["platform", 0, "enable_group_c2c"]],
|
||||
[["direct_message_mode"], ["platform", 0, "enable_guild_direct_message"]],
|
||||
[["gocqbot", "enable"], ["platform", 1, "enable"]],
|
||||
[["gocq_host"], ["platform", 1, "host"]],
|
||||
[["gocq_http_port"], ["platform", 1, "port"]],
|
||||
[["gocq_websocket_port"], ["platform", 1, "websocket_port"]],
|
||||
[["gocq_react_group"], ["platform", 1, "enable_group"]],
|
||||
[["gocq_react_guild"], ["platform", 1, "enable_guild"]],
|
||||
[["gocq_react_friend"], ["platform", 1, "enable_direct_message"]],
|
||||
[["gocq_react_group_increase"], ["platform", 1, "enable_group_increase"]],
|
||||
[["aiocqhttp", "enable"], ["platform", 2, "enable"]],
|
||||
[["aiocqhttp", "ws_reverse_host"], ["platform", 2, "ws_reverse_host"]],
|
||||
[["aiocqhttp", "ws_reverse_port"], ["platform", 2, "ws_reverse_port"]],
|
||||
[["uniqueSessionMode"], ["platform_settings", "unique_session"]],
|
||||
[["limit", "time"], ["platform_settings", "rate_limit", "time"]],
|
||||
[["limit", "count"], ["platform_settings", "rate_limit", "count"]],
|
||||
[["reply_prefix"], ["platform_settings", "reply_prefix"]],
|
||||
[["qq_forward_threshold"], ["platform_settings", "forward_threshold"]],
|
||||
|
||||
[["openai", "key"], ["llm", 0, "key"]],
|
||||
[["openai", "api_base"], ["llm", 0, "api_base"]],
|
||||
[["openai", "chatGPTConfigs", "model"], ["llm", 0, "model_config", "model"]],
|
||||
[["openai", "chatGPTConfigs", "max_tokens"], ["llm", 0, "model_config", "max_tokens"]],
|
||||
[["openai", "chatGPTConfigs", "temperature"], ["llm", 0, "model_config", "temperature"]],
|
||||
[["openai", "chatGPTConfigs", "top_p"], ["llm", 0, "model_config", "top_p"]],
|
||||
[["openai", "chatGPTConfigs", "frequency_penalty"], ["llm", 0, "model_config", "frequency_penalty"]],
|
||||
[["openai", "chatGPTConfigs", "presence_penalty"], ["llm", 0, "model_config", "presence_penalty"]],
|
||||
|
||||
[["default_personality_str"], ["llm", 0, "default_personality"]],
|
||||
[["llm_env_prompt"], ["llm", 0, "prompt_prefix"]],
|
||||
[["openai_image_generate", "model"], ["llm", 0, "image_generation_model_config", "model"]],
|
||||
[["openai_image_generate", "size"], ["llm", 0, "image_generation_model_config", "size"]],
|
||||
[["openai_image_generate", "style"], ["llm", 0, "image_generation_model_config", "style"]],
|
||||
[["openai_image_generate", "quality"], ["llm", 0, "image_generation_model_config", "quality"]],
|
||||
|
||||
[["llm_wake_prefix"], ["llm_settings", "wake_prefix"]],
|
||||
|
||||
[["baidu_aip", "enable"], ["content_safety", "baidu_aip", "enable"]],
|
||||
[["baidu_aip", "app_id"], ["content_safety", "baidu_aip", "app_id"]],
|
||||
[["baidu_aip", "api_key"], ["content_safety", "baidu_aip", "api_key"]],
|
||||
[["baidu_aip", "secret_key"], ["content_safety", "baidu_aip", "secret_key"]],
|
||||
|
||||
[["qq_pic_mode"], ["t2i"]],
|
||||
[["dump_history_interval"], ["dump_history_interval"]],
|
||||
[["other_admins"], ["admins_id"]],
|
||||
[["http_proxy"], ["http_proxy"]],
|
||||
[["https_proxy"], ["https_proxy"]],
|
||||
[["dashboard_username"], ["dashboard", "username"]],
|
||||
[["dashboard_password"], ["dashboard", "password"]],
|
||||
[["nick_qq"], ["wake_prefix"]],
|
||||
]
|
||||
|
||||
# 配置项的中文描述、值类型
|
||||
CONFIG_METADATA_2 = {
|
||||
"config_version": {"description": "配置版本", "type": "int"},
|
||||
"platform": {
|
||||
"description": "平台配置",
|
||||
"type": "list",
|
||||
"items": {
|
||||
"name": {"description": "平台名称", "type": "string"},
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
"appid": {"description": "appid", "type": "string"},
|
||||
"secret": {"description": "secret", "type": "string"},
|
||||
"enable_group_c2c": {"description": "启用消息列表单聊", "type": "bool"},
|
||||
"enable_guild_direct_message": {"description": "启用频道私聊", "type": "bool"},
|
||||
"host": {"description": "主机地址", "type": "string"},
|
||||
"port": {"description": "端口", "type": "int"},
|
||||
"websocket_port": {"description": "Websocket 端口", "type": "int"},
|
||||
"ws_reverse_host": {"description": "反向 Websocket 主机地址", "type": "string"},
|
||||
"ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int"},
|
||||
"enable_group": {"description": "接收群组消息", "type": "bool"},
|
||||
"enable_guild": {"description": "接收频道消息", "type": "bool"},
|
||||
"enable_direct_message": {"description": "接收频道私聊", "type": "bool"},
|
||||
"enable_group_increase": {"description": "接收群组成员增加事件", "type": "bool"},
|
||||
}
|
||||
},
|
||||
"platform_settings": {
|
||||
"description": "平台设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"unique_session": {"description": "会话隔离到单个人", "type": "bool"},
|
||||
"rate_limit": {
|
||||
"description": "速率限制",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"time": {"description": "消息速率限制时间", "type": "int"},
|
||||
"count": {"description": "消息速率限制计数", "type": "int"},
|
||||
}
|
||||
},
|
||||
"reply_prefix": {"description": "回复前缀", "type": "string"},
|
||||
"forward_threshold": {"description": "转发消息的字数阈值", "type": "int"},
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"description": "大语言模型配置",
|
||||
"type": "list",
|
||||
"items": {
|
||||
"name": {"description": "模型名称", "type": "string"},
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
"key": {"description": "API Key", "type": "list", "items": {"type": "string"}},
|
||||
"api_base": {"description": "API Base URL", "type": "string"},
|
||||
"prompt_prefix": {"description": "Prompt 前缀", "type": "string"},
|
||||
"default_personality": {"description": "默认人格", "type": "string"},
|
||||
"model_config": {
|
||||
"description": "模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {"description": "模型名称", "type": "string"},
|
||||
"max_tokens": {"description": "最大令牌数", "type": "int"},
|
||||
"temperature": {"description": "温度", "type": "float"},
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
"frequency_penalty": {"description": "频率惩罚", "type": "float"},
|
||||
"presence_penalty": {"description": "存在惩罚", "type": "float"},
|
||||
}
|
||||
},
|
||||
"image_generation_model_config": {
|
||||
"description": "图像生成模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用(需要该提供商支持图像生成模型)", "type": "bool"},
|
||||
"model": {"description": "模型名称", "type": "string"},
|
||||
"size": {"description": "图像尺寸", "type": "string"},
|
||||
"style": {"description": "图像风格", "type": "string"},
|
||||
"quality": {"description": "图像质量", "type": "string"},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"llm_settings": {
|
||||
"description": "大语言模型设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"wake_prefix": {"description": "LLM 聊天额外唤醒前缀", "type": "string"},
|
||||
"web_search": {"description": "启用网页搜索(能访问 Google 时效果最佳)", "type": "bool"},
|
||||
"identifier": {"description": "启动识别群员(略微增加token开销)", "type": "bool"},
|
||||
}
|
||||
},
|
||||
"content_safety": {
|
||||
"description": "内容安全",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"baidu_aip": {
|
||||
"description": "百度内容审核配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用百度内容审核(需手动安装 baidu-aip 库)", "type": "bool"},
|
||||
"app_id": {"description": "APP ID", "type": "string"},
|
||||
"api_key": {"description": "API Key", "type": "string"},
|
||||
"secret_key": {"description": "Secret Key", "type": "string"},
|
||||
}
|
||||
},
|
||||
"internal_keywords": {
|
||||
"description": "内部关键词过滤",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用内部关键词过滤", "type": "bool"},
|
||||
"extra_keywords": {"description": "额外关键词(支持正则)", "type": "list", "items": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"wake_prefix": {"description": "唤醒前缀列表", "type": "list", "items": {"type": "string"}},
|
||||
"t2i": {"description": "文本转图像功能", "type": "bool"},
|
||||
"dump_history_interval": {"description": "历史记录转储间隔", "type": "int"},
|
||||
"admins_id": {"description": "管理员ID列表", "type": "list", "items": {"type": "int"}},
|
||||
"https_proxy": {"description": "HTTPS代理", "type": "string"},
|
||||
"http_proxy": {"description": "HTTP代理", "type": "string"},
|
||||
"dashboard": {
|
||||
"description": "仪表盘配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
"username": {"description": "用户名", "type": "string"},
|
||||
"password": {"description": "密码", "type": "string"},
|
||||
}
|
||||
},
|
||||
"log_level": {"description": "控制台日志级别(DEBUG, INFO, WARNING, ERROR)", "type": "string"},
|
||||
"t2i_endpoint": {"description": "文本转图像服务接口(为空时使用公共服务器)", "type": "string"},
|
||||
}
|
||||
|
||||
@@ -47,10 +47,17 @@ class AstrMessageEvent():
|
||||
context: Context,
|
||||
platform_name: str,
|
||||
session_id: str,
|
||||
role: str = "member",
|
||||
|
||||
unified_msg_origin: str = None,
|
||||
only_command: bool = False):
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.sender.user_id)
|
||||
if sender_id in context.config_helper.admins_id:
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
|
||||
ame = AstrMessageEvent(message.message_str,
|
||||
message,
|
||||
context.find_platform(platform_name),
|
||||
|
||||
8
type/middleware.py
Normal file
8
type/middleware.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Middleware():
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
origin: str = "" # 注册来源
|
||||
func: callable = None
|
||||
@@ -1,17 +1,19 @@
|
||||
import asyncio
|
||||
import asyncio, os, time
|
||||
from asyncio import Task
|
||||
from type.register import *
|
||||
from typing import List, Awaitable
|
||||
from logging import Logger
|
||||
from util.cmd_config import CmdConfig
|
||||
from util.cmd_config import AstrBotConfig
|
||||
from util.t2i.renderer import TextToImageRenderer
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from util.image_uploader import ImageUploader
|
||||
from util.updator.plugin_updator import PluginUpdator
|
||||
from type.command import CommandResult
|
||||
from type.middleware import Middleware
|
||||
from type.astrbot_message import MessageType
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from model.provider.provider import Provider
|
||||
from type.cached_queue import CachedQueue
|
||||
|
||||
|
||||
class Context:
|
||||
@@ -20,19 +22,19 @@ class Context:
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
self.running = True
|
||||
self.logger: Logger = None
|
||||
self.base_config: dict = None # 配置(期望启动机器人后是不变的)
|
||||
self.config_helper: CmdConfig = None
|
||||
self.config_helper: AstrBotConfig = None
|
||||
self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件
|
||||
self.platforms: List[RegisteredPlatform] = []
|
||||
self.llms: List[RegisteredLLM] = []
|
||||
self.default_personality: dict = None
|
||||
|
||||
self.unique_session = False # 独立会话
|
||||
self.version: str = None # 机器人版本
|
||||
self.nick: tuple = None # gocq 的唤醒词
|
||||
self.t2i_mode = False
|
||||
self.web_search = False # 是否开启了网页搜索
|
||||
# self.unique_session = False # 独立会话
|
||||
# self.version: str = None # 机器人版本
|
||||
# self.nick: tuple = None # gocq 的唤醒词
|
||||
# self.t2i_mode = False
|
||||
# self.web_search = False # 是否开启了网页搜索
|
||||
|
||||
self.metrics_uploader = None
|
||||
self.updator: AstrBotUpdator = None
|
||||
@@ -42,12 +44,17 @@ class Context:
|
||||
self.image_uploader = ImageUploader()
|
||||
self.message_handler = None # see astrbot/message/handler.py
|
||||
self.ext_tasks: List[Task] = []
|
||||
self.middlewares: List[Middleware] = []
|
||||
|
||||
self.command_manager = None
|
||||
self.running = True
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._start_running = int(time.time())
|
||||
|
||||
self._log_queue = CachedQueue()
|
||||
|
||||
# useless
|
||||
self.reply_prefix = ""
|
||||
# self.reply_prefix = ""
|
||||
|
||||
def register_commands(self,
|
||||
plugin_name: str,
|
||||
@@ -97,13 +104,39 @@ class Context:
|
||||
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
|
||||
'''
|
||||
self.llms.append(RegisteredLLM(llm_name, provider, origin))
|
||||
|
||||
def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable):
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 处理函数
|
||||
'''
|
||||
self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
|
||||
|
||||
def unregister_llm_tool(self, tool_name: str):
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
self.message_handler.llm_tools.remove_func(tool_name)
|
||||
|
||||
def register_middleware(self, middleware: Middleware):
|
||||
'''
|
||||
注册一个中间件。所有的消息事件都会经过中间件处理,然后再进入 LLM 聊天模块。
|
||||
|
||||
在 AstrBot 中,会对到来的消息事件首先检查指令,然后再检查中间件。触发指令后将不会进入 LLM 聊天模块,而中间件会。
|
||||
'''
|
||||
self.middlewares.append(middleware)
|
||||
|
||||
def find_platform(self, platform_name: str) -> RegisteredPlatform:
|
||||
for platform in self.platforms:
|
||||
if platform_name == platform.platform_name:
|
||||
return platform
|
||||
|
||||
raise ValueError("couldn't find the platform you specified")
|
||||
|
||||
if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错
|
||||
raise ValueError("couldn't find the platform you specified")
|
||||
|
||||
async def send_message(self, unified_msg_origin: str, message: CommandResult):
|
||||
'''
|
||||
@@ -118,4 +151,9 @@ class Context:
|
||||
platform_name, message_type, id = l
|
||||
platform = self.find_platform(platform_name)
|
||||
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
|
||||
|
||||
|
||||
def get_current_llm_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前的 LLM Provider。
|
||||
'''
|
||||
return self.message_handler.provider
|
||||
@@ -1,6 +1,5 @@
|
||||
from model.provider.provider import Provider
|
||||
import json
|
||||
import time
|
||||
import textwrap
|
||||
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
@@ -23,10 +22,21 @@ class FuncCall():
|
||||
def __init__(self, provider: Provider) -> None:
|
||||
self.func_list = []
|
||||
self.provider = provider
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
|
||||
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 处理函数
|
||||
'''
|
||||
params = {
|
||||
"type": "object", # hardcore here
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {}
|
||||
}
|
||||
for param in func_args:
|
||||
@@ -34,14 +44,23 @@ class FuncCall():
|
||||
"type": param['type'],
|
||||
"description": param['description']
|
||||
}
|
||||
self._func = {
|
||||
_func = {
|
||||
"name": name,
|
||||
"parameters": params,
|
||||
"description": desc,
|
||||
"func_obj": func_obj,
|
||||
}
|
||||
self.func_list.append(self._func)
|
||||
|
||||
self.func_list.append(_func)
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
for i, f in enumerate(self.func_list):
|
||||
if f["name"] == name:
|
||||
self.func_list.pop(i)
|
||||
break
|
||||
|
||||
def func_dump(self) -> str:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
@@ -65,7 +84,10 @@ class FuncCall():
|
||||
})
|
||||
return _l
|
||||
|
||||
async def func_call(self, question: str, func_definition: str, session_id: str=None):
|
||||
async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider = None) -> tuple:
|
||||
|
||||
if not provider:
|
||||
provider = self.provider
|
||||
|
||||
prompt = textwrap.dedent(f"""
|
||||
ROLE:
|
||||
@@ -91,7 +113,7 @@ class FuncCall():
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
res = await self.provider.text_chat(prompt, session_id)
|
||||
res = await provider.text_chat(prompt, session_id)
|
||||
print(res)
|
||||
if res.find('```') != -1:
|
||||
res = res[res.find('```json') + 7: res.rfind('```')]
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import traceback
|
||||
import random
|
||||
import json
|
||||
import aiohttp
|
||||
import os
|
||||
|
||||
@@ -14,8 +12,10 @@ from util.websearch.bing import Bing
|
||||
from util.websearch.sogo import Sogo
|
||||
from util.websearch.google import Google
|
||||
from model.provider.provider import Provider
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
from type.message_event import AstrMessageEvent
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -31,24 +31,7 @@ def tidy_text(text: str) -> str:
|
||||
'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
# def special_fetch_zhihu(link: str) -> str:
|
||||
# '''
|
||||
# function-calling 函数, 用于获取知乎文章的内容
|
||||
# '''
|
||||
# response = requests.get(link, headers=HEADERS)
|
||||
# response.encoding = "utf-8"
|
||||
# soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
# if "zhuanlan.zhihu.com" in link:
|
||||
# r = soup.find(class_="Post-RichTextContainer")
|
||||
# else:
|
||||
# r = soup.find(class_="List-item").find(class_="RichContent-inner")
|
||||
# if r is None:
|
||||
# print("debug: zhihu none")
|
||||
# raise Exception("zhihu none")
|
||||
# return tidy_text(r.text)
|
||||
|
||||
async def search_from_bing(keyword: str) -> str:
|
||||
async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str) -> str:
|
||||
'''
|
||||
tools, 从 bing 搜索引擎搜索
|
||||
'''
|
||||
@@ -84,10 +67,11 @@ async def search_from_bing(keyword: str) -> str:
|
||||
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
|
||||
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
return ret
|
||||
|
||||
return await summarize(context, ame, ret)
|
||||
|
||||
|
||||
async def fetch_website_content(url):
|
||||
async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: str):
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -97,81 +81,13 @@ async def fetch_website_content(url):
|
||||
ret = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(ret, 'html.parser')
|
||||
ret = tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
|
||||
async def web_search(prompt: str, provider: Provider, session_id: str, official_fc: bool=False):
|
||||
'''
|
||||
@param official_fc: 使用官方 function-calling
|
||||
'''
|
||||
new_func_call = FuncCall(provider)
|
||||
|
||||
new_func_call.add_func("web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"description": "搜索关键词"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
search_from_bing
|
||||
)
|
||||
new_func_call.add_func("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "要获取内容的网页链接"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
return await summarize(context, ame, ret)
|
||||
|
||||
has_func = False
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
# we use official function-calling
|
||||
try:
|
||||
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
|
||||
except BadRequestError as e:
|
||||
# seems dont support function-calling
|
||||
logger.error(f"error: {e}. Try to use local function-calling implementation")
|
||||
return await web_search(prompt, provider, session_id, official_fc=False)
|
||||
if isinstance(result, Function):
|
||||
logger.debug(f"function-calling: {result}")
|
||||
func_obj = None
|
||||
for i in new_func_call.func_list:
|
||||
if i["name"] == result.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(result.arguments)
|
||||
function_invoked_ret = await func_obj(**args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
# we use our own function-calling
|
||||
try:
|
||||
args = {
|
||||
'question': prompt,
|
||||
'func_definition': new_func_call.func_dump(),
|
||||
}
|
||||
function_invoked_ret, has_func = await new_func_call.func_call(**args)
|
||||
|
||||
if not has_func:
|
||||
return await provider.text_chat(prompt, session_id)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return await provider.text_chat(prompt, session_id) + "(网页搜索失败, 此为默认回复)"
|
||||
|
||||
if has_func:
|
||||
await provider.forget(session_id=session_id)
|
||||
summary_prompt = f"""
|
||||
async def summarize(context: Context, ame: AstrMessageEvent, text: str):
|
||||
|
||||
summary_prompt = f"""
|
||||
你是一个专业且高效的助手,你的任务是
|
||||
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
|
||||
1. 根据下面的相关材料对用户的问题 `{ame.message_str}` 进行总结;
|
||||
2. 简单地发表你对这个问题的看法。
|
||||
|
||||
# 例子
|
||||
@@ -183,7 +99,7 @@ async def web_search(prompt: str, provider: Provider, session_id: str, official_
|
||||
2. 请**直接输出总结**,不要输出多余的内容和提示语。
|
||||
|
||||
# 相关材料
|
||||
{function_invoked_ret}"""
|
||||
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
|
||||
return ret
|
||||
return function_invoked_ret
|
||||
{text}"""
|
||||
|
||||
provider = context.get_current_llm_provider()
|
||||
return await provider.text_chat(prompt=summary_prompt, session_id=ame.session_id)
|
||||
@@ -1,33 +1,244 @@
|
||||
import os
|
||||
import json
|
||||
from type.config import DEFAULT_CONFIG
|
||||
import shutil
|
||||
import logging
|
||||
from util.io import on_error
|
||||
from type.config import DEFAULT_CONFIG_VERSION_2, MAPPINGS_1_2
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
cpath = "data/cmd_config.json"
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
def check_exist():
|
||||
if not os.path.exists(cpath):
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump({}, f, ensure_ascii=False)
|
||||
f.flush()
|
||||
@dataclass
|
||||
class RateLimit:
|
||||
time: int = 60
|
||||
count: int = 30
|
||||
|
||||
@dataclass
|
||||
class PlatformSettings:
|
||||
unique_session: bool = False
|
||||
rate_limit: RateLimit = field(default_factory=RateLimit)
|
||||
reply_prefix: str = ""
|
||||
forward_threshold: int = 200
|
||||
|
||||
def __post_init__(self):
|
||||
self.rate_limit = RateLimit(**self.rate_limit)
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig():
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
enable: bool = False
|
||||
|
||||
@dataclass
|
||||
class QQOfficialPlatformConfig(PlatformConfig):
|
||||
appid: str = ""
|
||||
secret: str = ""
|
||||
enable_group_c2c: bool = True
|
||||
enable_guild_direct_message: bool = True
|
||||
|
||||
@dataclass
|
||||
class NakuruPlatformConfig(PlatformConfig):
|
||||
host: str = "172.0.0.1",
|
||||
port: int = 5700,
|
||||
websocket_port: int = 6700,
|
||||
enable_group: bool = True,
|
||||
enable_guild: bool = True,
|
||||
enable_direct_message: bool = True,
|
||||
enable_group_increase: bool = True
|
||||
|
||||
@dataclass
|
||||
class AiocqhttpPlatformConfig(PlatformConfig):
|
||||
ws_reverse_host: str = ""
|
||||
ws_reverse_port: int = 6199
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
model: str = "gpt-4o"
|
||||
max_tokens: int = 6000
|
||||
temperature: float = 0.9
|
||||
top_p: float = 1
|
||||
frequency_penalty: float = 0
|
||||
presence_penalty: float = 0
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationModelConfig:
|
||||
enable: bool = True
|
||||
model: str = "dall-e-3"
|
||||
size: str = "1024x1024"
|
||||
style: str = "vivid"
|
||||
quality: str = "standard"
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
id: str = ""
|
||||
name: str = "openai"
|
||||
enable: bool = True
|
||||
key: List[str] = field(default_factory=list)
|
||||
api_base: str = ""
|
||||
prompt_prefix: str = ""
|
||||
default_personality: str = ""
|
||||
model_config: ModelConfig = field(default_factory=ModelConfig)
|
||||
image_generation_model_config: Optional[ImageGenerationModelConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_config = ModelConfig(**self.model_config)
|
||||
if self.image_generation_model_config:
|
||||
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config)
|
||||
@dataclass
|
||||
class LLMSettings:
|
||||
wake_prefix: str = ""
|
||||
web_search: bool = False
|
||||
identifier: bool = False
|
||||
|
||||
@dataclass
|
||||
class BaiduAIPConfig:
|
||||
enable: bool = False
|
||||
app_id: str = ""
|
||||
api_key: str = ""
|
||||
secret_key: str = ""
|
||||
|
||||
@dataclass
|
||||
class InternalKeywordsConfig:
|
||||
enable: bool = True
|
||||
extra_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class ContentSafetyConfig:
|
||||
baidu_aip: BaiduAIPConfig = field(default_factory=BaiduAIPConfig)
|
||||
internal_keywords: InternalKeywordsConfig = field(default_factory=InternalKeywordsConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
self.baidu_aip = BaiduAIPConfig(**self.baidu_aip)
|
||||
self.internal_keywords = InternalKeywordsConfig(**self.internal_keywords)
|
||||
|
||||
@dataclass
|
||||
class DashboardConfig:
|
||||
enable: bool = True
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
|
||||
@dataclass
|
||||
class AstrBotConfig():
|
||||
config_version: int = 2
|
||||
platform_settings: PlatformSettings = field(default_factory=PlatformSettings)
|
||||
llm: List[LLMConfig] = field(default_factory=list)
|
||||
llm_settings: LLMSettings = field(default_factory=LLMSettings)
|
||||
content_safety: ContentSafetyConfig = field(default_factory=ContentSafetyConfig)
|
||||
t2i: bool = True
|
||||
dump_history_interval: int = 10
|
||||
admins_id: List[str] = field(default_factory=list)
|
||||
https_proxy: str = ""
|
||||
http_proxy: str = ""
|
||||
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
|
||||
platform: List[PlatformConfig] = field(default_factory=list)
|
||||
wake_prefix: List[str] = field(default_factory=list)
|
||||
log_level: str = "INFO"
|
||||
t2i_endpoint: str = ""
|
||||
|
||||
class CmdConfig():
|
||||
def __init__(self) -> None:
|
||||
self.cached_config: dict = {}
|
||||
self.init_configs()
|
||||
|
||||
# compability
|
||||
if isinstance(self.wake_prefix, str):
|
||||
self.wake_prefix = [self.wake_prefix]
|
||||
|
||||
if len(self.wake_prefix) == 0:
|
||||
self.wake_prefix.append("/")
|
||||
|
||||
def load_from_dict(self, data: Dict):
|
||||
'''从字典中加载配置到对象。
|
||||
|
||||
@note: 适用于 version 2 配置文件。
|
||||
'''
|
||||
self.config_version=data.get("version", 2)
|
||||
self.platform=[]
|
||||
for p in data.get("platform", []):
|
||||
if 'name' not in p:
|
||||
logger.warning("A platform config missing name, skipping.")
|
||||
continue
|
||||
if p["name"] == "qq_official":
|
||||
self.platform.append(QQOfficialPlatformConfig(**p))
|
||||
elif p["name"] == "nakuru":
|
||||
self.platform.append(NakuruPlatformConfig(**p))
|
||||
elif p["name"] == "aiocqhttp":
|
||||
self.platform.append(AiocqhttpPlatformConfig(**p))
|
||||
else:
|
||||
self.platform.append(PlatformConfig(**p))
|
||||
self.platform_settings=PlatformSettings(**data.get("platform_settings", {}))
|
||||
self.llm=[LLMConfig(**l) for l in data.get("llm", [])]
|
||||
self.llm_settings=LLMSettings(**data.get("llm_settings", {}))
|
||||
self.content_safety=ContentSafetyConfig(**data.get("content_safety", {}))
|
||||
self.t2i=data.get("t2i", True)
|
||||
self.dump_history_interval=data.get("dump_history_interval", 10)
|
||||
self.admins_id=data.get("admins_id", [])
|
||||
self.https_proxy=data.get("https_proxy", "")
|
||||
self.http_proxy=data.get("http_proxy", "")
|
||||
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
|
||||
self.wake_prefix=data.get("wake_prefix", ["/"])
|
||||
self.log_level=data.get("log_level", "INFO")
|
||||
self.t2i_endpoint=data.get("t2i_endpoint", "")
|
||||
|
||||
def migrate_config_1_2(self, old: dict) -> dict:
|
||||
'''将配置文件从版本 1 迁移至版本 2'''
|
||||
logger.info("正在更新配置文件到 version 2...")
|
||||
new_config = DEFAULT_CONFIG_VERSION_2
|
||||
mappings = MAPPINGS_1_2
|
||||
|
||||
def set_nested_value(d, keys, value):
|
||||
cursor = d
|
||||
for key in keys[:-1]:
|
||||
cursor = cursor[key]
|
||||
cursor[keys[-1]] = value
|
||||
|
||||
for old_path, new_path in mappings:
|
||||
value = old
|
||||
try:
|
||||
for key in old_path:
|
||||
value = value[key] # soooooo convenient!!
|
||||
set_nested_value(new_config, new_path, value)
|
||||
except KeyError:
|
||||
# 如果旧配置中没有这个键,跳过,即使用新配置的默认值
|
||||
continue
|
||||
|
||||
logger.info("配置文件更新完成。")
|
||||
return new_config
|
||||
|
||||
def flush_config(self, config: dict = None):
|
||||
'''将配置写入文件, 如果没有传入配置,则写入默认配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def save_config(self):
|
||||
'''将现存配置写入文件'''
|
||||
self.flush_config(self.to_dict())
|
||||
|
||||
def init_configs(self):
|
||||
'''
|
||||
初始化必需的配置项
|
||||
'''
|
||||
self.init_config_items(DEFAULT_CONFIG)
|
||||
'''初始化必需的配置项'''
|
||||
config = None
|
||||
|
||||
if not self.check_exist():
|
||||
self.flush_config()
|
||||
config = DEFAULT_CONFIG_VERSION_2
|
||||
else:
|
||||
config = self.get_all()
|
||||
# check if the config is outdated
|
||||
if 'config_version' not in config: # version 1
|
||||
config = self.migrate_config_1_2(config)
|
||||
self.flush_config(config)
|
||||
|
||||
# 加载配置到对象
|
||||
self.load_from_dict(config)
|
||||
# 保存到文件
|
||||
# 这一步操作是为了保证配置文件中的字段的完整性。
|
||||
# 在版本变动新增配置项时,将对象中新增的配置项的默认值写入文件。
|
||||
self.save_config()
|
||||
|
||||
@staticmethod
|
||||
def get(key, default=None):
|
||||
'''
|
||||
从文件系统中直接获取配置
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
def get(self, key: str, default=None):
|
||||
'''从文件系统中直接获取配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
if key in d:
|
||||
return d[key]
|
||||
@@ -35,58 +246,49 @@ class CmdConfig():
|
||||
return default
|
||||
|
||||
def get_all(self):
|
||||
'''
|
||||
从文件系统中获取所有配置
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
'''从文件系统中获取所有配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
if not conf_str:
|
||||
return {}
|
||||
conf = json.loads(conf_str)
|
||||
return conf
|
||||
|
||||
def put(self, key, value):
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
d[key] = value
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self)
|
||||
|
||||
self.cached_config[key] = value
|
||||
def check_exist(self) -> bool:
|
||||
return os.path.exists(ASTRBOT_CONFIG_PATH)
|
||||
|
||||
def try_migrate():
|
||||
'''
|
||||
- 将 cmd_config.json 迁移至 data/cmd_config.json (如果存在)
|
||||
- 将 addons/plugins 迁移至 data/plugins (如果存在)
|
||||
'''
|
||||
if os.path.exists("cmd_config.json") and not os.path.exists("data/cmd_config.json"):
|
||||
try:
|
||||
shutil.move("cmd_config.json", "data/cmd_config.json")
|
||||
except:
|
||||
logger.error("迁移 cmd_config.json 失败。")
|
||||
|
||||
@staticmethod
|
||||
def put_by_dot_str(key: str, value):
|
||||
'''
|
||||
根据点分割的字符串,将值写入配置文件
|
||||
'''
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
_d = d
|
||||
_ks = key.split(".")
|
||||
for i in range(len(_ks)):
|
||||
if i == len(_ks) - 1:
|
||||
_d[_ks[i]] = value
|
||||
else:
|
||||
_d = _d[_ks[i]]
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def init_config_items(self, d: dict):
|
||||
conf = self.get_all()
|
||||
|
||||
if not self.cached_config:
|
||||
self.cached_config = conf
|
||||
|
||||
_tag = False
|
||||
|
||||
for key, val in d.items():
|
||||
if key not in conf:
|
||||
conf[key] = val
|
||||
_tag = True
|
||||
if _tag:
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(conf, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
if os.path.exists("addons/plugins"):
|
||||
if os.path.exists("data/plugins"):
|
||||
try:
|
||||
shutil.rmtree("data/plugins", onerror=on_error)
|
||||
except:
|
||||
logger.error("删除 data/plugins 失败。")
|
||||
try:
|
||||
shutil.move("addons/plugins", "data/")
|
||||
shutil.rmtree("addons", onerror=on_error)
|
||||
except:
|
||||
logger.error("迁移 addons/plugins 失败。")
|
||||
@@ -1,16 +0,0 @@
|
||||
import json, os
|
||||
from util.cmd_config import CmdConfig
|
||||
|
||||
def try_migrate_config():
|
||||
'''
|
||||
将 cmd_config.json 迁移至 data/cmd_config.json (如果存在的话)
|
||||
'''
|
||||
if os.path.exists("cmd_config.json"):
|
||||
with open("cmd_config.json", "r", encoding="utf-8-sig") as f:
|
||||
data = json.load(f)
|
||||
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
try:
|
||||
os.remove("cmd_config.json")
|
||||
except Exception as e:
|
||||
pass
|
||||
27
util/io.py
27
util/io.py
@@ -4,10 +4,9 @@ import shutil
|
||||
import socket
|
||||
import time
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from PIL import Image
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
@@ -47,13 +46,11 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
|
||||
|
||||
def save_temp_img(img: Image) -> str:
|
||||
if not os.path.exists("temp"):
|
||||
os.makedirs("temp")
|
||||
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
# 获得文件创建时间,清除超过1小时的
|
||||
try:
|
||||
for f in os.listdir("temp"):
|
||||
path = os.path.join("temp", f)
|
||||
for f in os.listdir("data/temp"):
|
||||
path = os.path.join("data/temp", f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600:
|
||||
@@ -63,7 +60,7 @@ def save_temp_img(img: Image) -> str:
|
||||
|
||||
# 获得时间戳
|
||||
timestamp = int(time.time())
|
||||
p = f"temp/{timestamp}.jpg"
|
||||
p = f"data/temp/{timestamp}.jpg"
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(p)
|
||||
@@ -101,16 +98,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def download_file(url: str, path: str):
|
||||
async def download_file(url: str, path: str):
|
||||
'''
|
||||
从指定 url 下载文件到指定路径 path
|
||||
'''
|
||||
try:
|
||||
logger.info(f"下载文件: {url}")
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(path, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
51
util/log.py
Normal file
51
util/log.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import logging, asyncio, colorlog
|
||||
from type.cached_queue import CachedQueue
|
||||
|
||||
log_color_config = {
|
||||
'DEBUG': 'bold_blue', 'INFO': 'bold_cyan',
|
||||
'WARNING': 'bold_yellow', 'ERROR': 'red',
|
||||
'CRITICAL': 'bold_red', 'RESET': 'reset',
|
||||
'asctime': 'green'
|
||||
}
|
||||
|
||||
class LogQueueHandler(logging.Handler):
|
||||
def __init__(self, log_queue: CachedQueue):
|
||||
super().__init__()
|
||||
self.log_queue = log_queue
|
||||
|
||||
def emit(self, record):
|
||||
log_entry = self.format(record)
|
||||
try:
|
||||
self.log_queue.put_nowait(log_entry)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
class LogManager:
|
||||
|
||||
@classmethod
|
||||
def GetLogger(cls, log_name: str = 'default'):
|
||||
logger = logging.getLogger(log_name)
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_formatter = colorlog.ColoredFormatter(
|
||||
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(funcName)s|%(filename)s:%(lineno)d]: %(message)s %(reset)s',
|
||||
datefmt='%H:%M:%S',
|
||||
log_colors=log_color_config
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
@classmethod
|
||||
def set_queue_handler(cls, logger: logging.Logger, log_queue: CachedQueue):
|
||||
handler = LogQueueHandler(log_queue)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
if logger.handlers:
|
||||
handler.setFormatter(logger.handlers[0].formatter)
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
logger.addHandler(handler)
|
||||
@@ -1,23 +1,29 @@
|
||||
import asyncio
|
||||
import requests
|
||||
import aiohttp
|
||||
import json
|
||||
import sys
|
||||
import logging
|
||||
|
||||
from astrbot.db import BaseDatabase
|
||||
from type.types import Context
|
||||
from collections import defaultdict
|
||||
from type.config import VERSION
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
class MetricUploader():
|
||||
def __init__(self, context: Context) -> None:
|
||||
def __init__(self, context: Context, db_helper: BaseDatabase) -> None:
|
||||
self.platform_stats = {}
|
||||
self.llm_stats = {}
|
||||
self.plugin_stats = {}
|
||||
self.command_stats = defaultdict(int)
|
||||
self.context = context
|
||||
|
||||
self.db_helper = db_helper
|
||||
|
||||
async def upload_metrics(self):
|
||||
'''
|
||||
上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。
|
||||
|
||||
|
||||
这些数据包含:
|
||||
- AstrBot 版本
|
||||
- OS 版本
|
||||
@@ -25,7 +31,7 @@ class MetricUploader():
|
||||
- LLM 模型名称、调用次数
|
||||
- 加载的插件的元数据
|
||||
'''
|
||||
await asyncio.sleep(10)
|
||||
await asyncio.sleep(30)
|
||||
context = self.context
|
||||
while True:
|
||||
for llm in context.llms:
|
||||
@@ -33,7 +39,7 @@ class MetricUploader():
|
||||
for k in stat:
|
||||
self.llm_stats[llm.llm_name + "#" + k] = stat[k]
|
||||
llm.llm_instance.reset_model_stat()
|
||||
|
||||
|
||||
for plugin in context.cached_plugins:
|
||||
self.plugin_stats[plugin.metadata.plugin_name] = {
|
||||
"metadata": {
|
||||
@@ -45,32 +51,40 @@ class MetricUploader():
|
||||
"repo": plugin.metadata.repo,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
res = {
|
||||
"stat_version": "moon",
|
||||
"version": VERSION, # 版本号
|
||||
"platform_stats": self.platform_stats, # 过去 30 分钟各消息平台交互消息数
|
||||
"llm_stats": self.llm_stats,
|
||||
"plugin_stats": self.plugin_stats,
|
||||
"command_stats": self.command_stats,
|
||||
"sys": sys.platform, # 系统版本
|
||||
}
|
||||
|
||||
try:
|
||||
res = {
|
||||
"stat_version": "moon",
|
||||
"version": context.version, # 版本号
|
||||
"platform_stats": self.platform_stats, # 过去 30 分钟各消息平台交互消息数
|
||||
"llm_stats": self.llm_stats,
|
||||
"plugin_stats": self.plugin_stats,
|
||||
"command_stats": self.command_stats,
|
||||
"sys": sys.platform, # 系统版本
|
||||
}
|
||||
resp = requests.post(
|
||||
'https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
|
||||
if resp.status_code == 200:
|
||||
ok = resp.json()
|
||||
if ok['status'] == 'ok':
|
||||
self.clear()
|
||||
self.db_helper.insert_base_metrics(res)
|
||||
except BaseException as e:
|
||||
logger.debug("指标数据保存到数据库失败: " + str(e))
|
||||
await asyncio.sleep(30*60)
|
||||
continue
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=10) as resp:
|
||||
pass
|
||||
except BaseException as e:
|
||||
pass
|
||||
|
||||
self.clear()
|
||||
await asyncio.sleep(30*60)
|
||||
|
||||
|
||||
def increment_platform_stat(self, platform_name: str):
|
||||
self.platform_stats[platform_name] = self.platform_stats.get(platform_name, 0) + 1
|
||||
self.platform_stats[platform_name] = self.platform_stats.get(
|
||||
platform_name, 0) + 1
|
||||
|
||||
def clear(self):
|
||||
self.platform_stats.clear()
|
||||
self.llm_stats.clear()
|
||||
self.plugin_stats.clear()
|
||||
self.command_stats.clear()
|
||||
self.command_stats.clear()
|
||||
|
||||
7
util/plugin_dev/api/v1/__init__.py
Normal file
7
util/plugin_dev/api/v1/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .bot import *
|
||||
from .config import *
|
||||
from .llm import *
|
||||
from .message import *
|
||||
from .platform import *
|
||||
from .register import *
|
||||
from .types import *
|
||||
@@ -7,6 +7,6 @@ Platform类是消息平台的抽象类,定义了消息平台的基本接口。
|
||||
|
||||
from model.platform import Platform
|
||||
|
||||
from model.platform.qq_nakuru import QQGOCQ
|
||||
from model.platform.qq_nakuru import QQNakuru
|
||||
from model.platform.qq_official import QQOfficial
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user