diff --git a/litellm/llms/vertex_ai/vertex_llm_base.py b/litellm/llms/vertex_ai/vertex_llm_base.py index 8f3037c791..928dbb5d18 100644 --- a/litellm/llms/vertex_ai/vertex_llm_base.py +++ b/litellm/llms/vertex_ai/vertex_llm_base.py @@ -41,14 +41,11 @@ class VertexBase: self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str] ) -> Tuple[Any, str]: import google.auth as google_auth - from google.auth import identity_pool from google.auth.transport.requests import ( Request, # type: ignore[import-untyped] ) if credentials is not None: - import google.oauth2.service_account - if isinstance(credentials, str): verbose_logger.debug( "Vertex: Loading vertex credentials from %s", credentials @@ -80,8 +77,18 @@ class VertexBase: # Check if the JSON object contains Workload Identity Federation configuration if "type" in json_obj and json_obj["type"] == "external_account": - creds = identity_pool.Credentials.from_info(json_obj) + # If environment_id key contains "aws" value it corresponds to an AWS config file + if ( + "credential_source" in json_obj + and "environment_id" in json_obj["credential_source"] + and "aws" in json_obj["credential_source"]["environment_id"] + ): + creds = google_auth.aws.Credentials.from_info(json_obj) + else: + creds = google_auth.identity_pool.Credentials.from_info(json_obj) else: + import google.oauth2.service_account + creds = ( google.oauth2.service_account.Credentials.from_service_account_info( json_obj, diff --git a/tests/litellm/llms/vertex_ai/test_vertex_llm_base.py b/tests/litellm/llms/vertex_ai/test_vertex_llm_base.py index 135dc5b616..b3f512ce43 100644 --- a/tests/litellm/llms/vertex_ai/test_vertex_llm_base.py +++ b/tests/litellm/llms/vertex_ai/test_vertex_llm_base.py @@ -174,3 +174,57 @@ class TestVertexBase: ) assert token == "" assert project == "" + + def test_load_auth_wif(self): + vertex_base = VertexBase() + input_project_id = "some_project_id" + + # Test case 1: Using Workload Identity Federation for Microsoft Azure and + # OIDC identity providers (default behavior) + json_obj_1 = { + "type": "external_account", + } + mock_auth_1 = MagicMock() + mock_creds_1 = MagicMock() + mock_request_1 = MagicMock() + mock_creds_1 = mock_auth_1.identity_pool.Credentials.from_info.return_value + with patch.dict(sys.modules, {"google.auth": mock_auth_1, + "google.auth.transport.requests": mock_request_1}): + + creds_1, project_id = vertex_base.load_auth( + credentials=json_obj_1, project_id=input_project_id + ) + + mock_auth_1.identity_pool.Credentials.from_info.assert_called_once_with( + json_obj_1 + ) + mock_creds_1.refresh.assert_called_once_with( + mock_request_1.Request.return_value + ) + assert creds_1 == mock_creds_1 + assert project_id == input_project_id + + # Test case 2: Using Workload Identity Federation for AWS + json_obj_2 = { + "type": "external_account", + "credential_source": {"environment_id": "aws1"} + } + mock_auth_2 = MagicMock() + mock_creds_2 = MagicMock() + mock_request_2 = MagicMock() + mock_creds_2 = mock_auth_2.aws.Credentials.from_info.return_value + with patch.dict(sys.modules, {"google.auth": mock_auth_2, + "google.auth.transport.requests": mock_request_2}): + + creds_2, project_id = vertex_base.load_auth( + credentials=json_obj_2, project_id=input_project_id + ) + + mock_auth_2.aws.Credentials.from_info.assert_called_once_with( + json_obj_2 + ) + mock_creds_2.refresh.assert_called_once_with( + mock_request_2.Request.return_value + ) + assert creds_2 == mock_creds_2 + assert project_id == input_project_id