feat: add Gemini embedding provider and update OpenAI provider to support timeout configuration

This commit is contained in:
Raven95676
2025-05-31 14:13:58 +08:00
parent 705cf2ea1b
commit 5f07bcc8e6
3 changed files with 79 additions and 1 deletions

View File

@@ -873,6 +873,17 @@ CONFIG_METADATA_2 = {
"embedding_dimensions": 1536,
"timeout": 20,
},
"Gemini Embedding": {
"id": "gemini_embedding",
"type": "gemini_embedding",
"provider_type": "embedding",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "",
"embedding_model": "gemini-embedding-exp-03-07",
"embedding_dimensions": 768,
"timeout": 20,
},
},
"items": {
"embedding_dimensions": {
@@ -888,7 +899,10 @@ CONFIG_METADATA_2 = {
"embedding_api_key": {
"description": "API Key",
"type": "string",
"hint": "API Key",
},
"embedding_api_base": {
"description": "API Base URL",
"type": "string",
},
"volcengine_cluster": {
"type": "string",

View File

@@ -0,0 +1,63 @@
from google import genai
from google.genai import types
from google.genai.errors import APIError
from ..provider import EmbeddingProvider
from ..register import register_provider_adapter
from ..entities import ProviderType
@register_provider_adapter(
"gemini_embedding",
"Google Gemini Embedding 提供商适配器",
provider_type=ProviderType.EMBEDDING,
)
class GeminiEmbeddingProvider(EmbeddingProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
api_key: str = provider_config.get("embedding_api_key")
api_base: str = provider_config.get("embedding_api_base", None)
timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000)
if api_base:
if api_base.endswith("/"):
api_base = api_base[:-1]
http_options.base_url = api_base
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
self.model = provider_config.get(
"embedding_model", "gemini-embedding-exp-03-07"
)
self.dimension = provider_config.get("embedding_dimensions", 768)
async def get_embedding(self, text: str) -> list[float]:
"""
获取文本的嵌入
"""
try:
result = await self.client.models.embed_content(
model=self.model, contents=text
)
return result.embeddings[0].values
except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
批量获取文本的嵌入
"""
try:
result = await self.client.models.embed_content(
model=self.model, contents=texts
)
return [embedding.values for embedding in result.embeddings]
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
def get_dim(self) -> int:
"""获取向量的维度"""
return self.dimension

View File

@@ -19,6 +19,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
base_url=provider_config.get(
"embedding_api_base", "https://api.openai.com/v1"
),
timeout=int(provider_config.get("timeout", 20)),
)
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
self.dimension = provider_config.get("embedding_dimensions", 1536)