_is_model_gemini_gemini_spec_model

This commit is contained in:
Ishaan Jaff 2025-03-26 10:53:23 -07:00
parent fb31006cd8
commit bbe69a47a9
4 changed files with 19 additions and 7 deletions

View file

@ -3,6 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType
@ -28,6 +29,10 @@ def get_supports_system_message(
supports_system_message = supports_system_messages(
model=model, custom_llm_provider=_custom_llm_provider
)
# Vertex Models called in the `/gemini` request/response format also support system messages
if litellm.VertexGeminiConfig._is_model_gemini_gemini_spec_model(model):
supports_system_message = True
except Exception as e:
verbose_logger.warning(
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(

View file

@ -207,7 +207,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
"extra_headers",
"seed",
"logprobs",
"top_logprobs" # Added this to list of supported openAI params
"top_logprobs", # Added this to list of supported openAI params
]
def map_tool_choice_values(
@ -419,6 +419,17 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
"europe-west9",
]
@staticmethod
def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool:
"""
Returns true if user is trying to call custom model in `/gemini` request/response format
"""
if model is None:
return False
if "gemini/" in model:
return True
return False
def get_flagged_finish_reasons(self) -> Dict[str, str]:
"""
Return Dictionary of finish reasons which indicate response was flagged

View file

@ -1097,10 +1097,7 @@ def completion( # type: ignore # noqa: PLR0915
logit_bias=logit_bias,
user=user,
# params to identify the model
model=LitellmCoreRequestUtils.select_model_for_request_transformation(
model=model,
base_model=base_model,
),
custom_llm_provider=custom_llm_provider,
response_format=response_format,
seed=seed,

View file

@ -3380,8 +3380,7 @@ def test_gemini_fine_tuned_model_request_consistency():
with patch.object(client, "post", new=MagicMock()) as mock_post_1:
try:
response_1 = completion(
model="vertex_ai/ft-uuid",
base_model="vertex_ai/gemini-2.0-flash-001",
model="vertex_ai/gemini/ft-uuid",
messages=messages,
tools=tools,
tool_choice="auto",