diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index a3f91fbacc..d5ad330484 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -3,6 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union import httpx +import litellm from litellm import supports_response_schema, supports_system_messages, verbose_logger from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.llms.vertex_ai import PartType @@ -28,6 +29,10 @@ def get_supports_system_message( supports_system_message = supports_system_messages( model=model, custom_llm_provider=_custom_llm_provider ) + + # Vertex Models called in the `/gemini` request/response format also support system messages + if litellm.VertexGeminiConfig._is_model_gemini_gemini_spec_model(model): + supports_system_message = True except Exception as e: verbose_logger.warning( "Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format( diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 73215b4048..cbbbae5b34 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -207,7 +207,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): "extra_headers", "seed", "logprobs", - "top_logprobs" # Added this to list of supported openAI params + "top_logprobs", # Added this to list of supported openAI params ] def map_tool_choice_values( @@ -419,6 +419,17 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): "europe-west9", ] + @staticmethod + def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool: + """ + Returns true if user is trying to call custom model in `/gemini` request/response format + """ + if model is None: + return False + if "gemini/" in model: + return True + return False + def get_flagged_finish_reasons(self) -> Dict[str, str]: """ Return Dictionary of finish reasons which indicate response was flagged diff --git a/litellm/main.py b/litellm/main.py index 8ebf66fd06..b96a054118 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1097,10 +1097,7 @@ def completion( # type: ignore # noqa: PLR0915 logit_bias=logit_bias, user=user, # params to identify the model - model=LitellmCoreRequestUtils.select_model_for_request_transformation( - model=model, - base_model=base_model, - ), + model=model, custom_llm_provider=custom_llm_provider, response_format=response_format, seed=seed, diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 7134c3536e..3687212e84 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -3380,8 +3380,7 @@ def test_gemini_fine_tuned_model_request_consistency(): with patch.object(client, "post", new=MagicMock()) as mock_post_1: try: response_1 = completion( - model="vertex_ai/ft-uuid", - base_model="vertex_ai/gemini-2.0-flash-001", + model="vertex_ai/gemini/ft-uuid", messages=messages, tools=tools, tool_choice="auto",