mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 62bf56c54d
into b82af5b826
This commit is contained in:
commit
cb16ac5698
2 changed files with 65 additions and 4 deletions
|
@ -41,14 +41,11 @@ class VertexBase:
|
|||
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
|
||||
) -> Tuple[Any, str]:
|
||||
import google.auth as google_auth
|
||||
from google.auth import identity_pool
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
import google.oauth2.service_account
|
||||
|
||||
if isinstance(credentials, str):
|
||||
verbose_logger.debug(
|
||||
"Vertex: Loading vertex credentials from %s", credentials
|
||||
|
@ -80,8 +77,18 @@ class VertexBase:
|
|||
|
||||
# Check if the JSON object contains Workload Identity Federation configuration
|
||||
if "type" in json_obj and json_obj["type"] == "external_account":
|
||||
creds = identity_pool.Credentials.from_info(json_obj)
|
||||
# If environment_id key contains "aws" value it corresponds to an AWS config file
|
||||
if (
|
||||
"credential_source" in json_obj
|
||||
and "environment_id" in json_obj["credential_source"]
|
||||
and "aws" in json_obj["credential_source"]["environment_id"]
|
||||
):
|
||||
creds = google_auth.aws.Credentials.from_info(json_obj)
|
||||
else:
|
||||
creds = google_auth.identity_pool.Credentials.from_info(json_obj)
|
||||
else:
|
||||
import google.oauth2.service_account
|
||||
|
||||
creds = (
|
||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
|
|
|
@ -174,3 +174,57 @@ class TestVertexBase:
|
|||
)
|
||||
assert token == ""
|
||||
assert project == ""
|
||||
|
||||
def test_load_auth_wif(self):
|
||||
vertex_base = VertexBase()
|
||||
input_project_id = "some_project_id"
|
||||
|
||||
# Test case 1: Using Workload Identity Federation for Microsoft Azure and
|
||||
# OIDC identity providers (default behavior)
|
||||
json_obj_1 = {
|
||||
"type": "external_account",
|
||||
}
|
||||
mock_auth_1 = MagicMock()
|
||||
mock_creds_1 = MagicMock()
|
||||
mock_request_1 = MagicMock()
|
||||
mock_creds_1 = mock_auth_1.identity_pool.Credentials.from_info.return_value
|
||||
with patch.dict(sys.modules, {"google.auth": mock_auth_1,
|
||||
"google.auth.transport.requests": mock_request_1}):
|
||||
|
||||
creds_1, project_id = vertex_base.load_auth(
|
||||
credentials=json_obj_1, project_id=input_project_id
|
||||
)
|
||||
|
||||
mock_auth_1.identity_pool.Credentials.from_info.assert_called_once_with(
|
||||
json_obj_1
|
||||
)
|
||||
mock_creds_1.refresh.assert_called_once_with(
|
||||
mock_request_1.Request.return_value
|
||||
)
|
||||
assert creds_1 == mock_creds_1
|
||||
assert project_id == input_project_id
|
||||
|
||||
# Test case 2: Using Workload Identity Federation for AWS
|
||||
json_obj_2 = {
|
||||
"type": "external_account",
|
||||
"credential_source": {"environment_id": "aws1"}
|
||||
}
|
||||
mock_auth_2 = MagicMock()
|
||||
mock_creds_2 = MagicMock()
|
||||
mock_request_2 = MagicMock()
|
||||
mock_creds_2 = mock_auth_2.aws.Credentials.from_info.return_value
|
||||
with patch.dict(sys.modules, {"google.auth": mock_auth_2,
|
||||
"google.auth.transport.requests": mock_request_2}):
|
||||
|
||||
creds_2, project_id = vertex_base.load_auth(
|
||||
credentials=json_obj_2, project_id=input_project_id
|
||||
)
|
||||
|
||||
mock_auth_2.aws.Credentials.from_info.assert_called_once_with(
|
||||
json_obj_2
|
||||
)
|
||||
mock_creds_2.refresh.assert_called_once_with(
|
||||
mock_request_2.Request.return_value
|
||||
)
|
||||
assert creds_2 == mock_creds_2
|
||||
assert project_id == input_project_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue