162 lines
5.2 KiB
Python
162 lines
5.2 KiB
Python
import base64
|
||
import os
|
||
import traceback
|
||
from io import BytesIO
|
||
from astrbot.api import logger
|
||
from astrbot.core.knowledge_base.kb_helper import KBHelper
|
||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||
|
||
|
||
async def generate_tsne_visualization(
|
||
query: str, kb_names: list[str], kb_manager: KnowledgeBaseManager
|
||
) -> str | None:
|
||
"""生成 t-SNE 可视化图片
|
||
|
||
Args:
|
||
query: 查询文本
|
||
kb_names: 知识库名称列表
|
||
kb_manager: 知识库管理器
|
||
|
||
Returns:
|
||
图片路径或 None
|
||
"""
|
||
try:
|
||
import faiss
|
||
import numpy as np
|
||
import matplotlib
|
||
|
||
matplotlib.use("Agg") # 使用非交互式后端
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.manifold import TSNE
|
||
except ImportError as e:
|
||
raise Exception(
|
||
"缺少必要的库以生成 t-SNE 可视化。请安装 matplotlib 和 scikit-learn: {e}"
|
||
) from e
|
||
|
||
try:
|
||
# 获取第一个知识库的向量数据
|
||
kb_helper: KBHelper | None = None
|
||
for kb_name in kb_names:
|
||
kb_helper = await kb_manager.get_kb_by_name(kb_name)
|
||
if kb_helper:
|
||
break
|
||
|
||
if not kb_helper:
|
||
logger.warning("未找到知识库")
|
||
return None
|
||
|
||
kb = kb_helper.kb
|
||
index_path = f"data/knowledge_base/{kb.kb_id}/index.faiss"
|
||
|
||
# 读取 FAISS 索引
|
||
if not os.path.exists(index_path):
|
||
logger.warning(f"FAISS 索引不存在: {index_path}")
|
||
return None
|
||
|
||
index = faiss.read_index(index_path)
|
||
|
||
if index.ntotal == 0:
|
||
logger.warning("索引为空")
|
||
return None
|
||
|
||
# 提取所有向量
|
||
logger.info(f"提取 {index.ntotal} 个向量用于可视化...")
|
||
if isinstance(index, faiss.IndexIDMap):
|
||
base_index = faiss.downcast_index(index.index)
|
||
if hasattr(base_index, "reconstruct_n"):
|
||
vectors = base_index.reconstruct_n(0, index.ntotal)
|
||
else:
|
||
vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
|
||
for i in range(index.ntotal):
|
||
base_index.reconstruct(i, vectors[i])
|
||
elif hasattr(index, "reconstruct_n"):
|
||
vectors = index.reconstruct_n(0, index.ntotal)
|
||
else:
|
||
vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
|
||
for i in range(index.ntotal):
|
||
index.reconstruct(i, vectors[i])
|
||
|
||
# 获取查询向量
|
||
vec_db: FaissVecDB = kb_helper.vec_db # type: ignore
|
||
embedding_provider = vec_db.embedding_provider
|
||
query_embedding = await embedding_provider.get_embedding(query)
|
||
query_vector = np.array([query_embedding], dtype=np.float32)
|
||
|
||
# 合并所有向量和查询向量
|
||
all_vectors = np.vstack([vectors, query_vector])
|
||
|
||
# t-SNE 降维
|
||
logger.info("开始 t-SNE 降维...")
|
||
perplexity = min(30, all_vectors.shape[0] - 1)
|
||
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
|
||
vectors_2d = tsne.fit_transform(all_vectors)
|
||
|
||
# 分离知识库向量和查询向量
|
||
kb_vectors_2d = vectors_2d[:-1]
|
||
query_vector_2d = vectors_2d[-1]
|
||
|
||
# 可视化
|
||
logger.info("生成可视化图表...")
|
||
plt.figure(figsize=(14, 10))
|
||
|
||
# 绘制知识库向量
|
||
scatter = plt.scatter(
|
||
kb_vectors_2d[:, 0],
|
||
kb_vectors_2d[:, 1],
|
||
alpha=0.5,
|
||
s=40,
|
||
c=range(len(kb_vectors_2d)),
|
||
cmap="viridis",
|
||
label="Knowledge Base Vectors",
|
||
)
|
||
|
||
# 绘制查询向量(红色 X)
|
||
plt.scatter(
|
||
query_vector_2d[0],
|
||
query_vector_2d[1],
|
||
c="red",
|
||
s=300,
|
||
marker="X",
|
||
edgecolors="black",
|
||
linewidths=2,
|
||
label="Query",
|
||
zorder=5,
|
||
)
|
||
|
||
# 添加查询文本标注
|
||
plt.annotate(
|
||
"Query",
|
||
(query_vector_2d[0], query_vector_2d[1]),
|
||
xytext=(10, 10),
|
||
textcoords="offset points",
|
||
fontsize=10,
|
||
bbox={"boxstyle": "round,pad=0.5", "fc": "yellow", "alpha": 0.7},
|
||
arrowprops={"arrowstyle": "->", "connectionstyle": "arc3,rad=0"},
|
||
)
|
||
|
||
plt.colorbar(scatter, label="Vector Index")
|
||
plt.title(
|
||
f"t-SNE Visualization: Query in Knowledge Base\n"
|
||
f"({index.ntotal} vectors, {index.d} dimensions, KB: {kb.kb_name})",
|
||
fontsize=14,
|
||
pad=20,
|
||
)
|
||
plt.xlabel("t-SNE Dimension 1", fontsize=12)
|
||
plt.ylabel("t-SNE Dimension 2", fontsize=12)
|
||
plt.grid(True, alpha=0.3)
|
||
plt.legend(fontsize=10, loc="upper right")
|
||
|
||
# base64 编码图片返回
|
||
buffer = BytesIO()
|
||
plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight")
|
||
plt.close()
|
||
buffer.seek(0)
|
||
img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
|
||
return img_base64
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成 t-SNE 可视化时出错: {e}")
|
||
logger.error(traceback.format_exc())
|
||
return None
|