diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index be3a903dcc..24ab08c167 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -456,7 +456,6 @@ async def vertex_proxy_route( request=request, api_key=api_key_to_use, ) - vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint) vertex_location: Optional[str] = get_vertex_location_from_url(endpoint) vertex_credentials = passthrough_endpoint_router.get_vertex_credentials( diff --git a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py index 5267c3b26c..897faa1717 100644 --- a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py +++ b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -1,7 +1,8 @@ from typing import Dict, Optional -from litellm._logging import verbose_logger +from litellm._logging import verbose_router_logger from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials @@ -15,6 +16,7 @@ class PassthroughEndpointRouter: self.deployment_key_to_vertex_credentials: Dict[ str, VertexPassThroughCredentials ] = {} + self.default_vertex_config: Optional[VertexPassThroughCredentials] = None def set_pass_through_credentials( self, @@ -49,14 +51,14 @@ class PassthroughEndpointRouter: custom_llm_provider=custom_llm_provider, region_name=region_name, ) - verbose_logger.debug( + verbose_router_logger.debug( f"Pass-through llm endpoints router, looking for credentials for {credential_name}" ) if credential_name in self.credentials: - verbose_logger.debug(f"Found credentials for {credential_name}") + verbose_router_logger.debug(f"Found credentials for {credential_name}") return self.credentials[credential_name] else: - verbose_logger.debug( + verbose_router_logger.debug( f"No credentials found for {credential_name}, looking for env variable" ) _env_variable_name = ( @@ -66,6 +68,72 @@ class PassthroughEndpointRouter: ) return get_secret_str(_env_variable_name) + def _get_vertex_env_vars(self) -> VertexPassThroughCredentials: + """ + Helper to get vertex pass through config from environment variables + + The following environment variables are used: + - DEFAULT_VERTEXAI_PROJECT (project id) + - DEFAULT_VERTEXAI_LOCATION (location) + - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file) + """ + return VertexPassThroughCredentials( + vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"), + vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"), + vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"), + ) + + def set_default_vertex_config(self, config: Optional[dict] = None): + """Sets vertex configuration from provided config and/or environment variables + + Args: + config (Optional[dict]): Configuration dictionary + Example: { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "os.environ/GOOGLE_CREDS" + } + """ + # Initialize config dictionary if None + if config is None: + self.default_vertex_config = self._get_vertex_env_vars() + return + + if isinstance(config, dict): + for key, value in config.items(): + if isinstance(value, str) and value.startswith("os.environ/"): + config[key] = get_secret_str(value) + + self.default_vertex_config = VertexPassThroughCredentials(**config) + + def add_vertex_credentials( + self, + project_id: str, + location: str, + vertex_credentials: VERTEX_CREDENTIALS_TYPES, + ): + """ + Add the vertex credentials for the given project-id, location + """ + + deployment_key = self._get_deployment_key( + project_id=project_id, + location=location, + ) + if deployment_key is None: + verbose_router_logger.debug( + "No deployment key found for project-id, location" + ) + return + vertex_pass_through_credentials = VertexPassThroughCredentials( + vertex_project=project_id, + vertex_location=location, + vertex_credentials=vertex_credentials, + ) + self.deployment_key_to_vertex_credentials[deployment_key] = ( + vertex_pass_through_credentials + ) + def _get_deployment_key( self, project_id: Optional[str], location: Optional[str] ) -> Optional[str]: @@ -82,15 +150,13 @@ class PassthroughEndpointRouter: """ Get the vertex credentials for the given project-id, location """ - # from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - # default_vertex_config, - # ) default_vertex_config: Optional[VertexPassThroughCredentials] = None deployment_key = self._get_deployment_key( project_id=project_id, location=location, ) + if deployment_key is None: return default_vertex_config if deployment_key in self.deployment_key_to_vertex_credentials: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ae1c8d18af..ee2a906200 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -235,6 +235,9 @@ from litellm.proxy.openai_files_endpoints.files_endpoints import ( router as openai_files_router, ) from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config +from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, +) from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( router as llm_passthrough_router, ) @@ -272,8 +275,6 @@ from litellm.proxy.utils import ( from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import ( router as langfuse_router, ) -from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router -from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config from litellm.router import ( AssistantsTypedDict, Deployment, @@ -2115,7 +2116,9 @@ class ProxyConfig: ## default config for vertex ai routes default_vertex_config = config.get("default_vertex_config", None) - set_default_vertex_config(config=default_vertex_config) + passthrough_endpoint_router.set_default_vertex_config( + config=default_vertex_config + ) ## ROUTER SETTINGS (e.g. routing_strategy, ...) router_settings = config.get("router_settings", None) @@ -8161,7 +8164,6 @@ app.include_router(batches_router) app.include_router(rerank_router) app.include_router(fine_tuning_router) app.include_router(credential_router) -app.include_router(vertex_router) app.include_router(llm_passthrough_router) app.include_router(anthropic_router) app.include_router(langfuse_router) diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 4b706ed33a..6243fe79b4 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -23,48 +23,6 @@ # default_vertex_config: Optional[VertexPassThroughCredentials] = None -# def _get_vertex_env_vars() -> VertexPassThroughCredentials: -# """ -# Helper to get vertex pass through config from environment variables - -# The following environment variables are used: -# - DEFAULT_VERTEXAI_PROJECT (project id) -# - DEFAULT_VERTEXAI_LOCATION (location) -# - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file) -# """ -# return VertexPassThroughCredentials( -# vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"), -# vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"), -# vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"), -# ) - - -# def set_default_vertex_config(config: Optional[dict] = None): -# """Sets vertex configuration from provided config and/or environment variables - -# Args: -# config (Optional[dict]): Configuration dictionary -# Example: { -# "vertex_project": "my-project-123", -# "vertex_location": "us-central1", -# "vertex_credentials": "os.environ/GOOGLE_CREDS" -# } -# """ -# global default_vertex_config - -# # Initialize config dictionary if None -# if config is None: -# default_vertex_config = _get_vertex_env_vars() -# return - -# if isinstance(config, dict): -# for key, value in config.items(): -# if isinstance(value, str) and value.startswith("os.environ/"): -# config[key] = litellm.get_secret(value) - -# _set_default_vertex_config(VertexPassThroughCredentials(**config)) - - # def _set_default_vertex_config( # vertex_pass_through_credentials: VertexPassThroughCredentials, # ): diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py b/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py index 5017a8f661..dd17d49b8f 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py @@ -55,37 +55,6 @@ # else: # return default_vertex_config -# def add_vertex_credentials( -# self, -# project_id: str, -# location: str, -# vertex_credentials: VERTEX_CREDENTIALS_TYPES, -# ): -# """ -# Add the vertex credentials for the given project-id, location -# """ - -# deployment_key = self._get_deployment_key( -# project_id=project_id, -# location=location, -# ) -# if deployment_key is None: -# verbose_proxy_logger.debug( -# "No deployment key found for project-id, location" -# ) -# return -# vertex_pass_through_credentials = VertexPassThroughCredentials( -# vertex_project=project_id, -# vertex_location=location, -# vertex_credentials=vertex_credentials, -# ) -# self.deployment_key_to_vertex_credentials[deployment_key] = ( -# vertex_pass_through_credentials -# ) -# verbose_proxy_logger.debug( -# f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}" -# ) - # @staticmethod # def _get_vertex_project_id_from_url(url: str) -> Optional[str]: diff --git a/litellm/router.py b/litellm/router.py index a395c851dd..af7b00e79d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4495,11 +4495,11 @@ class Router: Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints """ if deployment.litellm_params.use_in_pass_through is True: - if custom_llm_provider == "vertex_ai": - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - vertex_pass_through_router, - ) + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, + ) + if custom_llm_provider == "vertex_ai": if ( deployment.litellm_params.vertex_project is None or deployment.litellm_params.vertex_location is None @@ -4508,16 +4508,12 @@ class Router: raise ValueError( "vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints" ) - vertex_pass_through_router.add_vertex_credentials( + passthrough_endpoint_router.add_vertex_credentials( project_id=deployment.litellm_params.vertex_project, location=deployment.litellm_params.vertex_location, vertex_credentials=deployment.litellm_params.vertex_credentials, ) else: - from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( - passthrough_endpoint_router, - ) - passthrough_endpoint_router.set_pass_through_credentials( custom_llm_provider=custom_llm_provider, api_base=deployment.litellm_params.api_base, diff --git a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py index 48cb60968b..8f8fbbe9de 100644 --- a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py +++ b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py @@ -191,16 +191,39 @@ class TestVertexAIPassThroughHandler: """ @pytest.mark.asyncio - async def test_vertex_passthrough_with_credentials(self): + async def test_vertex_passthrough_with_credentials(self, monkeypatch): """ Test that when passthrough credentials are set, they are correctly used in the request """ + from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import ( + PassthroughEndpointRouter, + ) + + vertex_project = "test-project" + vertex_location = "us-central1" + vertex_credentials = "test-creds" + + pass_through_router = PassthroughEndpointRouter() + + pass_through_router.add_vertex_credentials( + project_id=vertex_project, + location=vertex_location, + vertex_credentials=vertex_credentials, + ) + + monkeypatch.setattr( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router", + pass_through_router, + ) + + endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/gemini-1.5-flash:generateContent" + # Mock request mock_request = Request( scope={ "type": "http", "method": "POST", - "path": "/vertex_ai/models/test-model/predict", + "path": endpoint, "headers": {}, } ) @@ -209,33 +232,24 @@ class TestVertexAIPassThroughHandler: mock_response = Response() # Mock vertex credentials - test_project = "test-project" - test_location = "us-central1" - test_token = "test-token-123" + test_project = vertex_project + test_location = vertex_location + test_token = vertex_credentials with mock.patch( - "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_vertex_credentials" - ) as mock_get_creds, mock.patch( "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async" ) as mock_ensure_token, mock.patch( "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url" ) as mock_get_token, mock.patch( "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" ) as mock_create_route: - - # Setup mock returns - mock_get_creds.return_value = VertexPassThroughCredentials( - vertex_project=test_project, - vertex_location=test_location, - vertex_credentials="test-creds", - ) mock_ensure_token.return_value = ("test-auth-header", test_project) mock_get_token.return_value = (test_token, "") # Call the route try: await vertex_proxy_route( - endpoint="models/test-model/predict", + endpoint=endpoint, request=mock_request, fastapi_response=mock_response, ) @@ -244,7 +258,7 @@ class TestVertexAIPassThroughHandler: # Verify create_pass_through_route was called with correct arguments mock_create_route.assert_called_once_with( - endpoint="models/test-model/predict", - target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/models/test-model/predict", + endpoint=endpoint, + target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent", custom_headers={"Authorization": f"Bearer {test_token}"}, )