_is_model_gemini_gemini_spec_model

This commit is contained in:
Ishaan Jaff 2025-03-26 10:53:23 -07:00
parent fb31006cd8
commit bbe69a47a9
4 changed files with 19 additions and 7 deletions

View file

@ -3,6 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx import httpx
import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType from litellm.types.llms.vertex_ai import PartType
@ -28,6 +29,10 @@ def get_supports_system_message(
supports_system_message = supports_system_messages( supports_system_message = supports_system_messages(
model=model, custom_llm_provider=_custom_llm_provider model=model, custom_llm_provider=_custom_llm_provider
) )
# Vertex Models called in the `/gemini` request/response format also support system messages
if litellm.VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
supports_system_message = True
except Exception as e: except Exception as e:
verbose_logger.warning( verbose_logger.warning(
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format( "Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(

View file

@ -207,7 +207,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
"extra_headers", "extra_headers",
"seed", "seed",
"logprobs", "logprobs",
"top_logprobs" # Added this to list of supported openAI params "top_logprobs", # Added this to list of supported openAI params
] ]
def map_tool_choice_values( def map_tool_choice_values(
@ -419,6 +419,17 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
"europe-west9", "europe-west9",
] ]
@staticmethod
def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool:
"""
Returns true if user is trying to call custom model in `/gemini` request/response format
"""
if model is None:
return False
if "gemini/" in model:
return True
return False
def get_flagged_finish_reasons(self) -> Dict[str, str]: def get_flagged_finish_reasons(self) -> Dict[str, str]:
""" """
Return Dictionary of finish reasons which indicate response was flagged Return Dictionary of finish reasons which indicate response was flagged

View file

@ -1097,10 +1097,7 @@ def completion( # type: ignore # noqa: PLR0915
logit_bias=logit_bias, logit_bias=logit_bias,
user=user, user=user,
# params to identify the model # params to identify the model
model=LitellmCoreRequestUtils.select_model_for_request_transformation(
model=model, model=model,
base_model=base_model,
),
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
response_format=response_format, response_format=response_format,
seed=seed, seed=seed,

View file

@ -3380,8 +3380,7 @@ def test_gemini_fine_tuned_model_request_consistency():
with patch.object(client, "post", new=MagicMock()) as mock_post_1: with patch.object(client, "post", new=MagicMock()) as mock_post_1:
try: try:
response_1 = completion( response_1 = completion(
model="vertex_ai/ft-uuid", model="vertex_ai/gemini/ft-uuid",
base_model="vertex_ai/gemini-2.0-flash-001",
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice="auto", tool_choice="auto",