This commit is contained in:
Pascal Lim 2025-04-24 00:54:14 -07:00 committed by GitHub
commit cb16ac5698
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 65 additions and 4 deletions

View file

@ -41,14 +41,11 @@ class VertexBase:
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
) -> Tuple[Any, str]:
import google.auth as google_auth
from google.auth import identity_pool
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
if credentials is not None:
import google.oauth2.service_account
if isinstance(credentials, str):
verbose_logger.debug(
"Vertex: Loading vertex credentials from %s", credentials
@ -80,8 +77,18 @@ class VertexBase:
# Check if the JSON object contains Workload Identity Federation configuration
if "type" in json_obj and json_obj["type"] == "external_account":
creds = identity_pool.Credentials.from_info(json_obj)
# If environment_id key contains "aws" value it corresponds to an AWS config file
if (
"credential_source" in json_obj
and "environment_id" in json_obj["credential_source"]
and "aws" in json_obj["credential_source"]["environment_id"]
):
creds = google_auth.aws.Credentials.from_info(json_obj)
else:
creds = google_auth.identity_pool.Credentials.from_info(json_obj)
else:
import google.oauth2.service_account
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,

View file

@ -174,3 +174,57 @@ class TestVertexBase:
)
assert token == ""
assert project == ""
def test_load_auth_wif(self):
vertex_base = VertexBase()
input_project_id = "some_project_id"
# Test case 1: Using Workload Identity Federation for Microsoft Azure and
# OIDC identity providers (default behavior)
json_obj_1 = {
"type": "external_account",
}
mock_auth_1 = MagicMock()
mock_creds_1 = MagicMock()
mock_request_1 = MagicMock()
mock_creds_1 = mock_auth_1.identity_pool.Credentials.from_info.return_value
with patch.dict(sys.modules, {"google.auth": mock_auth_1,
"google.auth.transport.requests": mock_request_1}):
creds_1, project_id = vertex_base.load_auth(
credentials=json_obj_1, project_id=input_project_id
)
mock_auth_1.identity_pool.Credentials.from_info.assert_called_once_with(
json_obj_1
)
mock_creds_1.refresh.assert_called_once_with(
mock_request_1.Request.return_value
)
assert creds_1 == mock_creds_1
assert project_id == input_project_id
# Test case 2: Using Workload Identity Federation for AWS
json_obj_2 = {
"type": "external_account",
"credential_source": {"environment_id": "aws1"}
}
mock_auth_2 = MagicMock()
mock_creds_2 = MagicMock()
mock_request_2 = MagicMock()
mock_creds_2 = mock_auth_2.aws.Credentials.from_info.return_value
with patch.dict(sys.modules, {"google.auth": mock_auth_2,
"google.auth.transport.requests": mock_request_2}):
creds_2, project_id = vertex_base.load_auth(
credentials=json_obj_2, project_id=input_project_id
)
mock_auth_2.aws.Credentials.from_info.assert_called_once_with(
json_obj_2
)
mock_creds_2.refresh.assert_called_once_with(
mock_request_2.Request.return_value
)
assert creds_2 == mock_creds_2
assert project_id == input_project_id