refactor: 拆分方法以提高代码可读性
This commit is contained in:
@@ -55,12 +55,18 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
)
|
)
|
||||||
self.api_keys: List = provider_config.get("key", [])
|
self.api_keys: List = provider_config.get("key", [])
|
||||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||||
self.timeout: int = provider_config.get("timeout", 180)
|
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||||
|
|
||||||
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
||||||
if self.api_base and self.api_base.endswith("/"):
|
if self.api_base and self.api_base.endswith("/"):
|
||||||
self.api_base = self.api_base[:-1]
|
self.api_base = self.api_base[:-1]
|
||||||
if isinstance(self.timeout, str):
|
|
||||||
self.timeout = int(self.timeout)
|
self._init_client()
|
||||||
|
self.set_model(provider_config["model_config"]["model"])
|
||||||
|
self._init_safety_settings()
|
||||||
|
|
||||||
|
def _init_client(self) -> None:
|
||||||
|
"""初始化Gemini客户端"""
|
||||||
self.client = genai.Client(
|
self.client = genai.Client(
|
||||||
api_key=self.chosen_api_key,
|
api_key=self.chosen_api_key,
|
||||||
http_options=types.HttpOptions(
|
http_options=types.HttpOptions(
|
||||||
@@ -68,8 +74,9 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
timeout=self.timeout * 1000, # 毫秒
|
timeout=self.timeout * 1000, # 毫秒
|
||||||
),
|
),
|
||||||
).aio
|
).aio
|
||||||
self.set_model(provider_config["model_config"]["model"])
|
|
||||||
|
|
||||||
|
def _init_safety_settings(self) -> None:
|
||||||
|
"""初始化安全设置"""
|
||||||
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
||||||
self.safety_settings = [
|
self.safety_settings = [
|
||||||
types.SafetySetting(
|
types.SafetySetting(
|
||||||
@@ -80,16 +87,59 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
and threshold_str in self.THRESHOLD_MAPPING
|
and threshold_str in self.THRESHOLD_MAPPING
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_models(self):
|
async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool:
|
||||||
try:
|
"""处理API错误,返回是否需要重试"""
|
||||||
models = await self.client.models.list()
|
if e.code == 429 or "API key not valid" in e.message:
|
||||||
return [
|
keys.remove(self.chosen_api_key)
|
||||||
m.name.replace("models/", "")
|
if len(keys) > 0:
|
||||||
for m in models
|
self.set_key(random.choice(keys))
|
||||||
if "generateContent" in m.supported_actions
|
logger.info(
|
||||||
]
|
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||||
except APIError as e:
|
)
|
||||||
raise Exception(f"获取模型列表失败: {e.message}")
|
await asyncio.sleep(1)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||||
|
)
|
||||||
|
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _prepare_query_config(
|
||||||
|
self,
|
||||||
|
tools: Optional[FuncCall] = None,
|
||||||
|
system_instruction: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = 0.7,
|
||||||
|
modalities: Optional[List[str]] = None,
|
||||||
|
) -> types.GenerateContentConfig:
|
||||||
|
"""准备查询配置"""
|
||||||
|
if not modalities:
|
||||||
|
modalities = ["Text"]
|
||||||
|
if self.provider_config.get("gm_resp_image_modal", False):
|
||||||
|
modalities.append("Image")
|
||||||
|
|
||||||
|
tool_list = None
|
||||||
|
if tools:
|
||||||
|
func_desc = tools.get_func_desc_google_genai_style()
|
||||||
|
if func_desc:
|
||||||
|
tool_list = [
|
||||||
|
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||||
|
]
|
||||||
|
|
||||||
|
return types.GenerateContentConfig(
|
||||||
|
system_instruction=system_instruction,
|
||||||
|
temperature=temperature,
|
||||||
|
response_modalities=modalities,
|
||||||
|
tools=tool_list,
|
||||||
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||||
|
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||||
|
disable=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _prepare_conversation(payloads: Dict) -> List[types.Content]:
|
def _prepare_conversation(payloads: Dict) -> List[types.Content]:
|
||||||
@@ -165,165 +215,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
return gemini_contents
|
return gemini_contents
|
||||||
|
|
||||||
async def _query(
|
|
||||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""非流式请求 Gemini API"""
|
|
||||||
tool_list = None
|
|
||||||
if tools:
|
|
||||||
func_desc = tools.get_func_desc_google_genai_style()
|
|
||||||
if func_desc:
|
|
||||||
tool_list = [
|
|
||||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
|
||||||
]
|
|
||||||
|
|
||||||
system_instruction = next(
|
|
||||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
conversation = self._prepare_conversation(payloads)
|
|
||||||
|
|
||||||
modalities = ["Text"]
|
|
||||||
if self.provider_config.get("gm_resp_image_modal", False):
|
|
||||||
modalities.append("Image")
|
|
||||||
|
|
||||||
result: Optional[types.GenerateContentResponse] = None
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
result = await self.client.models.generate_content(
|
|
||||||
model=self.get_model(),
|
|
||||||
contents=conversation,
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
system_instruction=system_instruction,
|
|
||||||
temperature=temperature,
|
|
||||||
response_modalities=modalities,
|
|
||||||
tools=tool_list,
|
|
||||||
safety_settings=self.safety_settings
|
|
||||||
if self.safety_settings
|
|
||||||
else None,
|
|
||||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
|
||||||
disable=True
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
|
||||||
if temperature > 2:
|
|
||||||
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
|
||||||
temperature += 0.2
|
|
||||||
logger.warning(
|
|
||||||
f"发生了recitation,正在提高温度至{temperature:.1f}重试..."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
except APIError as e:
|
|
||||||
if "Developer instruction is not enabled" in e.message:
|
|
||||||
logger.warning(
|
|
||||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
|
||||||
)
|
|
||||||
system_instruction = None
|
|
||||||
elif "Function calling is not enabled" in e.message:
|
|
||||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
|
||||||
tool_list = None
|
|
||||||
elif (
|
|
||||||
"Multi-modal output is not supported"
|
|
||||||
or "Model does not support the requested response modalities"
|
|
||||||
in e.message
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
|
|
||||||
)
|
|
||||||
modalities = ["Text"]
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
continue
|
|
||||||
|
|
||||||
llm_response = LLMResponse("assistant")
|
|
||||||
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
|
||||||
return llm_response
|
|
||||||
|
|
||||||
async def _query_stream(
|
|
||||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
|
||||||
"""流式请求 Gemini API"""
|
|
||||||
tool_list = None
|
|
||||||
if tools:
|
|
||||||
func_desc = tools.get_func_desc_google_genai_style()
|
|
||||||
if func_desc:
|
|
||||||
tool_list = [
|
|
||||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
|
||||||
]
|
|
||||||
|
|
||||||
system_instruction = next(
|
|
||||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
conversation = self._prepare_conversation(payloads)
|
|
||||||
|
|
||||||
result = None
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
result = await self.client.models.generate_content_stream(
|
|
||||||
model=self.get_model(),
|
|
||||||
contents=conversation,
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
system_instruction=system_instruction,
|
|
||||||
temperature=temperature,
|
|
||||||
tools=tool_list,
|
|
||||||
safety_settings=self.safety_settings
|
|
||||||
if self.safety_settings
|
|
||||||
else None,
|
|
||||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
|
||||||
disable=True
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
except APIError as e:
|
|
||||||
if "Developer instruction is not enabled" in e.message:
|
|
||||||
logger.warning(
|
|
||||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
|
||||||
)
|
|
||||||
system_instruction = None
|
|
||||||
elif "Function calling is not enabled" in e.message:
|
|
||||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
|
||||||
tool_list = None
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
raise Exception("API 返回异常")
|
|
||||||
|
|
||||||
async for chunk in result:
|
|
||||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
|
||||||
|
|
||||||
if chunk.candidates[0].content.parts and any(
|
|
||||||
part.function_call for part in chunk.candidates[0].content.parts
|
|
||||||
):
|
|
||||||
response = LLMResponse("assistant", is_chunk=False)
|
|
||||||
response.result_chain = self._process_content_parts(chunk, response)
|
|
||||||
yield response
|
|
||||||
break
|
|
||||||
|
|
||||||
if chunk.text:
|
|
||||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
|
||||||
yield llm_response
|
|
||||||
|
|
||||||
if chunk.candidates[0].finish_reason:
|
|
||||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
|
||||||
llm_response.result_chain = self._process_content_parts(
|
|
||||||
chunk, llm_response
|
|
||||||
)
|
|
||||||
yield llm_response
|
|
||||||
break
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_content_parts(
|
def _process_content_parts(
|
||||||
result: types.GenerateContentResponse, llm_response: LLMResponse
|
result: types.GenerateContentResponse, llm_response: LLMResponse
|
||||||
@@ -361,6 +252,129 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||||
return MessageChain(chain=chain)
|
return MessageChain(chain=chain)
|
||||||
|
|
||||||
|
async def _query(
|
||||||
|
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""非流式请求 Gemini API"""
|
||||||
|
system_instruction = next(
|
||||||
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
modalities = ["Text"]
|
||||||
|
if self.provider_config.get("gm_resp_image_modal", False):
|
||||||
|
modalities.append("Image")
|
||||||
|
|
||||||
|
conversation = self._prepare_conversation(payloads)
|
||||||
|
|
||||||
|
result: Optional[types.GenerateContentResponse] = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
config = await self._prepare_query_config(
|
||||||
|
tools, system_instruction, temperature, modalities
|
||||||
|
)
|
||||||
|
result = await self.client.models.generate_content(
|
||||||
|
model=self.get_model(),
|
||||||
|
contents=conversation,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
||||||
|
if temperature > 2:
|
||||||
|
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
||||||
|
temperature += 0.2
|
||||||
|
logger.warning(
|
||||||
|
f"发生了recitation,正在提高温度至{temperature:.1f}重试..."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
if "Developer instruction is not enabled" in e.message:
|
||||||
|
logger.warning(
|
||||||
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||||
|
)
|
||||||
|
system_instruction = None
|
||||||
|
elif "Function calling is not enabled" in e.message:
|
||||||
|
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||||
|
tools = None
|
||||||
|
elif (
|
||||||
|
"Multi-modal output is not supported"
|
||||||
|
or "Model does not support the requested response modalities"
|
||||||
|
in e.message
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
|
||||||
|
)
|
||||||
|
modalities = ["Text"]
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
continue
|
||||||
|
|
||||||
|
llm_response = LLMResponse("assistant")
|
||||||
|
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
async def _query_stream(
|
||||||
|
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
"""流式请求 Gemini API"""
|
||||||
|
system_instruction = next(
|
||||||
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation = self._prepare_conversation(payloads)
|
||||||
|
|
||||||
|
result = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
config = await self._prepare_query_config(
|
||||||
|
tools, system_instruction, temperature
|
||||||
|
)
|
||||||
|
result = await self.client.models.generate_content_stream(
|
||||||
|
model=self.get_model(),
|
||||||
|
contents=conversation,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except APIError as e:
|
||||||
|
if "Developer instruction is not enabled" in e.message:
|
||||||
|
logger.warning(
|
||||||
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||||
|
)
|
||||||
|
system_instruction = None
|
||||||
|
elif "Function calling is not enabled" in e.message:
|
||||||
|
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||||
|
tools = None
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
continue
|
||||||
|
|
||||||
|
async for chunk in result:
|
||||||
|
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||||
|
|
||||||
|
if chunk.candidates[0].content.parts and any(
|
||||||
|
part.function_call for part in chunk.candidates[0].content.parts
|
||||||
|
):
|
||||||
|
response = LLMResponse("assistant", is_chunk=False)
|
||||||
|
response.result_chain = self._process_content_parts(chunk, response)
|
||||||
|
yield response
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunk.text:
|
||||||
|
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
|
if chunk.candidates[0].finish_reason:
|
||||||
|
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||||
|
llm_response.result_chain = self._process_content_parts(
|
||||||
|
chunk, llm_response
|
||||||
|
)
|
||||||
|
yield llm_response
|
||||||
|
break
|
||||||
|
|
||||||
async def text_chat(
|
async def text_chat(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -389,7 +403,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
model_config["model"] = self.get_model()
|
model_config["model"] = self.get_model()
|
||||||
|
|
||||||
payloads = {"messages": context_query, **model_config}
|
payloads = {"messages": context_query, **model_config}
|
||||||
llm_response = None
|
|
||||||
|
|
||||||
retry = 10
|
retry = 10
|
||||||
keys = self.api_keys.copy()
|
keys = self.api_keys.copy()
|
||||||
@@ -397,30 +410,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
for _ in range(retry):
|
for _ in range(retry):
|
||||||
try:
|
try:
|
||||||
llm_response = await self._query(payloads, func_tool, temp)
|
return await self._query(payloads, func_tool, temp)
|
||||||
break
|
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
if e.code == 429 or "API key not valid" in e.message:
|
if await self._handle_api_error(e, keys):
|
||||||
keys.remove(self.chosen_api_key)
|
continue
|
||||||
if len(keys) > 0:
|
break
|
||||||
self.set_key(random.choice(keys))
|
|
||||||
logger.info(
|
|
||||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
|
||||||
)
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
|
||||||
)
|
|
||||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return llm_response
|
|
||||||
|
|
||||||
async def text_chat_stream(
|
async def text_chat_stream(
|
||||||
self,
|
self,
|
||||||
@@ -461,25 +455,20 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
yield response
|
yield response
|
||||||
break
|
break
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
if e.code == 429 or "API key not valid" in e.message:
|
if await self._handle_api_error(e, keys):
|
||||||
keys.remove(self.chosen_api_key)
|
continue
|
||||||
if len(keys) > 0:
|
break
|
||||||
self.set_key(random.choice(keys))
|
|
||||||
logger.info(
|
async def get_models(self):
|
||||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
try:
|
||||||
)
|
models = await self.client.models.list()
|
||||||
await asyncio.sleep(1)
|
return [
|
||||||
continue
|
m.name.replace("models/", "")
|
||||||
else:
|
for m in models
|
||||||
logger.error(
|
if "generateContent" in m.supported_actions
|
||||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
]
|
||||||
)
|
except APIError as e:
|
||||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
raise Exception(f"获取模型列表失败: {e.message}")
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_current_key(self) -> str:
|
def get_current_key(self) -> str:
|
||||||
return self.chosen_api_key
|
return self.chosen_api_key
|
||||||
@@ -489,14 +478,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
def set_key(self, key):
|
def set_key(self, key):
|
||||||
self.chosen_api_key = key
|
self.chosen_api_key = key
|
||||||
# 重新初始化客户端
|
self._init_client()
|
||||||
self.client = genai.Client(
|
|
||||||
api_key=self.chosen_api_key,
|
|
||||||
http_options=types.HttpOptions(
|
|
||||||
base_url=self.api_base,
|
|
||||||
timeout=self.timeout * 1000, # 毫秒
|
|
||||||
),
|
|
||||||
).aio
|
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user