Merge pull request #163 from Soulter/stat-upload-perf

优化统计记录数据结构
This commit is contained in:
Soulter
2024-05-25 18:28:08 +08:00
committed by GitHub
12 changed files with 98 additions and 32 deletions

View File

@@ -22,6 +22,7 @@ from util.cmd_config import init_astrbot_config_items
from type.types import GlobalObject
from type.register import *
from type.message import AstrBotMessage
from type.config import *
from addons.dashboard.helper import DashBoardHelper
from addons.dashboard.server import DashBoardData
from persist.session import dbConn
@@ -38,9 +39,6 @@ frequency_time = 60
# 计数默认值
frequency_count = 10
# 版本
version = '3.1.13'
# 语言模型
OPENAI_OFFICIAL = 'openai_official'
NONE_LLM = 'none_llm'
@@ -61,8 +59,6 @@ init_astrbot_config_items()
# 全局对象
_global_object: GlobalObject = None
# 语言模型选择
def privider_chooser(cfg):
l = []
@@ -70,13 +66,10 @@ def privider_chooser(cfg):
l.append('openai_official')
return l
'''
初始化机器人
'''
def init():
'''
初始化机器人
'''
global llm_instance, llm_command_instance
global baidu_judge, chosen_provider
global frequency_count, frequency_time
@@ -92,9 +85,9 @@ def init():
# 初始化 global_object
_global_object = GlobalObject()
_global_object.version = version
_global_object.version = VERSION
_global_object.base_config = cfg
logger.info("AstrBot v"+version)
logger.info("AstrBot v" + VERSION)
if 'reply_prefix' in cfg:
# 适配旧版配置
@@ -319,7 +312,6 @@ async def record_message(platform: str, session_id: str):
db_inst.increment_stat_session(platform, session_id, 1)
db_inst.increment_stat_message(curr_ts, 1)
db_inst.increment_stat_platform(curr_ts, platform, 1)
_global_object.cnt_total += 1
async def oper_msg(message: AstrBotMessage,

View File

@@ -73,6 +73,7 @@ class Command:
else:
raise TypeError("插件返回值格式错误。")
if hit:
plugin.trig()
logger.debug("hit plugin: " + plugin.metadata.plugin_name)
return True, res
except TypeError as e:

View File

@@ -14,34 +14,40 @@ class Platform():
初始化平台的各种接口
'''
self.message_handler = message_handler
self.cnt_receive = 0
self.cnt_reply = 0
pass
@abc.abstractmethod
async def handle_msg():
async def handle_msg(self):
'''
处理到来的消息
'''
self.cnt_receive += 1
pass
@abc.abstractmethod
async def reply_msg():
async def reply_msg(self):
'''
回复消息(被动发送)
'''
self.cnt_reply += 1
pass
@abc.abstractmethod
async def send_msg(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
async def send_msg(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
'''
发送消息(主动发送)
'''
self.cnt_reply += 1
pass
@abc.abstractmethod
async def send(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
async def send(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
'''
发送消息(主动发送)同 send_msg()
'''
self.cnt_reply += 1
pass
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str:

View File

@@ -104,6 +104,7 @@ class QQGOCQ(Platform):
self.client.run()
async def handle_msg(self, message: AstrBotMessage):
await super().handle_msg()
logger.info(
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
@@ -176,6 +177,7 @@ class QQGOCQ(Platform):
async def reply_msg(self,
message: Union[AstrBotMessage, GuildMessage, GroupMessage, FriendMessage],
result_message: list):
await super().reply_msg()
"""
插件开发者请使用send方法, 可以不用直接调用这个方法。
"""
@@ -254,6 +256,7 @@ class QQGOCQ(Platform):
提供给插件的发送QQ消息接口。
参数说明第一个参数可以是消息对象也可以是QQ群号。第二个参数是消息内容消息内容可以是消息链列表也可以是纯文字信息
'''
await super().reply_msg()
try:
await self.reply_msg(message, result_message)
except BaseException as e:
@@ -265,6 +268,7 @@ class QQGOCQ(Platform):
'''
同 send_msg()
'''
await super().reply_msg()
await self.reply_msg(to, res)
def create_text_image(title: str, text: str, max_width=30, font_size=20):

View File

@@ -102,6 +102,7 @@ class QQOfficial(Platform):
)
async def handle_msg(self, message: AstrBotMessage):
await super().handle_msg()
assert isinstance(message.raw_message, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage))
is_group = message.type != MessageType.FRIEND_MESSAGE
@@ -154,6 +155,7 @@ class QQOfficial(Platform):
'''
回复频道消息
'''
await super().reply_msg()
if isinstance(message, AstrBotMessage):
source = message.raw_message
else:

View File

@@ -73,6 +73,7 @@ class ProviderOpenAIOfficial(Provider):
base_url=self.base_url
)
self.model_configs: Dict = cfg['chatGPTConfigs']
super().set_curr_model(self.model_configs['model'])
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
@@ -289,6 +290,7 @@ class ProviderOpenAIOfficial(Provider):
extra_conf: Dict = None,
**kwargs
) -> str:
super().accu_model_stat()
if not session_id:
session_id = "unknown"
if "unknown" in self.session_memory:
@@ -421,6 +423,7 @@ class ProviderOpenAIOfficial(Provider):
'''
retry = 0
conf = self.image_generator_model_configs
super().accu_model_stat(model=conf['model'])
if not conf:
logger.error("OpenAI 图片生成模型配置不存在。")
raise Exception("OpenAI 图片生成模型配置不存在。")
@@ -481,6 +484,7 @@ class ProviderOpenAIOfficial(Provider):
def set_model(self, model: str):
self.model_configs['model'] = model
super().set_curr_model(model)
def get_configs(self):
return self.model_configs

View File

@@ -1,4 +1,27 @@
from collections import defaultdict
class Provider:
def __init__(self) -> None:
self.model_stat = defaultdict(int) # 用于记录 LLM Model 使用数据
self.curr_model_name = "unknown"
def reset_model_stat(self):
self.model_stat.clear()
def set_curr_model(self, model_name: str):
self.curr_model_name = model_name
def get_curr_model(self):
'''
返回当前正在使用的 LLM
'''
return self.curr_model_name
def accu_model_stat(self, model: str = None):
if not model:
model = self.get_curr_model()
self.model_stat[model] += 1
async def text_chat(self,
prompt: str,
session_id: str,
@@ -18,7 +41,7 @@ class Provider:
extra_conf: 额外配置
default_personality: 默认人格
'''
raise NotImplementedError
raise NotImplementedError()
async def image_generate(self, prompt, session_id, **kwargs) -> str:
'''
@@ -26,10 +49,10 @@ class Provider:
prompt: 提示词
session_id: 会话id
'''
raise NotImplementedError
raise NotImplementedError()
async def forget(self, session_id=None) -> bool:
'''
重置会话
'''
raise NotImplementedError
raise NotImplementedError()

1
type/config.py Normal file
View File

@@ -0,0 +1 @@
VERSION = '3.1.13'

View File

@@ -15,6 +15,13 @@ class RegisteredPlugin:
module_path: str
module: ModuleType
root_dir_name: str
trig_cnt: int = 0
def reset_trig_cnt(self):
self.trig_cnt = 0
def trig(self):
self.trig_cnt += 1
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"

View File

@@ -15,7 +15,6 @@ class GlobalObject:
web_search: bool # 是否开启了网页搜索
reply_prefix: str # 回复前缀
unique_session: bool # 是否开启了独立会话
cnt_total: int # 总消息数
default_personality: dict
dashboard_data = None
@@ -26,7 +25,6 @@ class GlobalObject:
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.unique_session = False
self.cnt_total = 0
self.platforms = []
self.llms = []
self.default_personality = None

View File

@@ -15,6 +15,7 @@ from PIL import Image, ImageDraw, ImageFont
from type.types import GlobalObject
from SparkleLogging.utils.core import LogManager
from logging import Logger
from collections import defaultdict
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
@@ -466,15 +467,41 @@ def get_sys_info(global_object: GlobalObject):
def upload(_global_object: GlobalObject):
'''
上传相关非敏感统计数据
'''
time.sleep(10)
while True:
addr_ip = ''
platform_stats = {}
llm_stats = {}
plugin_stats = {}
for platform in _global_object.platforms:
platform_stats[platform.platform_name] = {
"cnt_receive": platform.platform_instance.cnt_receive,
"cnt_reply": platform.platform_instance.cnt_reply
}
for llm in _global_object.llms:
stat = llm.llm_instance.model_stat
for k in stat:
llm_stats[llm.llm_name + "#" + k] = stat[k]
llm.llm_instance.reset_model_stat()
for plugin in _global_object.cached_plugins:
plugin_stats[plugin.metadata.plugin_name] = {
"metadata": plugin.metadata,
"trig_cnt": plugin.trig_cnt
}
plugin.reset_trig_cnt()
try:
res = {
"version": _global_object.version,
"count": _global_object.cnt_total,
"ip": addr_ip,
"sys": sys.platform,
"admin": "null",
"stat_version": "moon",
"version": _global_object.version, # 版本号
"platform_stats": platform_stats, # 过去 30 分钟各消息平台交互消息数
"llm_stats": llm_stats,
"plugin_stats": plugin_stats,
"sys": sys.platform, # 系统版本
}
resp = requests.post(
'https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
@@ -484,7 +511,7 @@ def upload(_global_object: GlobalObject):
_global_object.cnt_total = 0
except BaseException as e:
pass
time.sleep(10*60)
time.sleep(30*60)
def retry(n: int = 3):
'''

View File

@@ -6,6 +6,7 @@ except BaseException as e:
has_git = False
import sys, os
import requests
from type.config import VERSION
def _reboot():
py = sys.executable
@@ -78,11 +79,11 @@ def check_update() -> str:
print(f"当前版本: {curr_commit}")
print(f"最新版本: {new_commit}")
if curr_commit.startswith(new_commit):
return "当前已经是最新版本"
return f"当前已经是最新版本: v{VERSION}"
else:
update_info = f"""有新版本可用。
=== 当前版本 ===
{curr_commit}
v{VERSION}
=== 新版本 ===
{update_data[0]['version']}