refactor: 拆分方法以提高代码可读性
This commit is contained in:
@@ -55,12 +55,18 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout: int = provider_config.get("timeout", 180)
|
||||
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||
|
||||
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
||||
if self.api_base and self.api_base.endswith("/"):
|
||||
self.api_base = self.api_base[:-1]
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
self._init_client()
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self._init_safety_settings()
|
||||
|
||||
def _init_client(self) -> None:
|
||||
"""初始化Gemini客户端"""
|
||||
self.client = genai.Client(
|
||||
api_key=self.chosen_api_key,
|
||||
http_options=types.HttpOptions(
|
||||
@@ -68,8 +74,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
),
|
||||
).aio
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
|
||||
def _init_safety_settings(self) -> None:
|
||||
"""初始化安全设置"""
|
||||
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
||||
self.safety_settings = [
|
||||
types.SafetySetting(
|
||||
@@ -80,16 +87,59 @@ class ProviderGoogleGenAI(Provider):
|
||||
and threshold_str in self.THRESHOLD_MAPPING
|
||||
]
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models = await self.client.models.list()
|
||||
return [
|
||||
m.name.replace("models/", "")
|
||||
for m in models
|
||||
if "generateContent" in m.supported_actions
|
||||
]
|
||||
except APIError as e:
|
||||
raise Exception(f"获取模型列表失败: {e.message}")
|
||||
async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool:
|
||||
"""处理API错误,返回是否需要重试"""
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
if len(keys) > 0:
|
||||
self.set_key(random.choice(keys))
|
||||
logger.info(
|
||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _prepare_query_config(
|
||||
self,
|
||||
tools: Optional[FuncCall] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
temperature: Optional[float] = 0.7,
|
||||
modalities: Optional[List[str]] = None,
|
||||
) -> types.GenerateContentConfig:
|
||||
"""准备查询配置"""
|
||||
if not modalities:
|
||||
modalities = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("Image")
|
||||
|
||||
tool_list = None
|
||||
if tools:
|
||||
func_desc = tools.get_func_desc_google_genai_style()
|
||||
if func_desc:
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_conversation(payloads: Dict) -> List[types.Content]:
|
||||
@@ -165,165 +215,6 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
return gemini_contents
|
||||
|
||||
async def _query(
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
) -> LLMResponse:
|
||||
"""非流式请求 Gemini API"""
|
||||
tool_list = None
|
||||
if tools:
|
||||
func_desc = tools.get_func_desc_google_genai_style()
|
||||
if func_desc:
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
modalities = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("Image")
|
||||
|
||||
result: Optional[types.GenerateContentResponse] = None
|
||||
while True:
|
||||
try:
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings
|
||||
if self.safety_settings
|
||||
else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
||||
if temperature > 2:
|
||||
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
||||
temperature += 0.2
|
||||
logger.warning(
|
||||
f"发生了recitation,正在提高温度至{temperature:.1f}重试..."
|
||||
)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tool_list = None
|
||||
elif (
|
||||
"Multi-modal output is not supported"
|
||||
or "Model does not support the requested response modalities"
|
||||
in e.message
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
|
||||
)
|
||||
modalities = ["Text"]
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式请求 Gemini API"""
|
||||
tool_list = None
|
||||
if tools:
|
||||
func_desc = tools.get_func_desc_google_genai_style()
|
||||
if func_desc:
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings
|
||||
if self.safety_settings
|
||||
else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tool_list = None
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
if not result:
|
||||
raise Exception("API 返回异常")
|
||||
|
||||
async for chunk in result:
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
if chunk.candidates[0].content.parts and any(
|
||||
part.function_call for part in chunk.candidates[0].content.parts
|
||||
):
|
||||
response = LLMResponse("assistant", is_chunk=False)
|
||||
response.result_chain = self._process_content_parts(chunk, response)
|
||||
yield response
|
||||
break
|
||||
|
||||
if chunk.text:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _process_content_parts(
|
||||
result: types.GenerateContentResponse, llm_response: LLMResponse
|
||||
@@ -361,6 +252,129 @@ class ProviderGoogleGenAI(Provider):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def _query(
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
) -> LLMResponse:
|
||||
"""非流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
modalities = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("Image")
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result: Optional[types.GenerateContentResponse] = None
|
||||
while True:
|
||||
try:
|
||||
config = await self._prepare_query_config(
|
||||
tools, system_instruction, temperature, modalities
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=config,
|
||||
)
|
||||
|
||||
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
||||
if temperature > 2:
|
||||
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
||||
temperature += 0.2
|
||||
logger.warning(
|
||||
f"发生了recitation,正在提高温度至{temperature:.1f}重试..."
|
||||
)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
elif (
|
||||
"Multi-modal output is not supported"
|
||||
or "Model does not support the requested response modalities"
|
||||
in e.message
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
|
||||
)
|
||||
modalities = ["Text"]
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
config = await self._prepare_query_config(
|
||||
tools, system_instruction, temperature
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=config,
|
||||
)
|
||||
break
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
async for chunk in result:
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
if chunk.candidates[0].content.parts and any(
|
||||
part.function_call for part in chunk.candidates[0].content.parts
|
||||
):
|
||||
response = LLMResponse("assistant", is_chunk=False)
|
||||
response.result_chain = self._process_content_parts(chunk, response)
|
||||
yield response
|
||||
break
|
||||
|
||||
if chunk.text:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -389,7 +403,6 @@ class ProviderGoogleGenAI(Provider):
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
llm_response = None
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -397,30 +410,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool, temp)
|
||||
break
|
||||
return await self._query(payloads, func_tool, temp)
|
||||
except APIError as e:
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
if len(keys) > 0:
|
||||
self.set_key(random.choice(keys))
|
||||
logger.info(
|
||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
break
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
@@ -461,25 +455,20 @@ class ProviderGoogleGenAI(Provider):
|
||||
yield response
|
||||
break
|
||||
except APIError as e:
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
if len(keys) > 0:
|
||||
self.set_key(random.choice(keys))
|
||||
logger.info(
|
||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
raise e
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
break
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models = await self.client.models.list()
|
||||
return [
|
||||
m.name.replace("models/", "")
|
||||
for m in models
|
||||
if "generateContent" in m.supported_actions
|
||||
]
|
||||
except APIError as e:
|
||||
raise Exception(f"获取模型列表失败: {e.message}")
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.chosen_api_key
|
||||
@@ -489,14 +478,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
def set_key(self, key):
|
||||
self.chosen_api_key = key
|
||||
# 重新初始化客户端
|
||||
self.client = genai.Client(
|
||||
api_key=self.chosen_api_key,
|
||||
http_options=types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
),
|
||||
).aio
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user