refactor: 拆分方法以提高代码可读性

This commit is contained in:
Raven95676
2025-04-12 00:23:57 +08:00
parent bd24cf3ea4
commit 44dbe475af

View File

@@ -55,12 +55,18 @@ class ProviderGoogleGenAI(Provider):
) )
self.api_keys: List = provider_config.get("key", []) self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout: int = provider_config.get("timeout", 180) self.timeout: int = int(provider_config.get("timeout", 180))
self.api_base: Optional[str] = provider_config.get("api_base", None) self.api_base: Optional[str] = provider_config.get("api_base", None)
if self.api_base and self.api_base.endswith("/"): if self.api_base and self.api_base.endswith("/"):
self.api_base = self.api_base[:-1] self.api_base = self.api_base[:-1]
if isinstance(self.timeout, str):
self.timeout = int(self.timeout) self._init_client()
self.set_model(provider_config["model_config"]["model"])
self._init_safety_settings()
def _init_client(self) -> None:
"""初始化Gemini客户端"""
self.client = genai.Client( self.client = genai.Client(
api_key=self.chosen_api_key, api_key=self.chosen_api_key,
http_options=types.HttpOptions( http_options=types.HttpOptions(
@@ -68,8 +74,9 @@ class ProviderGoogleGenAI(Provider):
timeout=self.timeout * 1000, # 毫秒 timeout=self.timeout * 1000, # 毫秒
), ),
).aio ).aio
self.set_model(provider_config["model_config"]["model"])
def _init_safety_settings(self) -> None:
"""初始化安全设置"""
user_safety_config = self.provider_config.get("gm_safety_settings", {}) user_safety_config = self.provider_config.get("gm_safety_settings", {})
self.safety_settings = [ self.safety_settings = [
types.SafetySetting( types.SafetySetting(
@@ -80,16 +87,59 @@ class ProviderGoogleGenAI(Provider):
and threshold_str in self.THRESHOLD_MAPPING and threshold_str in self.THRESHOLD_MAPPING
] ]
async def get_models(self): async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool:
try: """处理API错误返回是否需要重试"""
models = await self.client.models.list() if e.code == 429 or "API key not valid" in e.message:
return [ keys.remove(self.chosen_api_key)
m.name.replace("models/", "") if len(keys) > 0:
for m in models self.set_key(random.choice(keys))
if "generateContent" in m.supported_actions logger.info(
] f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
except APIError as e: )
raise Exception(f"获取模型列表失败: {e.message}") await asyncio.sleep(1)
return True
else:
logger.error(
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
)
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
else:
logger.error(
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
)
raise e
async def _prepare_query_config(
self,
tools: Optional[FuncCall] = None,
system_instruction: Optional[str] = None,
temperature: Optional[float] = 0.7,
modalities: Optional[List[str]] = None,
) -> types.GenerateContentConfig:
"""准备查询配置"""
if not modalities:
modalities = ["Text"]
if self.provider_config.get("gm_resp_image_modal", False):
modalities.append("Image")
tool_list = None
if tools:
func_desc = tools.get_func_desc_google_genai_style()
if func_desc:
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
return types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
response_modalities=modalities,
tools=tool_list,
safety_settings=self.safety_settings if self.safety_settings else None,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
)
@staticmethod @staticmethod
def _prepare_conversation(payloads: Dict) -> List[types.Content]: def _prepare_conversation(payloads: Dict) -> List[types.Content]:
@@ -165,165 +215,6 @@ class ProviderGoogleGenAI(Provider):
return gemini_contents return gemini_contents
async def _query(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> LLMResponse:
"""非流式请求 Gemini API"""
tool_list = None
if tools:
func_desc = tools.get_func_desc_google_genai_style()
if func_desc:
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
conversation = self._prepare_conversation(payloads)
modalities = ["Text"]
if self.provider_config.get("gm_resp_image_modal", False):
modalities.append("Image")
result: Optional[types.GenerateContentResponse] = None
while True:
try:
result = await self.client.models.generate_content(
model=self.get_model(),
contents=conversation,
config=types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
response_modalities=modalities,
tools=tool_list,
safety_settings=self.safety_settings
if self.safety_settings
else None,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
),
)
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
if temperature > 2:
raise Exception("温度参数已超过最大值2仍然发生recitation")
temperature += 0.2
logger.warning(
f"发生了recitation正在提高温度至{temperature:.1f}重试..."
)
continue
break
except APIError as e:
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt已自动去除(影响人格设置)"
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tool_list = None
elif (
"Multi-modal output is not supported"
or "Model does not support the requested response modalities"
in e.message
):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
)
modalities = ["Text"]
else:
raise
continue
llm_response = LLMResponse("assistant")
llm_response.result_chain = self._process_content_parts(result, llm_response)
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
tool_list = None
if tools:
func_desc = tools.get_func_desc_google_genai_style()
if func_desc:
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
conversation = self._prepare_conversation(payloads)
result = None
while True:
try:
result = await self.client.models.generate_content_stream(
model=self.get_model(),
contents=conversation,
config=types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
tools=tool_list,
safety_settings=self.safety_settings
if self.safety_settings
else None,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
),
)
break
except APIError as e:
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt已自动去除(影响人格设置)"
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tool_list = None
else:
raise
continue
if not result:
raise Exception("API 返回异常")
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)
if chunk.candidates[0].content.parts and any(
part.function_call for part in chunk.candidates[0].content.parts
):
response = LLMResponse("assistant", is_chunk=False)
response.result_chain = self._process_content_parts(chunk, response)
yield response
break
if chunk.text:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
yield llm_response
if chunk.candidates[0].finish_reason:
llm_response = LLMResponse("assistant", is_chunk=False)
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
)
yield llm_response
break
@staticmethod @staticmethod
def _process_content_parts( def _process_content_parts(
result: types.GenerateContentResponse, llm_response: LLMResponse result: types.GenerateContentResponse, llm_response: LLMResponse
@@ -361,6 +252,129 @@ class ProviderGoogleGenAI(Provider):
chain.append(Comp.Image.fromBytes(part.inline_data.data)) chain.append(Comp.Image.fromBytes(part.inline_data.data))
return MessageChain(chain=chain) return MessageChain(chain=chain)
async def _query(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
modalities = ["Text"]
if self.provider_config.get("gm_resp_image_modal", False):
modalities.append("Image")
conversation = self._prepare_conversation(payloads)
result: Optional[types.GenerateContentResponse] = None
while True:
try:
config = await self._prepare_query_config(
tools, system_instruction, temperature, modalities
)
result = await self.client.models.generate_content(
model=self.get_model(),
contents=conversation,
config=config,
)
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
if temperature > 2:
raise Exception("温度参数已超过最大值2仍然发生recitation")
temperature += 0.2
logger.warning(
f"发生了recitation正在提高温度至{temperature:.1f}重试..."
)
continue
break
except APIError as e:
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt已自动去除(影响人格设置)"
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tools = None
elif (
"Multi-modal output is not supported"
or "Model does not support the requested response modalities"
in e.message
):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
)
modalities = ["Text"]
else:
raise
continue
llm_response = LLMResponse("assistant")
llm_response.result_chain = self._process_content_parts(result, llm_response)
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
conversation = self._prepare_conversation(payloads)
result = None
while True:
try:
config = await self._prepare_query_config(
tools, system_instruction, temperature
)
result = await self.client.models.generate_content_stream(
model=self.get_model(),
contents=conversation,
config=config,
)
break
except APIError as e:
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt已自动去除(影响人格设置)"
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tools = None
else:
raise
continue
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)
if chunk.candidates[0].content.parts and any(
part.function_call for part in chunk.candidates[0].content.parts
):
response = LLMResponse("assistant", is_chunk=False)
response.result_chain = self._process_content_parts(chunk, response)
yield response
break
if chunk.text:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
yield llm_response
if chunk.candidates[0].finish_reason:
llm_response = LLMResponse("assistant", is_chunk=False)
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
)
yield llm_response
break
async def text_chat( async def text_chat(
self, self,
prompt: str, prompt: str,
@@ -389,7 +403,6 @@ class ProviderGoogleGenAI(Provider):
model_config["model"] = self.get_model() model_config["model"] = self.get_model()
payloads = {"messages": context_query, **model_config} payloads = {"messages": context_query, **model_config}
llm_response = None
retry = 10 retry = 10
keys = self.api_keys.copy() keys = self.api_keys.copy()
@@ -397,30 +410,11 @@ class ProviderGoogleGenAI(Provider):
for _ in range(retry): for _ in range(retry):
try: try:
llm_response = await self._query(payloads, func_tool, temp) return await self._query(payloads, func_tool, temp)
break
except APIError as e: except APIError as e:
if e.code == 429 or "API key not valid" in e.message: if await self._handle_api_error(e, keys):
keys.remove(self.chosen_api_key) continue
if len(keys) > 0: break
self.set_key(random.choice(keys))
logger.info(
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
)
await asyncio.sleep(1)
continue
else:
logger.error(
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
)
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
else:
logger.error(
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
)
raise e
return llm_response
async def text_chat_stream( async def text_chat_stream(
self, self,
@@ -461,25 +455,20 @@ class ProviderGoogleGenAI(Provider):
yield response yield response
break break
except APIError as e: except APIError as e:
if e.code == 429 or "API key not valid" in e.message: if await self._handle_api_error(e, keys):
keys.remove(self.chosen_api_key) continue
if len(keys) > 0: break
self.set_key(random.choice(keys))
logger.info( async def get_models(self):
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." try:
) models = await self.client.models.list()
await asyncio.sleep(1) return [
continue m.name.replace("models/", "")
else: for m in models
logger.error( if "generateContent" in m.supported_actions
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." ]
) except APIError as e:
raise Exception("达到了 Gemini 速率限制, 请稍后再试...") raise Exception(f"获取模型列表失败: {e.message}")
else:
logger.error(
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
)
raise e
def get_current_key(self) -> str: def get_current_key(self) -> str:
return self.chosen_api_key return self.chosen_api_key
@@ -489,14 +478,7 @@ class ProviderGoogleGenAI(Provider):
def set_key(self, key): def set_key(self, key):
self.chosen_api_key = key self.chosen_api_key = key
# 重新初始化客户端 self._init_client()
self.client = genai.Client(
api_key=self.chosen_api_key,
http_options=types.HttpOptions(
base_url=self.api_base,
timeout=self.timeout * 1000, # 毫秒
),
).aio
async def assemble_context(self, text: str, image_urls: List[str] = None): async def assemble_context(self, text: str, image_urls: List[str] = None):
""" """