import asyncio import threading import typing as T from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col, delete, desc, func, or_, select, text, update from astrbot.core.db import BaseDatabase from astrbot.core.db.po import ( Attachment, ConversationV2, Persona, PlatformMessageHistory, PlatformSession, PlatformStat, Preference, SQLModel, ) from astrbot.core.db.po import ( Platform as DeprecatedPlatformStat, ) from astrbot.core.db.po import ( Stats as DeprecatedStats, ) NOT_GIVEN = T.TypeVar("NOT_GIVEN") class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.inited = False super().__init__() async def initialize(self) -> None: """Initialize the database by creating tables if they do not exist.""" async with self.engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) await conn.execute(text("PRAGMA journal_mode=WAL")) await conn.execute(text("PRAGMA synchronous=NORMAL")) await conn.execute(text("PRAGMA cache_size=20000")) await conn.execute(text("PRAGMA temp_store=MEMORY")) await conn.execute(text("PRAGMA mmap_size=134217728")) await conn.execute(text("PRAGMA optimize")) await conn.commit() # ==== # Platform Statistics # ==== async def insert_platform_stats( self, platform_id, platform_type, count=1, timestamp=None, ) -> None: """Insert a new platform statistic record.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): if timestamp is None: timestamp = datetime.now().replace( minute=0, second=0, microsecond=0, ) current_hour = timestamp await session.execute( text(""" INSERT INTO platform_stats (timestamp, platform_id, platform_type, count) VALUES (:timestamp, :platform_id, :platform_type, :count) ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET count = platform_stats.count + EXCLUDED.count """), { "timestamp": current_hour, "platform_id": platform_id, "platform_type": platform_type, "count": count, }, ) async def count_platform_stats(self) -> int: """Count the number of platform statistics records.""" async with self.get_db() as session: session: AsyncSession result = await session.execute( select(func.count(col(PlatformStat.platform_id))).select_from( PlatformStat, ), ) count = result.scalar_one_or_none() return count if count is not None else 0 async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: """Get platform statistics within the specified offset in seconds and group by platform_id.""" async with self.get_db() as session: session: AsyncSession now = datetime.now() start_time = now - timedelta(seconds=offset_sec) result = await session.execute( text(""" SELECT * FROM platform_stats WHERE timestamp >= :start_time ORDER BY timestamp DESC GROUP BY platform_id """), {"start_time": start_time}, ) return list(result.scalars().all()) # ==== # Conversation Management # ==== async def get_conversations(self, user_id=None, platform_id=None): async with self.get_db() as session: session: AsyncSession query = select(ConversationV2) if user_id: query = query.where(ConversationV2.user_id == user_id) if platform_id: query = query.where(ConversationV2.platform_id == platform_id) # order by query = query.order_by(desc(ConversationV2.created_at)) result = await session.execute(query) return result.scalars().all() async def get_conversation_by_id(self, cid): async with self.get_db() as session: session: AsyncSession query = select(ConversationV2).where(ConversationV2.conversation_id == cid) result = await session.execute(query) return result.scalar_one_or_none() async def get_all_conversations(self, page=1, page_size=20): async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size result = await session.execute( select(ConversationV2) .order_by(desc(ConversationV2.created_at)) .offset(offset) .limit(page_size), ) return result.scalars().all() async def get_filtered_conversations( self, page=1, page_size=20, platform_ids=None, search_query="", **kwargs, ): async with self.get_db() as session: session: AsyncSession # Build the base query with filters base_query = select(ConversationV2) if platform_ids: base_query = base_query.where( col(ConversationV2.platform_id).in_(platform_ids), ) if search_query: search_query = search_query.encode("unicode_escape").decode("utf-8") base_query = base_query.where( or_( col(ConversationV2.title).ilike(f"%{search_query}%"), col(ConversationV2.content).ilike(f"%{search_query}%"), col(ConversationV2.user_id).ilike(f"%{search_query}%"), col(ConversationV2.conversation_id).ilike(f"%{search_query}%"), ), ) if "message_types" in kwargs and len(kwargs["message_types"]) > 0: for msg_type in kwargs["message_types"]: base_query = base_query.where( col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"), ) if "platforms" in kwargs and len(kwargs["platforms"]) > 0: base_query = base_query.where( col(ConversationV2.platform_id).in_(kwargs["platforms"]), ) # Get total count matching the filters count_query = select(func.count()).select_from(base_query.subquery()) total_count = await session.execute(count_query) total = total_count.scalar_one() # Get paginated results offset = (page - 1) * page_size result_query = ( base_query.order_by(desc(ConversationV2.created_at)) .offset(offset) .limit(page_size) ) result = await session.execute(result_query) conversations = result.scalars().all() return conversations, total async def create_conversation( self, user_id, platform_id, content=None, title=None, persona_id=None, cid=None, created_at=None, updated_at=None, ): kwargs = {} if cid: kwargs["conversation_id"] = cid if created_at: kwargs["created_at"] = created_at if updated_at: kwargs["updated_at"] = updated_at async with self.get_db() as session: session: AsyncSession async with session.begin(): new_conversation = ConversationV2( user_id=user_id, content=content or [], platform_id=platform_id, title=title, persona_id=persona_id, **kwargs, ) session.add(new_conversation) return new_conversation async def update_conversation(self, cid, title=None, persona_id=None, content=None): async with self.get_db() as session: session: AsyncSession async with session.begin(): query = update(ConversationV2).where( col(ConversationV2.conversation_id) == cid, ) values = {} if title is not None: values["title"] = title if persona_id is not None: values["persona_id"] = persona_id if content is not None: values["content"] = content if not values: return None query = query.values(**values) await session.execute(query) return await self.get_conversation_by_id(cid) async def delete_conversation(self, cid): async with self.get_db() as session: session: AsyncSession async with session.begin(): await session.execute( delete(ConversationV2).where( col(ConversationV2.conversation_id) == cid, ), ) async def delete_conversations_by_user_id(self, user_id: str) -> None: async with self.get_db() as session: session: AsyncSession async with session.begin(): await session.execute( delete(ConversationV2).where( col(ConversationV2.user_id) == user_id ), ) async def get_session_conversations( self, page=1, page_size=20, search_query=None, platform=None, ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details.""" async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size base_query = ( select( col(Preference.scope_id).label("session_id"), func.json_extract(Preference.value, "$.val").label( "conversation_id", ), # type: ignore col(ConversationV2.persona_id).label("persona_id"), col(ConversationV2.title).label("title"), col(Persona.persona_id).label("persona_name"), ) .select_from(Preference) .outerjoin( ConversationV2, func.json_extract(Preference.value, "$.val") == ConversationV2.conversation_id, ) .outerjoin( Persona, col(ConversationV2.persona_id) == Persona.persona_id, ) .where(Preference.scope == "umo", Preference.key == "sel_conv_id") ) # 搜索筛选 if search_query: search_pattern = f"%{search_query}%" base_query = base_query.where( or_( col(Preference.scope_id).ilike(search_pattern), col(ConversationV2.title).ilike(search_pattern), col(Persona.persona_id).ilike(search_pattern), ), ) # 平台筛选 if platform: platform_pattern = f"{platform}:%" base_query = base_query.where( col(Preference.scope_id).like(platform_pattern), ) # 排序 base_query = base_query.order_by(Preference.scope_id) # 分页结果 result_query = base_query.offset(offset).limit(page_size) result = await session.execute(result_query) rows = result.fetchall() # 查询总数(应用相同的筛选条件) count_base_query = ( select(func.count(col(Preference.scope_id))) .select_from(Preference) .outerjoin( ConversationV2, func.json_extract(Preference.value, "$.val") == ConversationV2.conversation_id, ) .outerjoin( Persona, col(ConversationV2.persona_id) == Persona.persona_id, ) .where(Preference.scope == "umo", Preference.key == "sel_conv_id") ) # 应用相同的搜索和平台筛选条件到计数查询 if search_query: search_pattern = f"%{search_query}%" count_base_query = count_base_query.where( or_( col(Preference.scope_id).ilike(search_pattern), col(ConversationV2.title).ilike(search_pattern), col(Persona.persona_id).ilike(search_pattern), ), ) if platform: platform_pattern = f"{platform}:%" count_base_query = count_base_query.where( col(Preference.scope_id).like(platform_pattern), ) total_result = await session.execute(count_base_query) total = total_result.scalar() or 0 sessions_data = [ { "session_id": row.session_id, "conversation_id": row.conversation_id, "persona_id": row.persona_id, "title": row.title, "persona_name": row.persona_name, } for row in rows ] return sessions_data, total async def insert_platform_message_history( self, platform_id, user_id, content, sender_id=None, sender_name=None, ): """Insert a new platform message history record.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): new_history = PlatformMessageHistory( platform_id=platform_id, user_id=user_id, content=content, sender_id=sender_id, sender_name=sender_name, ) session.add(new_history) return new_history async def delete_platform_message_offset( self, platform_id, user_id, offset_sec=86400, ): """Delete platform message history records newer than the specified offset.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): now = datetime.now() cutoff_time = now - timedelta(seconds=offset_sec) await session.execute( delete(PlatformMessageHistory).where( col(PlatformMessageHistory.platform_id) == platform_id, col(PlatformMessageHistory.user_id) == user_id, col(PlatformMessageHistory.created_at) >= cutoff_time, ), ) async def get_platform_message_history( self, platform_id, user_id, page=1, page_size=20, ): """Get platform message history records.""" async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size query = ( select(PlatformMessageHistory) .where( PlatformMessageHistory.platform_id == platform_id, PlatformMessageHistory.user_id == user_id, ) .order_by(desc(PlatformMessageHistory.created_at)) ) result = await session.execute(query.offset(offset).limit(page_size)) return result.scalars().all() async def insert_attachment(self, path, type, mime_type): """Insert a new attachment record.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): new_attachment = Attachment( path=path, type=type, mime_type=mime_type, ) session.add(new_attachment) return new_attachment async def get_attachment_by_id(self, attachment_id): """Get an attachment by its ID.""" async with self.get_db() as session: session: AsyncSession query = select(Attachment).where(Attachment.attachment_id == attachment_id) result = await session.execute(query) return result.scalar_one_or_none() async def insert_persona( self, persona_id, system_prompt, begin_dialogs=None, tools=None, ): """Insert a new persona record.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): new_persona = Persona( persona_id=persona_id, system_prompt=system_prompt, begin_dialogs=begin_dialogs or [], tools=tools, ) session.add(new_persona) return new_persona async def get_persona_by_id(self, persona_id): """Get a persona by its ID.""" async with self.get_db() as session: session: AsyncSession query = select(Persona).where(Persona.persona_id == persona_id) result = await session.execute(query) return result.scalar_one_or_none() async def get_personas(self): """Get all personas for a specific bot.""" async with self.get_db() as session: session: AsyncSession query = select(Persona) result = await session.execute(query) return result.scalars().all() async def update_persona( self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN, ): """Update a persona's system prompt or begin dialogs.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): query = update(Persona).where(col(Persona.persona_id) == persona_id) values = {} if system_prompt is not None: values["system_prompt"] = system_prompt if begin_dialogs is not None: values["begin_dialogs"] = begin_dialogs if tools is not NOT_GIVEN: values["tools"] = tools if not values: return None query = query.values(**values) await session.execute(query) return await self.get_persona_by_id(persona_id) async def delete_persona(self, persona_id): """Delete a persona by its ID.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): await session.execute( delete(Persona).where(col(Persona.persona_id) == persona_id), ) async def insert_preference_or_update(self, scope, scope_id, key, value): """Insert a new preference record or update if it exists.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): query = select(Preference).where( Preference.scope == scope, Preference.scope_id == scope_id, Preference.key == key, ) result = await session.execute(query) existing_preference = result.scalar_one_or_none() if existing_preference: existing_preference.value = value else: new_preference = Preference( scope=scope, scope_id=scope_id, key=key, value=value, ) session.add(new_preference) return existing_preference or new_preference async def get_preference(self, scope, scope_id, key): """Get a preference by key.""" async with self.get_db() as session: session: AsyncSession query = select(Preference).where( Preference.scope == scope, Preference.scope_id == scope_id, Preference.key == key, ) result = await session.execute(query) return result.scalar_one_or_none() async def get_preferences(self, scope, scope_id=None, key=None): """Get all preferences for a specific scope ID or key.""" async with self.get_db() as session: session: AsyncSession query = select(Preference).where(Preference.scope == scope) if scope_id is not None: query = query.where(Preference.scope_id == scope_id) if key is not None: query = query.where(Preference.key == key) result = await session.execute(query) return result.scalars().all() async def remove_preference(self, scope, scope_id, key): """Remove a preference by scope ID and key.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): await session.execute( delete(Preference).where( col(Preference.scope) == scope, col(Preference.scope_id) == scope_id, col(Preference.key) == key, ), ) await session.commit() async def clear_preferences(self, scope, scope_id): """Clear all preferences for a specific scope ID.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): await session.execute( delete(Preference).where( col(Preference.scope) == scope, col(Preference.scope_id) == scope_id, ), ) await session.commit() # ==== # Deprecated Methods # ==== def get_base_stats(self, offset_sec=86400): """Get base statistics within the specified offset in seconds.""" async def _inner(): async with self.get_db() as session: session: AsyncSession now = datetime.now() start_time = now - timedelta(seconds=offset_sec) result = await session.execute( select(PlatformStat).where(PlatformStat.timestamp >= start_time), ) all_datas = result.scalars().all() deprecated_stats = DeprecatedStats() for data in all_datas: deprecated_stats.platform.append( DeprecatedPlatformStat( name=data.platform_id, count=data.count, timestamp=int(data.timestamp.timestamp()), ), ) return deprecated_stats result = None def runner(): nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() return result def get_total_message_count(self): """Get the total message count from platform statistics.""" async def _inner(): async with self.get_db() as session: session: AsyncSession result = await session.execute( select(func.sum(PlatformStat.count)).select_from(PlatformStat), ) total_count = result.scalar_one_or_none() return total_count if total_count is not None else 0 result = None def runner(): nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() return result def get_grouped_base_stats(self, offset_sec=86400): # group by platform_id async def _inner(): async with self.get_db() as session: session: AsyncSession now = datetime.now() start_time = now - timedelta(seconds=offset_sec) result = await session.execute( select(PlatformStat.platform_id, func.sum(PlatformStat.count)) .where(PlatformStat.timestamp >= start_time) .group_by(PlatformStat.platform_id), ) grouped_stats = result.all() deprecated_stats = DeprecatedStats() for platform_id, count in grouped_stats: deprecated_stats.platform.append( DeprecatedPlatformStat( name=platform_id, count=count, timestamp=int(start_time.timestamp()), ), ) return deprecated_stats result = None def runner(): nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() return result # ==== # Platform Session Management # ==== async def create_platform_session( self, creator: str, platform_id: str = "webchat", session_id: str | None = None, display_name: str | None = None, is_group: int = 0, ) -> PlatformSession: """Create a new Platform session.""" kwargs = {} if session_id: kwargs["session_id"] = session_id async with self.get_db() as session: session: AsyncSession async with session.begin(): new_session = PlatformSession( creator=creator, platform_id=platform_id, display_name=display_name, is_group=is_group, **kwargs, ) session.add(new_session) await session.flush() await session.refresh(new_session) return new_session async def get_platform_session_by_id( self, session_id: str ) -> PlatformSession | None: """Get a Platform session by its ID.""" async with self.get_db() as session: session: AsyncSession query = select(PlatformSession).where( PlatformSession.session_id == session_id, ) result = await session.execute(query) return result.scalar_one_or_none() async def get_platform_sessions_by_creator( self, creator: str, platform_id: str | None = None, page: int = 1, page_size: int = 20, ) -> list[PlatformSession]: """Get all Platform sessions for a specific creator (username) and optionally platform.""" async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size query = select(PlatformSession).where(PlatformSession.creator == creator) if platform_id: query = query.where(PlatformSession.platform_id == platform_id) query = ( query.order_by(desc(PlatformSession.updated_at)) .offset(offset) .limit(page_size) ) result = await session.execute(query) return list(result.scalars().all()) async def update_platform_session( self, session_id: str, display_name: str | None = None, ) -> None: """Update a Platform session's updated_at timestamp and optionally display_name.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} if display_name is not None: values["display_name"] = display_name await session.execute( update(PlatformSession) .where(col(PlatformSession.session_id) == session_id) .values(**values), ) async def delete_platform_session(self, session_id: str) -> None: """Delete a Platform session by its ID.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): await session.execute( delete(PlatformSession).where( col(PlatformSession.session_id) == session_id, ), )