mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat: add workload identity federation between GCP and AWS
This commit is contained in:
parent
f5996b2f6b
commit
62bf56c54d
2 changed files with 65 additions and 4 deletions
|
@ -41,14 +41,11 @@ class VertexBase:
|
||||||
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
|
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
|
||||||
) -> Tuple[Any, str]:
|
) -> Tuple[Any, str]:
|
||||||
import google.auth as google_auth
|
import google.auth as google_auth
|
||||||
from google.auth import identity_pool
|
|
||||||
from google.auth.transport.requests import (
|
from google.auth.transport.requests import (
|
||||||
Request, # type: ignore[import-untyped]
|
Request, # type: ignore[import-untyped]
|
||||||
)
|
)
|
||||||
|
|
||||||
if credentials is not None:
|
if credentials is not None:
|
||||||
import google.oauth2.service_account
|
|
||||||
|
|
||||||
if isinstance(credentials, str):
|
if isinstance(credentials, str):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"Vertex: Loading vertex credentials from %s", credentials
|
"Vertex: Loading vertex credentials from %s", credentials
|
||||||
|
@ -80,8 +77,18 @@ class VertexBase:
|
||||||
|
|
||||||
# Check if the JSON object contains Workload Identity Federation configuration
|
# Check if the JSON object contains Workload Identity Federation configuration
|
||||||
if "type" in json_obj and json_obj["type"] == "external_account":
|
if "type" in json_obj and json_obj["type"] == "external_account":
|
||||||
creds = identity_pool.Credentials.from_info(json_obj)
|
# If environment_id key contains "aws" value it corresponds to an AWS config file
|
||||||
|
if (
|
||||||
|
"credential_source" in json_obj
|
||||||
|
and "environment_id" in json_obj["credential_source"]
|
||||||
|
and "aws" in json_obj["credential_source"]["environment_id"]
|
||||||
|
):
|
||||||
|
creds = google_auth.aws.Credentials.from_info(json_obj)
|
||||||
|
else:
|
||||||
|
creds = google_auth.identity_pool.Credentials.from_info(json_obj)
|
||||||
else:
|
else:
|
||||||
|
import google.oauth2.service_account
|
||||||
|
|
||||||
creds = (
|
creds = (
|
||||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||||
json_obj,
|
json_obj,
|
||||||
|
|
|
@ -174,3 +174,57 @@ class TestVertexBase:
|
||||||
)
|
)
|
||||||
assert token == ""
|
assert token == ""
|
||||||
assert project == ""
|
assert project == ""
|
||||||
|
|
||||||
|
def test_load_auth_wif(self):
|
||||||
|
vertex_base = VertexBase()
|
||||||
|
input_project_id = "some_project_id"
|
||||||
|
|
||||||
|
# Test case 1: Using Workload Identity Federation for Microsoft Azure and
|
||||||
|
# OIDC identity providers (default behavior)
|
||||||
|
json_obj_1 = {
|
||||||
|
"type": "external_account",
|
||||||
|
}
|
||||||
|
mock_auth_1 = MagicMock()
|
||||||
|
mock_creds_1 = MagicMock()
|
||||||
|
mock_request_1 = MagicMock()
|
||||||
|
mock_creds_1 = mock_auth_1.identity_pool.Credentials.from_info.return_value
|
||||||
|
with patch.dict(sys.modules, {"google.auth": mock_auth_1,
|
||||||
|
"google.auth.transport.requests": mock_request_1}):
|
||||||
|
|
||||||
|
creds_1, project_id = vertex_base.load_auth(
|
||||||
|
credentials=json_obj_1, project_id=input_project_id
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_auth_1.identity_pool.Credentials.from_info.assert_called_once_with(
|
||||||
|
json_obj_1
|
||||||
|
)
|
||||||
|
mock_creds_1.refresh.assert_called_once_with(
|
||||||
|
mock_request_1.Request.return_value
|
||||||
|
)
|
||||||
|
assert creds_1 == mock_creds_1
|
||||||
|
assert project_id == input_project_id
|
||||||
|
|
||||||
|
# Test case 2: Using Workload Identity Federation for AWS
|
||||||
|
json_obj_2 = {
|
||||||
|
"type": "external_account",
|
||||||
|
"credential_source": {"environment_id": "aws1"}
|
||||||
|
}
|
||||||
|
mock_auth_2 = MagicMock()
|
||||||
|
mock_creds_2 = MagicMock()
|
||||||
|
mock_request_2 = MagicMock()
|
||||||
|
mock_creds_2 = mock_auth_2.aws.Credentials.from_info.return_value
|
||||||
|
with patch.dict(sys.modules, {"google.auth": mock_auth_2,
|
||||||
|
"google.auth.transport.requests": mock_request_2}):
|
||||||
|
|
||||||
|
creds_2, project_id = vertex_base.load_auth(
|
||||||
|
credentials=json_obj_2, project_id=input_project_id
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_auth_2.aws.Credentials.from_info.assert_called_once_with(
|
||||||
|
json_obj_2
|
||||||
|
)
|
||||||
|
mock_creds_2.refresh.assert_called_once_with(
|
||||||
|
mock_request_2.Request.return_value
|
||||||
|
)
|
||||||
|
assert creds_2 == mock_creds_2
|
||||||
|
assert project_id == input_project_id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue