diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 3b58567881..7444e3d1c1 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -15,7 +15,10 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( from litellm.secret_managers.main import get_secret_str from litellm.types.passthrough_endpoints.vertex_ai import * +from .vertex_passthrough_router import VertexPassThroughRouter + router = APIRouter() +vertex_pass_through_router = VertexPassThroughRouter() default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials() @@ -59,7 +62,14 @@ def set_default_vertex_config(config: Optional[dict] = None): if isinstance(value, str) and value.startswith("os.environ/"): config[key] = litellm.get_secret(value) - default_vertex_config = VertexPassThroughCredentials(**config) + _set_default_vertex_config(VertexPassThroughCredentials(**config)) + + +def _set_default_vertex_config( + vertex_pass_through_credentials: VertexPassThroughCredentials, +): + global default_vertex_config + default_vertex_config = vertex_pass_through_credentials def exception_handler(e: Exception): @@ -147,9 +157,6 @@ async def vertex_proxy_route( [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) """ encoded_endpoint = httpx.URL(endpoint).path - - import re - verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} api_key_to_use = get_litellm_virtual_key(request=request) @@ -158,31 +165,37 @@ async def vertex_proxy_route( api_key=api_key_to_use, ) - vertex_project = None - vertex_location = None - # Use headers from the incoming request if default_vertex_config is not set - if default_vertex_config.vertex_project is None: + vertex_project: Optional[str] = ( + VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint) + ) + vertex_location: Optional[str] = ( + VertexPassThroughRouter._get_vertex_location_from_url(endpoint) + ) + vertex_credentials = vertex_pass_through_router.get_vertex_credentials( + project_id=vertex_project, + location=vertex_location, + ) + + # Use headers from the incoming request if no vertex credentials are found + if vertex_credentials.vertex_project is None: headers = dict(request.headers) or {} verbose_proxy_logger.debug( "default_vertex_config not set, incoming request headers %s", headers ) - # extract location from endpoint, endpoint - # "v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent" - match = re.search(r"/locations/([^/]+)", endpoint) - vertex_location = match.group(1) if match else None base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" headers.pop("content-length", None) headers.pop("host", None) else: - vertex_project = default_vertex_config.vertex_project - vertex_location = default_vertex_config.vertex_location - vertex_credentials = default_vertex_config.vertex_credentials + vertex_project = vertex_credentials.vertex_project + vertex_location = vertex_credentials.vertex_location + vertex_credentials_str = vertex_credentials.vertex_credentials + # Construct base URL for the target endpoint base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" _auth_header, vertex_project = ( await vertex_fine_tuning_apis_instance._ensure_access_token_async( - credentials=vertex_credentials, + credentials=vertex_credentials_str, project_id=vertex_project, custom_llm_provider="vertex_ai_beta", ) @@ -192,7 +205,7 @@ async def vertex_proxy_route( model="", auth_header=_auth_header, gemini_api_key=None, - vertex_credentials=vertex_credentials, + vertex_credentials=vertex_credentials_str, vertex_project=vertex_project, vertex_location=vertex_location, stream=False, diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py b/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py new file mode 100644 index 0000000000..fdba424765 --- /dev/null +++ b/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py @@ -0,0 +1,120 @@ +import json +import re +from typing import Dict, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + VertexPassThroughCredentials, +) + + +class VertexPassThroughRouter: + """ + Vertex Pass Through Router for Vertex AI pass-through endpoints + + + - if request specifies a project-id, location -> use credentials corresponding to the project-id, location + - if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION + """ + + def __init__(self): + """ + Initialize the VertexPassThroughRouter + Stores the vertex credentials for each deployment key + ``` + { + "project_id-location": VertexPassThroughCredentials, + "adroit-crow-us-central1": VertexPassThroughCredentials, + } + ``` + """ + self.deployment_key_to_vertex_credentials: Dict[ + str, VertexPassThroughCredentials + ] = {} + pass + + def get_vertex_credentials( + self, project_id: Optional[str], location: Optional[str] + ) -> VertexPassThroughCredentials: + """ + Get the vertex credentials for the given project-id, location + """ + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + default_vertex_config, + ) + + deployment_key = self._get_deployment_key( + project_id=project_id, + location=location, + ) + if deployment_key is None: + return default_vertex_config + if deployment_key in self.deployment_key_to_vertex_credentials: + return self.deployment_key_to_vertex_credentials[deployment_key] + else: + return default_vertex_config + + def add_vertex_credentials( + self, + project_id: str, + location: str, + vertex_credentials: str, + ): + """ + Add the vertex credentials for the given project-id, location + """ + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + _set_default_vertex_config, + ) + + deployment_key = self._get_deployment_key( + project_id=project_id, + location=location, + ) + if deployment_key is None: + verbose_proxy_logger.debug( + "No deployment key found for project-id, location" + ) + return + vertex_pass_through_credentials = VertexPassThroughCredentials( + vertex_project=project_id, + vertex_location=location, + vertex_credentials=vertex_credentials, + ) + self.deployment_key_to_vertex_credentials[deployment_key] = ( + vertex_pass_through_credentials + ) + verbose_proxy_logger.debug( + f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}" + ) + _set_default_vertex_config(vertex_pass_through_credentials) + + def _get_deployment_key( + self, project_id: Optional[str], location: Optional[str] + ) -> Optional[str]: + """ + Get the deployment key for the given project-id, location + """ + if project_id is None or location is None: + return None + return f"{project_id}-{location}" + + @staticmethod + def _get_vertex_project_id_from_url(url: str) -> Optional[str]: + """ + Get the vertex project id from the url + + `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` + """ + match = re.search(r"/projects/([^/]+)", url) + return match.group(1) if match else None + + @staticmethod + def _get_vertex_location_from_url(url: str) -> Optional[str]: + """ + Get the vertex location from the url + + `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` + """ + match = re.search(r"/locations/([^/]+)", url) + return match.group(1) if match else None diff --git a/litellm/router.py b/litellm/router.py index 58809197ee..63669ef588 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4133,8 +4133,48 @@ class Router: litellm_router_instance=self, model=deployment.to_json(exclude_none=True) ) + self._initialize_deployment_for_pass_through( + deployment=deployment, + custom_llm_provider=custom_llm_provider, + model=deployment.litellm_params.model, + ) + return deployment + def _initialize_deployment_for_pass_through( + self, deployment: Deployment, custom_llm_provider: str, model: str + ): + """ + Optional: Initialize deployment for pass-through endpoints if `deployment.litellm_params.use_in_pass_through` is True + + Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints + """ + if deployment.litellm_params.use_in_pass_through is True: + if custom_llm_provider == "vertex_ai": + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + vertex_pass_through_router, + ) + + if ( + deployment.litellm_params.vertex_project is None + or deployment.litellm_params.vertex_location is None + or deployment.litellm_params.vertex_credentials is None + ): + raise ValueError( + "vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints" + ) + vertex_pass_through_router.add_vertex_credentials( + project_id=deployment.litellm_params.vertex_project, + location=deployment.litellm_params.vertex_location, + vertex_credentials=deployment.litellm_params.vertex_credentials, + ) + else: + verbose_router_logger.error( + f"Unsupported provider - {custom_llm_provider} for pass-through endpoints" + ) + pass + pass + def add_deployment(self, deployment: Deployment) -> Optional[Deployment]: """ Parameters: diff --git a/litellm/types/router.py b/litellm/types/router.py index 9393bb2213..f59c3ce671 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -176,7 +176,7 @@ class GenericLiteLLMParams(BaseModel): # Deployment budgets max_budget: Optional[float] = None budget_duration: Optional[str] = None - + use_in_pass_through: Optional[bool] = False model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) def __init__( @@ -215,6 +215,8 @@ class GenericLiteLLMParams(BaseModel): # Deployment budgets max_budget: Optional[float] = None, budget_duration: Optional[str] = None, + # Pass through params + use_in_pass_through: Optional[bool] = False, **params, ): args = locals() @@ -276,6 +278,8 @@ class LiteLLM_Params(GenericLiteLLMParams): # OpenAI / Azure Whisper # set a max-size of file that can be passed to litellm proxy max_file_size_mb: Optional[float] = None, + # will use deployment on pass-through endpoints if True + use_in_pass_through: Optional[bool] = False, **params, ): args = locals() diff --git a/litellm/types/utils.py b/litellm/types/utils.py index b2b198a4ff..a1c19dab1b 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1729,6 +1729,7 @@ all_litellm_params = [ "max_fallbacks", "max_budget", "budget_duration", + "use_in_pass_through", ] + list(StandardCallbackDynamicParams.__annotations__.keys()) diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py index 4c66f69934..d82cba8a11 100644 --- a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, Mock, patch sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system-path import httpx @@ -23,6 +23,9 @@ from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( VertexPassThroughCredentials, default_vertex_config, ) +from litellm.proxy.vertex_ai_endpoints.vertex_passthrough_router import ( + VertexPassThroughRouter, +) @pytest.mark.asyncio @@ -167,3 +170,123 @@ async def test_set_default_vertex_config(): del os.environ["DEFAULT_VERTEXAI_LOCATION"] del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] del os.environ["GOOGLE_CREDS"] + + +@pytest.mark.asyncio +async def test_vertex_passthrough_router_init(): + """Test VertexPassThroughRouter initialization""" + router = VertexPassThroughRouter() + assert isinstance(router.deployment_key_to_vertex_credentials, dict) + assert len(router.deployment_key_to_vertex_credentials) == 0 + + +@pytest.mark.asyncio +async def test_get_vertex_credentials_none(): + """Test get_vertex_credentials with various inputs""" + from litellm.proxy.vertex_ai_endpoints import vertex_endpoints + + setattr(vertex_endpoints, "default_vertex_config", VertexPassThroughCredentials()) + router = VertexPassThroughRouter() + + # Test with None project_id and location - should return default config + creds = router.get_vertex_credentials(None, None) + assert isinstance(creds, VertexPassThroughCredentials) + + # Test with valid project_id and location but no stored credentials + creds = router.get_vertex_credentials("test-project", "us-central1") + assert isinstance(creds, VertexPassThroughCredentials) + assert creds.vertex_project is None + assert creds.vertex_location is None + assert creds.vertex_credentials is None + + +@pytest.mark.asyncio +async def test_get_vertex_credentials_stored(): + """Test get_vertex_credentials with stored credentials""" + router = VertexPassThroughRouter() + router.add_vertex_credentials( + project_id="test-project", + location="us-central1", + vertex_credentials="test-creds", + ) + + creds = router.get_vertex_credentials( + project_id="test-project", location="us-central1" + ) + assert creds.vertex_project == "test-project" + assert creds.vertex_location == "us-central1" + assert creds.vertex_credentials == "test-creds" + + +@pytest.mark.asyncio +async def test_add_vertex_credentials(): + """Test add_vertex_credentials functionality""" + router = VertexPassThroughRouter() + + # Test adding valid credentials + router.add_vertex_credentials( + project_id="test-project", + location="us-central1", + vertex_credentials="test-creds", + ) + + assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials + creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"] + assert creds.vertex_project == "test-project" + assert creds.vertex_location == "us-central1" + assert creds.vertex_credentials == "test-creds" + + # Test adding with None values + router.add_vertex_credentials( + project_id=None, location=None, vertex_credentials="test-creds" + ) + # Should not add None values + assert len(router.deployment_key_to_vertex_credentials) == 1 + + +@pytest.mark.asyncio +async def test_get_deployment_key(): + """Test _get_deployment_key with various inputs""" + router = VertexPassThroughRouter() + + # Test with valid inputs + key = router._get_deployment_key("test-project", "us-central1") + assert key == "test-project-us-central1" + + # Test with None values + key = router._get_deployment_key(None, "us-central1") + assert key is None + + key = router._get_deployment_key("test-project", None) + assert key is None + + key = router._get_deployment_key(None, None) + assert key is None + + +@pytest.mark.asyncio +async def test_get_vertex_project_id_from_url(): + """Test _get_vertex_project_id_from_url with various URLs""" + # Test with valid URL + url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent" + project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url) + assert project_id == "test-project" + + # Test with invalid URL + url = "https://invalid-url.com" + project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url) + assert project_id is None + + +@pytest.mark.asyncio +async def test_get_vertex_location_from_url(): + """Test _get_vertex_location_from_url with various URLs""" + # Test with valid URL + url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent" + location = VertexPassThroughRouter._get_vertex_location_from_url(url) + assert location == "us-central1" + + # Test with invalid URL + url = "https://invalid-url.com" + location = VertexPassThroughRouter._get_vertex_location_from_url(url) + assert location is None diff --git a/tests/router_unit_tests/test_router_adding_deployments.py b/tests/router_unit_tests/test_router_adding_deployments.py new file mode 100644 index 0000000000..b5e2d4a526 --- /dev/null +++ b/tests/router_unit_tests/test_router_adding_deployments.py @@ -0,0 +1,170 @@ +import sys, os +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from litellm import Router +from litellm.router import Deployment, LiteLLM_Params +from unittest.mock import patch +import json + + +def test_initialize_deployment_for_pass_through_success(): + """ + Test successful initialization of a Vertex AI pass-through deployment + """ + router = Router(model_list=[]) + deployment = Deployment( + model_name="vertex-test", + litellm_params=LiteLLM_Params( + model="vertex_ai/test-model", + vertex_project="test-project", + vertex_location="us-central1", + vertex_credentials=json.dumps( + {"type": "service_account", "project_id": "test"} + ), + use_in_pass_through=True, + ), + ) + + # Test the initialization + router._initialize_deployment_for_pass_through( + deployment=deployment, + custom_llm_provider="vertex_ai", + model="vertex_ai/test-model", + ) + + # Verify the credentials were properly set + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + vertex_pass_through_router, + ) + + vertex_creds = vertex_pass_through_router.get_vertex_credentials( + project_id="test-project", location="us-central1" + ) + assert vertex_creds.vertex_project == "test-project" + assert vertex_creds.vertex_location == "us-central1" + assert vertex_creds.vertex_credentials == json.dumps( + {"type": "service_account", "project_id": "test"} + ) + + +def test_initialize_deployment_for_pass_through_missing_params(): + """ + Test initialization fails when required Vertex AI parameters are missing + """ + router = Router(model_list=[]) + deployment = Deployment( + model_name="vertex-test", + litellm_params=LiteLLM_Params( + model="vertex_ai/test-model", + # Missing required parameters + use_in_pass_through=True, + ), + ) + + # Test that initialization raises ValueError + with pytest.raises( + ValueError, + match="vertex_project, vertex_location, and vertex_credentials must be set", + ): + router._initialize_deployment_for_pass_through( + deployment=deployment, + custom_llm_provider="vertex_ai", + model="vertex_ai/test-model", + ) + + +def test_initialize_deployment_for_pass_through_unsupported_provider(): + """ + Test initialization with an unsupported provider + """ + router = Router(model_list=[]) + deployment = Deployment( + model_name="unsupported-test", + litellm_params=LiteLLM_Params( + model="unsupported/test-model", + use_in_pass_through=True, + ), + ) + + # Should not raise an error, but log a warning + router._initialize_deployment_for_pass_through( + deployment=deployment, + custom_llm_provider="unsupported_provider", + model="unsupported/test-model", + ) + + +def test_initialize_deployment_when_pass_through_disabled(): + """ + Test that initialization simply exits when use_in_pass_through is False + """ + router = Router(model_list=[]) + deployment = Deployment( + model_name="vertex-test", + litellm_params=LiteLLM_Params( + model="vertex_ai/test-model", + ), + ) + + # This should exit without error, even with missing vertex parameters + router._initialize_deployment_for_pass_through( + deployment=deployment, + custom_llm_provider="vertex_ai", + model="vertex_ai/test-model", + ) + + # If we reach this point, the test passes as the method exited without raising any errors + assert True + + +def test_add_vertex_pass_through_deployment(): + """ + Test adding a Vertex AI deployment with pass-through configuration + """ + router = Router(model_list=[]) + + # Create a deployment with Vertex AI pass-through settings + deployment = Deployment( + model_name="vertex-test", + litellm_params=LiteLLM_Params( + model="vertex_ai/test-model", + vertex_project="test-project", + vertex_location="us-central1", + vertex_credentials=json.dumps( + {"type": "service_account", "project_id": "test"} + ), + use_in_pass_through=True, + ), + ) + + # Add deployment to router + router.add_deployment(deployment) + + # Get the vertex credentials from the router + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + vertex_pass_through_router, + ) + + # current state of pass-through vertex router + print("\n vertex_pass_through_router.deployment_key_to_vertex_credentials\n\n") + print( + json.dumps( + vertex_pass_through_router.deployment_key_to_vertex_credentials, + indent=4, + default=str, + ) + ) + + vertex_creds = vertex_pass_through_router.get_vertex_credentials( + project_id="test-project", location="us-central1" + ) + + # Verify the credentials were properly set + assert vertex_creds.vertex_project == "test-project" + assert vertex_creds.vertex_location == "us-central1" + assert vertex_creds.vertex_credentials == json.dumps( + {"type": "service_account", "project_id": "test"} + )