diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index b5bad6b858..2361939d81 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -75,8 +75,8 @@ def _get_vertex_url( ) -> Tuple[str, str]: url: Optional[str] = None endpoint: Optional[str] = None - if litellm.VertexGeminiConfig._is_model_gemini_gemini_spec_model(model): - model = litellm.VertexGeminiConfig._get_model_name_from_gemini_spec_model(model) + + model = litellm.VertexGeminiConfig.get_model_for_vertex_ai_url(model=model) if mode == "chat": ### SET RUNTIME ENDPOINT ### endpoint = "generateContent" diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 0e48c690ba..8aecfffd86 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -419,6 +419,25 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): "europe-west9", ] + @staticmethod + def get_model_for_vertex_ai_url(model: str) -> str: + """ + Returns the model name to use in the request to Vertex AI + + Handles 2 cases: + 1. User passed `model="vertex_ai/gemini/ft-uuid"`, we need to return `ft-uuid` for the request to Vertex AI + 2. User passed `model="vertex_ai/gemini-2.0-flash-001"`, we need to return `gemini-2.0-flash-001` for the request to Vertex AI + + Args: + model (str): The model name to use in the request to Vertex AI + + Returns: + str: The model name to use in the request to Vertex AI + """ + if VertexGeminiConfig._is_model_gemini_gemini_spec_model(model): + return VertexGeminiConfig._get_model_name_from_gemini_spec_model(model) + return model + @staticmethod def _is_model_gemini_gemini_spec_model(model: Optional[str]) -> bool: """ diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 5d99ecb5cd..83c99d766b 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -3395,6 +3395,14 @@ def test_gemini_fine_tuned_model_request_consistency(): first_request_body = mock_post_1.call_args.kwargs["json"] print("first_request_body", first_request_body) + # Validate correct `model` is added to the request to Vertex AI + print("final URL=", mock_post_1.call_args.kwargs["url"]) + # Validate the request url + assert ( + "publishers/google/models/ft-uuid:generateContent" + in mock_post_1.call_args.kwargs["url"] + ) + # Second request with patch.object(client, "post", new=MagicMock()) as mock_post_2: try: