From 06e69a414e7af0ca8a0ef9234fe9d75b636aaa70 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 22 Mar 2025 11:32:01 -0700 Subject: [PATCH] fix(vertex_ai/common_utils.py): fix handling constructed url with default vertex config --- litellm/llms/vertex_ai/common_utils.py | 29 ++++++-- .../llm_passthrough_endpoints.py | 11 +-- .../passthrough_endpoint_router.py | 6 +- .../test_llm_pass_through_endpoints.py | 69 +++++++++++++++++++ 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 4a4c428941..a3f91fbacc 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -303,11 +303,26 @@ def get_vertex_location_from_url(url: str) -> Optional[str]: return match.group(1) if match else None +def replace_project_and_location_in_route( + requested_route: str, vertex_project: str, vertex_location: str +) -> str: + """ + Replace project and location values in the route with the provided values + """ + # Replace project and location values while keeping route structure + modified_route = re.sub( + r"/projects/[^/]+/locations/[^/]+/", + f"/projects/{vertex_project}/locations/{vertex_location}/", + requested_route, + ) + return modified_route + + def construct_target_url( base_url: str, requested_route: str, - default_vertex_location: Optional[str], - default_vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_project: Optional[str], ) -> httpx.URL: """ Allow user to specify their own project id / location. @@ -321,8 +336,12 @@ def construct_target_url( """ new_base_url = httpx.URL(base_url) if "locations" in requested_route: # contains the target project id + location - updated_url = new_base_url.copy_with(path=requested_route) - return updated_url + if vertex_project and vertex_location: + requested_route = replace_project_and_location_in_route( + requested_route, vertex_project, vertex_location + ) + return new_base_url.copy_with(path=requested_route) + """ - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest) - Add default project id @@ -333,7 +352,7 @@ def construct_target_url( vertex_version = "v1beta1" base_requested_route = "{}/projects/{}/locations/{}".format( - vertex_version, default_vertex_project, default_vertex_location + vertex_version, vertex_project, vertex_location ) updated_requested_route = "/" + base_requested_route + requested_route diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 24ab08c167..0fae1e6f0b 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -463,13 +463,8 @@ async def vertex_proxy_route( location=vertex_location, ) - if vertex_credentials is None: - raise Exception( - f"No matching vertex credentials found, for project_id: {vertex_project}, location: {vertex_location}. No default_vertex_config set either." - ) - # Use headers from the incoming request if no vertex credentials are found - if vertex_credentials.vertex_project is None: + if vertex_credentials is None or vertex_credentials.vertex_project is None: headers = dict(request.headers) or {} verbose_proxy_logger.debug( "default_vertex_config not set, incoming request headers %s", headers @@ -518,8 +513,8 @@ async def vertex_proxy_route( updated_url = construct_target_url( base_url=base_target_url, requested_route=encoded_endpoint, - default_vertex_location=vertex_location, - default_vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_project=vertex_project, ) # base_url = httpx.URL(base_target_url) # updated_url = base_url.copy_with(path=encoded_endpoint) diff --git a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py index 897faa1717..89cccfc071 100644 --- a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py +++ b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -150,19 +150,17 @@ class PassthroughEndpointRouter: """ Get the vertex credentials for the given project-id, location """ - default_vertex_config: Optional[VertexPassThroughCredentials] = None - deployment_key = self._get_deployment_key( project_id=project_id, location=location, ) if deployment_key is None: - return default_vertex_config + return self.default_vertex_config if deployment_key in self.deployment_key_to_vertex_credentials: return self.deployment_key_to_vertex_credentials[deployment_key] else: - return default_vertex_config + return self.default_vertex_config def _get_credential_name_for_provider( self, diff --git a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py index 8f8fbbe9de..ea5b908796 100644 --- a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py +++ b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py @@ -1,6 +1,7 @@ import json import os import sys +import traceback from unittest import mock from unittest.mock import MagicMock, patch @@ -262,3 +263,71 @@ class TestVertexAIPassThroughHandler: target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent", custom_headers={"Authorization": f"Bearer {test_token}"}, ) + + @pytest.mark.asyncio + async def test_vertex_passthrough_with_default_credentials(self, monkeypatch): + """ + Test that when no passthrough credentials are set, default credentials are used in the request + """ + from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import ( + PassthroughEndpointRouter, + ) + + # Setup default credentials + default_project = "default-project" + default_location = "us-central1" + default_credentials = "default-creds" + + pass_through_router = PassthroughEndpointRouter() + pass_through_router.default_vertex_config = VertexPassThroughCredentials( + vertex_project=default_project, + vertex_location=default_location, + vertex_credentials=default_credentials, + ) + + monkeypatch.setattr( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router", + pass_through_router, + ) + + # Use different project/location in request than the default + request_project = "non-existing-project" + request_location = "bad-location" + endpoint = f"/v1/projects/{request_project}/locations/{request_location}/publishers/google/models/gemini-1.5-flash:generateContent" + + mock_request = Request( + scope={ + "type": "http", + "method": "POST", + "path": endpoint, + "headers": {}, + } + ) + mock_response = Response() + + with mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async" + ) as mock_ensure_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url" + ) as mock_get_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" + ) as mock_create_route: + mock_ensure_token.return_value = ("test-auth-header", default_project) + mock_get_token.return_value = (default_credentials, "") + + try: + await vertex_proxy_route( + endpoint=endpoint, + request=mock_request, + fastapi_response=mock_response, + ) + except Exception as e: + traceback.print_exc() + print(f"Error: {e}") + + # Verify default credentials were used + mock_create_route.assert_called_once_with( + endpoint=endpoint, + target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent", + custom_headers={"Authorization": f"Bearer {default_credentials}"}, + )