import sqlite3 import time from dataclasses import dataclass from typing import Any from astrbot.core.db.po import Platform, Stats @dataclass class Conversation: """LLM 对话存储 对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 """ user_id: str cid: str history: str = "" """字符串格式的列表。""" created_at: int = 0 updated_at: int = 0 title: str = "" persona_id: str = "" INIT_SQL = """ CREATE TABLE IF NOT EXISTS platform( name VARCHAR(32), count INTEGER, timestamp INTEGER ); CREATE TABLE IF NOT EXISTS llm( name VARCHAR(32), count INTEGER, timestamp INTEGER ); CREATE TABLE IF NOT EXISTS plugin( name VARCHAR(32), count INTEGER, timestamp INTEGER ); CREATE TABLE IF NOT EXISTS command( name VARCHAR(32), count INTEGER, timestamp INTEGER ); CREATE TABLE IF NOT EXISTS llm_history( provider_type VARCHAR(32), session_id VARCHAR(32), content TEXT ); -- ATRI CREATE TABLE IF NOT EXISTS atri_vision( id TEXT, url_or_path TEXT, caption TEXT, is_meme BOOLEAN, keywords TEXT, platform_name VARCHAR(32), session_id VARCHAR(32), sender_nickname VARCHAR(32), timestamp INTEGER ); CREATE TABLE IF NOT EXISTS webchat_conversation( user_id TEXT, -- 会话 id cid TEXT, -- 对话 id history TEXT, created_at INTEGER, updated_at INTEGER, title TEXT, persona_id TEXT ); PRAGMA encoding = 'UTF-8'; """ class SQLiteDatabase: def __init__(self, db_path: str) -> None: super().__init__() self.db_path = db_path sql = INIT_SQL # 初始化数据库 self.conn = self._get_conn(self.db_path) c = self.conn.cursor() c.executescript(sql) self.conn.commit() # 检查 webchat_conversation 的 title 字段是否存在 c.execute( """ PRAGMA table_info(webchat_conversation) """, ) res = c.fetchall() has_title = False has_persona_id = False for row in res: if row[1] == "title": has_title = True if row[1] == "persona_id": has_persona_id = True if not has_title: c.execute( """ ALTER TABLE webchat_conversation ADD COLUMN title TEXT; """, ) self.conn.commit() if not has_persona_id: c.execute( """ ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; """, ) self.conn.commit() c.close() def _get_conn(self, db_path: str) -> sqlite3.Connection: conn = sqlite3.connect(self.db_path) conn.text_factory = str return conn def _exec_sql(self, sql: str, params: tuple = None): conn = self.conn try: c = self.conn.cursor() except sqlite3.ProgrammingError: conn = self._get_conn(self.db_path) c = conn.cursor() if params: c.execute(sql, params) c.close() else: c.execute(sql) c.close() conn.commit() def insert_platform_metrics(self, metrics: dict): for k, v in metrics.items(): self._exec_sql( """ INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?) """, (k, v, int(time.time())), ) def insert_llm_metrics(self, metrics: dict): for k, v in metrics.items(): self._exec_sql( """ INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?) """, (k, v, int(time.time())), ) def get_base_stats(self, offset_sec: int = 86400) -> Stats: """获取 offset_sec 秒前到现在的基础统计数据""" where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() c.execute( """ SELECT * FROM platform """ + where_clause, ) platform = [] for row in c.fetchall(): platform.append(Platform(*row)) c.close() return Stats(platform=platform) def get_total_message_count(self) -> int: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() c.execute( """ SELECT SUM(count) FROM platform """, ) res = c.fetchone() c.close() return res[0] def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: """获取 offset_sec 秒前到现在的基础统计数据(合并)""" where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() c.execute( """ SELECT name, SUM(count), timestamp FROM platform """ + where_clause + " GROUP BY name", ) platform = [] for row in c.fetchall(): platform.append(Platform(*row)) c.close() return Stats(platform, [], []) def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() c.execute( """ SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ? """, (user_id, cid), ) res = c.fetchone() c.close() if not res: return None return Conversation(*res) def new_conversation(self, user_id: str, cid: str): history = "[]" updated_at = int(time.time()) created_at = updated_at self._exec_sql( """ INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?) """, (user_id, cid, history, updated_at, created_at), ) def get_conversations(self, user_id: str) -> tuple: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() c.execute( """ SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC """, (user_id,), ) res = c.fetchall() c.close() conversations = [] for row in res: cid = row[0] created_at = row[1] updated_at = row[2] title = row[3] persona_id = row[4] conversations.append( Conversation("", cid, "[]", created_at, updated_at, title, persona_id), ) return conversations def update_conversation(self, user_id: str, cid: str, history: str): """更新对话,并且同时更新时间""" updated_at = int(time.time()) self._exec_sql( """ UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? """, (history, updated_at, user_id, cid), ) def update_conversation_title(self, user_id: str, cid: str, title: str): self._exec_sql( """ UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? """, (title, user_id, cid), ) def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): self._exec_sql( """ UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? """, (persona_id, user_id, cid), ) def delete_conversation(self, user_id: str, cid: str): self._exec_sql( """ DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? """, (user_id, cid), ) def get_all_conversations( self, page: int = 1, page_size: int = 20, ) -> tuple[list[dict[str, Any]], int]: """获取所有对话,支持分页,按更新时间降序排序""" try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() try: # 获取总记录数 c.execute(""" SELECT COUNT(*) FROM webchat_conversation """) total_count = c.fetchone()[0] # 计算偏移量 offset = (page - 1) * page_size # 获取分页数据,按更新时间降序排序 c.execute( """ SELECT user_id, cid, created_at, updated_at, title, persona_id FROM webchat_conversation ORDER BY updated_at DESC LIMIT ? OFFSET ? """, (page_size, offset), ) rows = c.fetchall() conversations = [] for row in rows: user_id, cid, created_at, updated_at, title, persona_id = row # 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值 safe_cid = str(cid) if cid else "unknown" display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid conversations.append( { "user_id": user_id or "", "cid": safe_cid, "title": title or f"对话 {display_cid}", "persona_id": persona_id or "", "created_at": created_at or 0, "updated_at": updated_at or 0, }, ) return conversations, total_count except Exception as _: # 返回空列表和0,确保即使出错也有有效的返回值 return [], 0 finally: c.close() def get_filtered_conversations( self, page: int = 1, page_size: int = 20, platforms: list[str] | None = None, message_types: list[str] | None = None, search_query: str | None = None, exclude_ids: list[str] | None = None, exclude_platforms: list[str] | None = None, ) -> tuple[list[dict[str, Any]], int]: """获取筛选后的对话列表""" try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() try: # 构建查询条件 where_clauses = [] params = [] # 平台筛选 if platforms and len(platforms) > 0: platform_conditions = [] for platform in platforms: platform_conditions.append("user_id LIKE ?") params.append(f"{platform}:%") if platform_conditions: where_clauses.append(f"({' OR '.join(platform_conditions)})") # 消息类型筛选 if message_types and len(message_types) > 0: message_type_conditions = [] for msg_type in message_types: message_type_conditions.append("user_id LIKE ?") params.append(f"%:{msg_type}:%") if message_type_conditions: where_clauses.append(f"({' OR '.join(message_type_conditions)})") # 搜索关键词 if search_query: search_query = search_query.encode("unicode_escape").decode("utf-8") where_clauses.append( "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)", ) search_param = f"%{search_query}%" params.extend([search_param, search_param, search_param, search_param]) # 排除特定用户ID if exclude_ids and len(exclude_ids) > 0: for exclude_id in exclude_ids: where_clauses.append("user_id NOT LIKE ?") params.append(f"{exclude_id}%") # 排除特定平台 if exclude_platforms and len(exclude_platforms) > 0: for exclude_platform in exclude_platforms: where_clauses.append("user_id NOT LIKE ?") params.append(f"{exclude_platform}:%") # 构建完整的 WHERE 子句 where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" # 构建计数查询 count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}" # 获取总记录数 c.execute(count_sql, params) total_count = c.fetchone()[0] # 计算偏移量 offset = (page - 1) * page_size # 构建分页数据查询 data_sql = f""" SELECT user_id, cid, created_at, updated_at, title, persona_id FROM webchat_conversation {where_sql} ORDER BY updated_at DESC LIMIT ? OFFSET ? """ query_params = params + [page_size, offset] # 获取分页数据 c.execute(data_sql, query_params) rows = c.fetchall() conversations = [] for row in rows: user_id, cid, created_at, updated_at, title, persona_id = row # 确保 cid 是字符串类型,否则使用一个默认值 safe_cid = str(cid) if cid else "unknown" display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid conversations.append( { "user_id": user_id or "", "cid": safe_cid, "title": title or f"对话 {display_cid}", "persona_id": persona_id or "", "created_at": created_at or 0, "updated_at": updated_at or 0, }, ) return conversations, total_count except Exception as _: # 返回空列表和0,确保即使出错也有有效的返回值 return [], 0 finally: c.close()