@@ -22,6 +22,7 @@ from util.cmd_config import init_astrbot_config_items
|
||||
from type.types import GlobalObject
|
||||
from type.register import *
|
||||
from type.message import AstrBotMessage
|
||||
from type.config import *
|
||||
from addons.dashboard.helper import DashBoardHelper
|
||||
from addons.dashboard.server import DashBoardData
|
||||
from persist.session import dbConn
|
||||
@@ -38,9 +39,6 @@ frequency_time = 60
|
||||
# 计数默认值
|
||||
frequency_count = 10
|
||||
|
||||
# 版本
|
||||
version = '3.1.13'
|
||||
|
||||
# 语言模型
|
||||
OPENAI_OFFICIAL = 'openai_official'
|
||||
NONE_LLM = 'none_llm'
|
||||
@@ -61,8 +59,6 @@ init_astrbot_config_items()
|
||||
# 全局对象
|
||||
_global_object: GlobalObject = None
|
||||
|
||||
# 语言模型选择
|
||||
|
||||
|
||||
def privider_chooser(cfg):
|
||||
l = []
|
||||
@@ -70,13 +66,10 @@ def privider_chooser(cfg):
|
||||
l.append('openai_official')
|
||||
return l
|
||||
|
||||
|
||||
'''
|
||||
初始化机器人
|
||||
'''
|
||||
|
||||
|
||||
def init():
|
||||
'''
|
||||
初始化机器人
|
||||
'''
|
||||
global llm_instance, llm_command_instance
|
||||
global baidu_judge, chosen_provider
|
||||
global frequency_count, frequency_time
|
||||
@@ -92,9 +85,9 @@ def init():
|
||||
|
||||
# 初始化 global_object
|
||||
_global_object = GlobalObject()
|
||||
_global_object.version = version
|
||||
_global_object.version = VERSION
|
||||
_global_object.base_config = cfg
|
||||
logger.info("AstrBot v"+version)
|
||||
logger.info("AstrBot v" + VERSION)
|
||||
|
||||
if 'reply_prefix' in cfg:
|
||||
# 适配旧版配置
|
||||
@@ -319,7 +312,6 @@ async def record_message(platform: str, session_id: str):
|
||||
db_inst.increment_stat_session(platform, session_id, 1)
|
||||
db_inst.increment_stat_message(curr_ts, 1)
|
||||
db_inst.increment_stat_platform(curr_ts, platform, 1)
|
||||
_global_object.cnt_total += 1
|
||||
|
||||
|
||||
async def oper_msg(message: AstrBotMessage,
|
||||
|
||||
@@ -73,6 +73,7 @@ class Command:
|
||||
else:
|
||||
raise TypeError("插件返回值格式错误。")
|
||||
if hit:
|
||||
plugin.trig()
|
||||
logger.debug("hit plugin: " + plugin.metadata.plugin_name)
|
||||
return True, res
|
||||
except TypeError as e:
|
||||
|
||||
@@ -14,34 +14,40 @@ class Platform():
|
||||
初始化平台的各种接口
|
||||
'''
|
||||
self.message_handler = message_handler
|
||||
self.cnt_receive = 0
|
||||
self.cnt_reply = 0
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle_msg():
|
||||
async def handle_msg(self):
|
||||
'''
|
||||
处理到来的消息
|
||||
'''
|
||||
self.cnt_receive += 1
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def reply_msg():
|
||||
async def reply_msg(self):
|
||||
'''
|
||||
回复消息(被动发送)
|
||||
'''
|
||||
self.cnt_reply += 1
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send_msg(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
|
||||
async def send_msg(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
|
||||
'''
|
||||
发送消息(主动发送)
|
||||
'''
|
||||
self.cnt_reply += 1
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
|
||||
async def send(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
|
||||
'''
|
||||
发送消息(主动发送)同 send_msg()
|
||||
'''
|
||||
self.cnt_reply += 1
|
||||
pass
|
||||
|
||||
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str:
|
||||
|
||||
@@ -104,6 +104,7 @@ class QQGOCQ(Platform):
|
||||
self.client.run()
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
await super().handle_msg()
|
||||
logger.info(
|
||||
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
|
||||
|
||||
@@ -176,6 +177,7 @@ class QQGOCQ(Platform):
|
||||
async def reply_msg(self,
|
||||
message: Union[AstrBotMessage, GuildMessage, GroupMessage, FriendMessage],
|
||||
result_message: list):
|
||||
await super().reply_msg()
|
||||
"""
|
||||
插件开发者请使用send方法, 可以不用直接调用这个方法。
|
||||
"""
|
||||
@@ -254,6 +256,7 @@ class QQGOCQ(Platform):
|
||||
提供给插件的发送QQ消息接口。
|
||||
参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。
|
||||
'''
|
||||
await super().reply_msg()
|
||||
try:
|
||||
await self.reply_msg(message, result_message)
|
||||
except BaseException as e:
|
||||
@@ -265,6 +268,7 @@ class QQGOCQ(Platform):
|
||||
'''
|
||||
同 send_msg()
|
||||
'''
|
||||
await super().reply_msg()
|
||||
await self.reply_msg(to, res)
|
||||
|
||||
def create_text_image(title: str, text: str, max_width=30, font_size=20):
|
||||
|
||||
@@ -102,6 +102,7 @@ class QQOfficial(Platform):
|
||||
)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
await super().handle_msg()
|
||||
assert isinstance(message.raw_message, (botpy.message.Message,
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage))
|
||||
is_group = message.type != MessageType.FRIEND_MESSAGE
|
||||
@@ -154,6 +155,7 @@ class QQOfficial(Platform):
|
||||
'''
|
||||
回复频道消息
|
||||
'''
|
||||
await super().reply_msg()
|
||||
if isinstance(message, AstrBotMessage):
|
||||
source = message.raw_message
|
||||
else:
|
||||
|
||||
@@ -73,6 +73,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
base_url=self.base_url
|
||||
)
|
||||
self.model_configs: Dict = cfg['chatGPTConfigs']
|
||||
super().set_curr_model(self.model_configs['model'])
|
||||
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
|
||||
self.session_memory: Dict[str, List] = {} # 会话记忆
|
||||
self.session_memory_lock = threading.Lock()
|
||||
@@ -289,6 +290,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
extra_conf: Dict = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
super().accu_model_stat()
|
||||
if not session_id:
|
||||
session_id = "unknown"
|
||||
if "unknown" in self.session_memory:
|
||||
@@ -421,6 +423,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
retry = 0
|
||||
conf = self.image_generator_model_configs
|
||||
super().accu_model_stat(model=conf['model'])
|
||||
if not conf:
|
||||
logger.error("OpenAI 图片生成模型配置不存在。")
|
||||
raise Exception("OpenAI 图片生成模型配置不存在。")
|
||||
@@ -481,6 +484,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
def set_model(self, model: str):
|
||||
self.model_configs['model'] = model
|
||||
super().set_curr_model(model)
|
||||
|
||||
def get_configs(self):
|
||||
return self.model_configs
|
||||
|
||||
@@ -1,4 +1,27 @@
|
||||
from collections import defaultdict
|
||||
|
||||
class Provider:
|
||||
def __init__(self) -> None:
|
||||
self.model_stat = defaultdict(int) # 用于记录 LLM Model 使用数据
|
||||
self.curr_model_name = "unknown"
|
||||
|
||||
def reset_model_stat(self):
|
||||
self.model_stat.clear()
|
||||
|
||||
def set_curr_model(self, model_name: str):
|
||||
self.curr_model_name = model_name
|
||||
|
||||
def get_curr_model(self):
|
||||
'''
|
||||
返回当前正在使用的 LLM
|
||||
'''
|
||||
return self.curr_model_name
|
||||
|
||||
def accu_model_stat(self, model: str = None):
|
||||
if not model:
|
||||
model = self.get_curr_model()
|
||||
self.model_stat[model] += 1
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
@@ -18,7 +41,7 @@ class Provider:
|
||||
extra_conf: 额外配置
|
||||
default_personality: 默认人格
|
||||
'''
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
async def image_generate(self, prompt, session_id, **kwargs) -> str:
|
||||
'''
|
||||
@@ -26,10 +49,10 @@ class Provider:
|
||||
prompt: 提示词
|
||||
session_id: 会话id
|
||||
'''
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
async def forget(self, session_id=None) -> bool:
|
||||
'''
|
||||
重置会话
|
||||
'''
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
1
type/config.py
Normal file
1
type/config.py
Normal file
@@ -0,0 +1 @@
|
||||
VERSION = '3.1.13'
|
||||
@@ -15,6 +15,13 @@ class RegisteredPlugin:
|
||||
module_path: str
|
||||
module: ModuleType
|
||||
root_dir_name: str
|
||||
trig_cnt: int = 0
|
||||
|
||||
def reset_trig_cnt(self):
|
||||
self.trig_cnt = 0
|
||||
|
||||
def trig(self):
|
||||
self.trig_cnt += 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
|
||||
|
||||
@@ -15,7 +15,6 @@ class GlobalObject:
|
||||
web_search: bool # 是否开启了网页搜索
|
||||
reply_prefix: str # 回复前缀
|
||||
unique_session: bool # 是否开启了独立会话
|
||||
cnt_total: int # 总消息数
|
||||
default_personality: dict
|
||||
dashboard_data = None
|
||||
|
||||
@@ -26,7 +25,6 @@ class GlobalObject:
|
||||
self.web_search = False # 是否开启了网页搜索
|
||||
self.reply_prefix = None
|
||||
self.unique_session = False
|
||||
self.cnt_total = 0
|
||||
self.platforms = []
|
||||
self.llms = []
|
||||
self.default_personality = None
|
||||
|
||||
@@ -15,6 +15,7 @@ from PIL import Image, ImageDraw, ImageFont
|
||||
from type.types import GlobalObject
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from collections import defaultdict
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||
|
||||
@@ -466,15 +467,41 @@ def get_sys_info(global_object: GlobalObject):
|
||||
|
||||
|
||||
def upload(_global_object: GlobalObject):
|
||||
'''
|
||||
上传相关非敏感统计数据
|
||||
'''
|
||||
time.sleep(10)
|
||||
while True:
|
||||
addr_ip = ''
|
||||
platform_stats = {}
|
||||
llm_stats = {}
|
||||
plugin_stats = {}
|
||||
for platform in _global_object.platforms:
|
||||
platform_stats[platform.platform_name] = {
|
||||
"cnt_receive": platform.platform_instance.cnt_receive,
|
||||
"cnt_reply": platform.platform_instance.cnt_reply
|
||||
}
|
||||
|
||||
for llm in _global_object.llms:
|
||||
stat = llm.llm_instance.model_stat
|
||||
for k in stat:
|
||||
llm_stats[llm.llm_name + "#" + k] = stat[k]
|
||||
llm.llm_instance.reset_model_stat()
|
||||
|
||||
for plugin in _global_object.cached_plugins:
|
||||
plugin_stats[plugin.metadata.plugin_name] = {
|
||||
"metadata": plugin.metadata,
|
||||
"trig_cnt": plugin.trig_cnt
|
||||
}
|
||||
plugin.reset_trig_cnt()
|
||||
|
||||
try:
|
||||
res = {
|
||||
"version": _global_object.version,
|
||||
"count": _global_object.cnt_total,
|
||||
"ip": addr_ip,
|
||||
"sys": sys.platform,
|
||||
"admin": "null",
|
||||
"stat_version": "moon",
|
||||
"version": _global_object.version, # 版本号
|
||||
"platform_stats": platform_stats, # 过去 30 分钟各消息平台交互消息数
|
||||
"llm_stats": llm_stats,
|
||||
"plugin_stats": plugin_stats,
|
||||
"sys": sys.platform, # 系统版本
|
||||
}
|
||||
resp = requests.post(
|
||||
'https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
|
||||
@@ -484,7 +511,7 @@ def upload(_global_object: GlobalObject):
|
||||
_global_object.cnt_total = 0
|
||||
except BaseException as e:
|
||||
pass
|
||||
time.sleep(10*60)
|
||||
time.sleep(30*60)
|
||||
|
||||
def retry(n: int = 3):
|
||||
'''
|
||||
|
||||
@@ -6,6 +6,7 @@ except BaseException as e:
|
||||
has_git = False
|
||||
import sys, os
|
||||
import requests
|
||||
from type.config import VERSION
|
||||
|
||||
def _reboot():
|
||||
py = sys.executable
|
||||
@@ -78,11 +79,11 @@ def check_update() -> str:
|
||||
print(f"当前版本: {curr_commit}")
|
||||
print(f"最新版本: {new_commit}")
|
||||
if curr_commit.startswith(new_commit):
|
||||
return "当前已经是最新版本。"
|
||||
return f"当前已经是最新版本: v{VERSION}"
|
||||
else:
|
||||
update_info = f"""有新版本可用。
|
||||
=== 当前版本 ===
|
||||
{curr_commit}
|
||||
v{VERSION}
|
||||
|
||||
=== 新版本 ===
|
||||
{update_data[0]['version']}
|
||||
|
||||
Reference in New Issue
Block a user