From 5f07bcc8e691572ea301e2c6dec9f16b89ad706c Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sat, 31 May 2025 14:13:58 +0800 Subject: [PATCH] feat: add Gemini embedding provider and update OpenAI provider to support timeout configuration --- astrbot/core/config/default.py | 16 ++++- .../sources/gemini_embedding_source.py | 63 +++++++++++++++++++ .../sources/openai_embedding_source.py | 1 + 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 astrbot/core/provider/sources/gemini_embedding_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index dce1da52..6af27337 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -873,6 +873,17 @@ CONFIG_METADATA_2 = { "embedding_dimensions": 1536, "timeout": 20, }, + "Gemini Embedding": { + "id": "gemini_embedding", + "type": "gemini_embedding", + "provider_type": "embedding", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "", + "embedding_model": "gemini-embedding-exp-03-07", + "embedding_dimensions": 768, + "timeout": 20, + }, }, "items": { "embedding_dimensions": { @@ -888,7 +899,10 @@ CONFIG_METADATA_2 = { "embedding_api_key": { "description": "API Key", "type": "string", - "hint": "API Key", + }, + "embedding_api_base": { + "description": "API Base URL", + "type": "string", }, "volcengine_cluster": { "type": "string", diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py new file mode 100644 index 00000000..baccf52a --- /dev/null +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -0,0 +1,63 @@ +from google import genai +from google.genai import types +from google.genai.errors import APIError +from ..provider import EmbeddingProvider +from ..register import register_provider_adapter +from ..entities import ProviderType + + +@register_provider_adapter( + "gemini_embedding", + "Google Gemini Embedding 提供商适配器", + provider_type=ProviderType.EMBEDDING, +) +class GeminiEmbeddingProvider(EmbeddingProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + + api_key: str = provider_config.get("embedding_api_key") + api_base: str = provider_config.get("embedding_api_base", None) + timeout: int = int(provider_config.get("timeout", 20)) + + http_options = types.HttpOptions(timeout=timeout * 1000) + if api_base: + if api_base.endswith("/"): + api_base = api_base[:-1] + http_options.base_url = api_base + + self.client = genai.Client(api_key=api_key, http_options=http_options).aio + + self.model = provider_config.get( + "embedding_model", "gemini-embedding-exp-03-07" + ) + self.dimension = provider_config.get("embedding_dimensions", 768) + + async def get_embedding(self, text: str) -> list[float]: + """ + 获取文本的嵌入 + """ + try: + result = await self.client.models.embed_content( + model=self.model, contents=text + ) + return result.embeddings[0].values + except APIError as e: + raise Exception(f"Gemini Embedding API请求失败: {e.message}") + + async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + """ + 批量获取文本的嵌入 + """ + try: + result = await self.client.models.embed_content( + model=self.model, contents=texts + ) + return [embedding.values for embedding in result.embeddings] + except APIError as e: + raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") + + def get_dim(self) -> int: + """获取向量的维度""" + return self.dimension diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 2d339e57..f4315247 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -19,6 +19,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): base_url=provider_config.get( "embedding_api_base", "https://api.openai.com/v1" ), + timeout=int(provider_config.get("timeout", 20)), ) self.model = provider_config.get("embedding_model", "text-embedding-3-small") self.dimension = provider_config.get("embedding_dimensions", 1536)