211 lines
5.9 KiB
Python
211 lines
5.9 KiB
Python
import sqlite3
|
|
import os
|
|
import time
|
|
from astrbot.core.db.po import (
|
|
Platform,
|
|
Command,
|
|
Provider,
|
|
Stats,
|
|
LLMHistory
|
|
)
|
|
from . import BaseDatabase
|
|
from typing import Tuple
|
|
|
|
|
|
class SQLiteDatabase(BaseDatabase):
|
|
def __init__(self, db_path: str) -> None:
|
|
super().__init__()
|
|
self.db_path = db_path
|
|
|
|
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
|
|
sql = f.read()
|
|
|
|
# 初始化数据库
|
|
self.conn = self._get_conn(self.db_path)
|
|
c = self.conn.cursor()
|
|
c.executescript(sql)
|
|
self.conn.commit()
|
|
|
|
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(self.db_path)
|
|
conn.text_factory = str
|
|
return conn
|
|
|
|
def _exec_sql(self, sql: str, params: Tuple = None):
|
|
conn = self.conn
|
|
try:
|
|
c = self.conn.cursor()
|
|
except sqlite3.ProgrammingError:
|
|
conn = self._get_conn(self.db_path)
|
|
c = conn.cursor()
|
|
|
|
if params:
|
|
c.execute(sql, params)
|
|
c.close()
|
|
else:
|
|
c.execute(sql)
|
|
c.close()
|
|
|
|
conn.commit()
|
|
|
|
def insert_platform_metrics(self, metrics: dict):
|
|
for k, v in metrics.items():
|
|
self._exec_sql(
|
|
'''
|
|
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
|
|
''', (k, v, int(time.time()))
|
|
)
|
|
|
|
def insert_plugin_metrics(self, metrics: dict):
|
|
pass
|
|
|
|
def insert_command_metrics(self, metrics: dict):
|
|
for k, v in metrics.items():
|
|
self._exec_sql(
|
|
'''
|
|
INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
|
|
''', (k, v, int(time.time()))
|
|
)
|
|
|
|
def insert_llm_metrics(self, metrics: dict):
|
|
for k, v in metrics.items():
|
|
self._exec_sql(
|
|
'''
|
|
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
|
|
''', (k, v, int(time.time()))
|
|
)
|
|
|
|
def update_llm_history(self, session_id: str, content: str):
|
|
res = self.get_llm_history(session_id)
|
|
if res:
|
|
self._exec_sql(
|
|
'''
|
|
UPDATE llm_history SET content = ? WHERE session_id = ?
|
|
''', (content, session_id)
|
|
)
|
|
else:
|
|
self._exec_sql(
|
|
'''
|
|
INSERT INTO llm_history(session_id, content) VALUES (?, ?)
|
|
''', (session_id, content)
|
|
)
|
|
|
|
def get_llm_history(self, session_id: str = None) -> Tuple:
|
|
try:
|
|
c = self.conn.cursor()
|
|
except sqlite3.ProgrammingError:
|
|
c = self._get_conn(self.db_path).cursor()
|
|
|
|
where_clause = "" if session_id is None else f"WHERE session_id = '{session_id}'"
|
|
c.execute(
|
|
'''
|
|
SELECT * FROM llm_history
|
|
''' + where_clause
|
|
)
|
|
res = c.fetchall()
|
|
histories = []
|
|
for row in res:
|
|
histories.append(LLMHistory(*row))
|
|
c.close()
|
|
return histories
|
|
|
|
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
|
'''获取 offset_sec 秒前到现在的基础统计数据'''
|
|
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
|
|
|
try:
|
|
c = self.conn.cursor()
|
|
except sqlite3.ProgrammingError:
|
|
c = self._get_conn(self.db_path).cursor()
|
|
|
|
c.execute(
|
|
'''
|
|
SELECT * FROM platform
|
|
''' + where_clause
|
|
)
|
|
|
|
platform = []
|
|
for row in c.fetchall():
|
|
platform.append(Platform(*row))
|
|
|
|
# c.execute(
|
|
# '''
|
|
# SELECT * FROM command
|
|
# ''' + where_clause
|
|
# )
|
|
|
|
# command = []
|
|
# for row in c.fetchall():
|
|
# command.append(Command(*row))
|
|
|
|
# c.execute(
|
|
# '''
|
|
# SELECT * FROM llm
|
|
# ''' + where_clause
|
|
# )
|
|
|
|
# llm = []
|
|
# for row in c.fetchall():
|
|
# llm.append(Provider(*row))
|
|
|
|
c.close()
|
|
|
|
return Stats(platform, [], [])
|
|
|
|
def get_total_message_count(self) -> int:
|
|
try:
|
|
c = self.conn.cursor()
|
|
except sqlite3.ProgrammingError:
|
|
c = self._get_conn(self.db_path).cursor()
|
|
|
|
c.execute(
|
|
'''
|
|
SELECT SUM(count) FROM platform
|
|
'''
|
|
)
|
|
res = c.fetchone()
|
|
c.close()
|
|
return res[0]
|
|
|
|
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
|
'''获取 offset_sec 秒前到现在的基础统计数据(合并)'''
|
|
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
|
|
|
try:
|
|
c = self.conn.cursor()
|
|
except sqlite3.ProgrammingError:
|
|
c = self._get_conn(self.db_path).cursor()
|
|
|
|
c.execute(
|
|
'''
|
|
SELECT name, SUM(count), timestamp FROM platform
|
|
''' + where_clause + " GROUP BY name"
|
|
)
|
|
|
|
platform = []
|
|
for row in c.fetchall():
|
|
platform.append(Platform(*row))
|
|
|
|
# c.execute(
|
|
# '''
|
|
# SELECT name, SUM(count), timestamp FROM command
|
|
# ''' + where_clause + " GROUP BY name"
|
|
# )
|
|
|
|
# command = []
|
|
# for row in c.fetchall():
|
|
# command.append(Command(*row))
|
|
|
|
# c.execute(
|
|
# '''
|
|
# SELECT name, SUM(count), timestamp FROM llm
|
|
# ''' + where_clause + " GROUP BY name"
|
|
# )
|
|
|
|
# llm = []
|
|
# for row in c.fetchall():
|
|
# llm.append(Provider(*row))
|
|
|
|
c.close()
|
|
|
|
return Stats(platform, [], []) |