From 610cdc71ef7ac68a74f16f451df664cc7e45dcf5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 16 Apr 2025 19:35:49 -0700 Subject: [PATCH] feat(llm_passthrough_endpoints.py): add vertex endpoint specific passthrough handlers --- litellm/llms/vertex_ai/common_utils.py | 1 + .../llm_passthrough_endpoints.py | 83 ++++++++++-- .../test_llm_pass_through_endpoints.py | 119 +++++++++++++++++- 3 files changed, 195 insertions(+), 8 deletions(-) diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 752ba777af..314fb81901 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -401,6 +401,7 @@ def construct_target_url( Constructed Url: POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents """ + new_base_url = httpx.URL(base_url) if "locations" in requested_route: # contains the target project id + location if vertex_project and vertex_location: diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index e899293f58..21df55c084 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -578,16 +578,70 @@ async def azure_proxy_route( ) +from abc import ABC, abstractmethod + + +class BaseVertexAIPassThroughHandler(ABC): + @staticmethod + @abstractmethod + def get_default_base_target_url(vertex_location: Optional[str]) -> str: + pass + + @staticmethod + @abstractmethod + def update_base_target_url_with_credential_location( + base_target_url: str, vertex_location: Optional[str] + ) -> str: + pass + + +class VertexAIDiscoveryPassThroughHandler(BaseVertexAIPassThroughHandler): + @staticmethod + def get_default_base_target_url(vertex_location: Optional[str]) -> str: + return "https://discoveryengine.googleapis.com/" + + @staticmethod + def update_base_target_url_with_credential_location( + base_target_url: str, vertex_location: Optional[str] + ) -> str: + return base_target_url + + +class VertexAIPassThroughHandler(BaseVertexAIPassThroughHandler): + @staticmethod + def get_default_base_target_url(vertex_location: Optional[str]) -> str: + return f"https://{vertex_location}-aiplatform.googleapis.com/" + + @staticmethod + def update_base_target_url_with_credential_location( + base_target_url: str, vertex_location: Optional[str] + ) -> str: + return f"https://{vertex_location}-aiplatform.googleapis.com/" + + +def get_vertex_pass_through_handler( + call_type: Literal["discovery", "aiplatform"] +) -> BaseVertexAIPassThroughHandler: + if call_type == "discovery": + return VertexAIDiscoveryPassThroughHandler() + elif call_type == "aiplatform": + return VertexAIPassThroughHandler() + else: + raise ValueError(f"Invalid call type: {call_type}") + + async def _base_vertex_proxy_route( endpoint: str, request: Request, fastapi_response: Response, - base_target_url: str, + get_vertex_pass_through_handler: BaseVertexAIPassThroughHandler, user_api_key_dict: Optional[UserAPIKeyAuth] = None, ): """ Base function for Vertex AI passthrough routes. Handles common logic for all Vertex AI services. + + Default base_target_url is `https://{vertex_location}-aiplatform.googleapis.com/` """ from litellm.llms.vertex_ai.common_utils import ( construct_target_url, @@ -598,6 +652,11 @@ async def _base_vertex_proxy_route( encoded_endpoint = httpx.URL(endpoint).path verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} + api_key_to_use = get_litellm_virtual_key(request=request) + user_api_key_dict = await user_api_key_auth( + request=request, + api_key=api_key_to_use, + ) if user_api_key_dict is None: api_key_to_use = get_litellm_virtual_key(request=request) @@ -613,6 +672,10 @@ async def _base_vertex_proxy_route( location=vertex_location, ) + base_target_url = get_vertex_pass_through_handler.get_default_base_target_url( + vertex_location + ) + headers_passed_through = False # Use headers from the incoming request if no vertex credentials are found if vertex_credentials is None or vertex_credentials.vertex_project is None: @@ -650,6 +713,13 @@ async def _base_vertex_proxy_route( "Authorization": f"Bearer {auth_header}", } + base_target_url = get_vertex_pass_through_handler.update_base_target_url_with_credential_location( + base_target_url, vertex_location + ) + + if base_target_url is None: + base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" + request_route = encoded_endpoint verbose_proxy_logger.debug("request_route %s", request_route) @@ -713,11 +783,13 @@ async def vertex_discovery_proxy_route( Target url: `https://discoveryengine.googleapis.com` """ + + discovery_handler = get_vertex_pass_through_handler(call_type="discovery") return await _base_vertex_proxy_route( endpoint=endpoint, request=request, fastapi_response=fastapi_response, - base_target_url="https://discoveryengine.googleapis.com/", + get_vertex_pass_through_handler=discovery_handler, ) @@ -743,16 +815,13 @@ async def vertex_proxy_route( [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) """ - from litellm.llms.vertex_ai.common_utils import get_vertex_location_from_url - - vertex_location: Optional[str] = get_vertex_location_from_url(endpoint) - base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" + ai_platform_handler = get_vertex_pass_through_handler(call_type="aiplatform") return await _base_vertex_proxy_route( endpoint=endpoint, request=request, fastapi_response=fastapi_response, - base_target_url=base_target_url, + get_vertex_pass_through_handler=ai_platform_handler, user_api_key_dict=user_api_key_dict, ) diff --git a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py index da08dea605..501b4364da 100644 --- a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py +++ b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py @@ -19,13 +19,13 @@ from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( BaseOpenAIPassThroughHandler, RouteChecks, create_pass_through_route, + vertex_discovery_proxy_route, vertex_proxy_route, ) from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials class TestBaseOpenAIPassThroughHandler: - def test_join_url_paths(self): print("\nTesting _join_url_paths method...") @@ -456,3 +456,120 @@ class TestVertexAIPassThroughHandler: mock_auth.assert_called_once() call_args = mock_auth.call_args[1] assert call_args["api_key"] == "Bearer test-key-123" + + +class TestVertexAIDiscoveryPassThroughHandler: + """ + Test cases for Vertex AI Discovery passthrough endpoint + """ + + @pytest.mark.asyncio + async def test_vertex_discovery_passthrough_with_credentials(self, monkeypatch): + """ + Test that when passthrough credentials are set, they are correctly used in the request + """ + from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import ( + PassthroughEndpointRouter, + ) + + vertex_project = "test-project" + vertex_location = "us-central1" + vertex_credentials = "test-creds" + + pass_through_router = PassthroughEndpointRouter() + + pass_through_router.add_vertex_credentials( + project_id=vertex_project, + location=vertex_location, + vertex_credentials=vertex_credentials, + ) + + monkeypatch.setattr( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router", + pass_through_router, + ) + + endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/dataStores/default/servingConfigs/default:search" + + # Mock request + mock_request = Request( + scope={ + "type": "http", + "method": "POST", + "path": endpoint, + "headers": [ + (b"Authorization", b"Bearer test-creds"), + (b"Content-Type", b"application/json"), + ], + } + ) + + # Mock response + mock_response = Response() + + # Mock vertex credentials + test_project = vertex_project + test_location = vertex_location + test_token = vertex_credentials + + with mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async" + ) as mock_ensure_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url" + ) as mock_get_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" + ) as mock_create_route: + mock_ensure_token.return_value = ("test-auth-header", test_project) + mock_get_token.return_value = (test_token, "") + + # Call the route + try: + await vertex_discovery_proxy_route( + endpoint=endpoint, + request=mock_request, + fastapi_response=mock_response, + ) + except Exception as e: + print(f"Error: {e}") + + # Verify create_pass_through_route was called with correct arguments + mock_create_route.assert_called_once_with( + endpoint=endpoint, + target=f"https://discoveryengine.googleapis.com/v1/projects/{test_project}/locations/{test_location}/dataStores/default/servingConfigs/default:search", + custom_headers={"Authorization": f"Bearer {test_token}"}, + ) + + @pytest.mark.asyncio + async def test_vertex_discovery_proxy_route_api_key_auth(self): + """ + Test that the route correctly handles API key authentication + """ + # Mock dependencies + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + mock_request.method = "POST" + mock_response = Mock() + + with patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth" + ) as mock_auth: + mock_auth.return_value = {"api_key": "test-key-123"} + + with patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" + ) as mock_pass_through: + mock_pass_through.return_value = AsyncMock( + return_value={"status": "success"} + ) + + # Call the function + result = await vertex_discovery_proxy_route( + endpoint="v1/projects/test-project/locations/us-central1/dataStores/default/servingConfigs/default:search", + request=mock_request, + fastapi_response=mock_response, + ) + + # Verify user_api_key_auth was called with the correct Bearer token + mock_auth.assert_called_once() + call_args = mock_auth.call_args[1] + assert call_args["api_key"] == "Bearer test-key-123"