(fix) BaseAWSLLM - cache IAM role credentials when used (#7775)

* fix base aws llm

* fix auth with aws role

* test aws base llm

* fix base aws llm init

* run ci/cd again

* fix get_credentials

* ci/cd run again

* _auth_with_aws_role
This commit is contained in:
Ishaan Jaff 2025-01-14 20:16:22 -08:00 committed by GitHub
parent 5fbbf47581
commit 30bb4c4cdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 397 additions and 129 deletions

View file

@ -2,7 +2,7 @@
import warnings
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
### INIT VARIABLES #####
### INIT VARIABLES ######
import threading
import os
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args

View file

@ -1,7 +1,7 @@
"""
Custom Logger that handles batching logic
Use this if you want your logs to be stored in memory and flushed periodically
Use this if you want your logs to be stored in memory and flushed periodically.
"""
import asyncio

View file

@ -1,6 +1,7 @@
import hashlib
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import httpx
@ -48,7 +49,7 @@ class BaseAWSLLM:
credential_str = json.dumps(credential_args, sort_keys=True)
return hashlib.sha256(credential_str.encode()).hexdigest()
def get_credentials( # noqa: PLR0915
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
@ -63,10 +64,6 @@ class BaseAWSLLM:
"""
Return a boto3.Credentials object
"""
import boto3
from botocore.credentials import Credentials
## CHECK IS 'os.environ/' passed in
param_names = [
"aws_access_key_id",
@ -115,10 +112,6 @@ class BaseAWSLLM:
aws_sts_endpoint,
) = params_to_check
# create cache key for non-expiring auth flows
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
cache_key = self.get_cache_key(args)
verbose_logger.debug(
"in get credentials\n"
"aws_access_key_id=%s\n"
@ -141,152 +134,241 @@ class BaseAWSLLM:
aws_sts_endpoint,
)
### CHECK STS ###
# create cache key for non-expiring auth flows
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
cache_key = self.get_cache_key(args)
_cached_credentials = self.iam_cache.get_cache(cache_key)
if _cached_credentials:
return _cached_credentials
#########################################################
# Handle diff boto3 auth flows
# for each helper
# Return:
# Credentials - boto3.Credentials
# cache ttl - Optional[int]. If None, the credentials are not cached. Some auth flows have no expiry time.
#########################################################
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
verbose_logger.debug(
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
credentials, _cache_ttl = self._auth_with_web_identity_token(
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_region_name=aws_region_name,
aws_sts_endpoint=aws_sts_endpoint,
)
if aws_sts_endpoint is None:
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
else:
sts_endpoint = aws_sts_endpoint
iam_creds_cache_key = json.dumps(
{
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
}
)
iam_creds_dict = self.iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise AwsAuthError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=sts_endpoint,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
Policy='{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}',
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"][
"SecretAccessKey"
],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
self.iam_cache.set_cache(
key=iam_creds_cache_key,
value=json.dumps(iam_creds_dict),
ttl=3600 - 60,
)
if sts_response["PackedPolicySize"] > 75:
verbose_logger.warning(
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
)
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
credentials, _cache_ttl = self._auth_with_aws_role(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name)
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_session_token is not None
): ### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
):
credentials, _cache_ttl = self._auth_with_aws_session_token(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
return credentials
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_region_name is not None
):
# Check if credentials are already in cache. These credentials have no expiry time.
cached_credentials: Optional[Credentials] = self.iam_cache.get_cache(
cache_key
)
if cached_credentials:
return cached_credentials
session = boto3.Session(
credentials, _cache_ttl = self._auth_with_access_key_and_secret_key(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
aws_region_name=aws_region_name,
)
else:
credentials, _cache_ttl = self._auth_with_env_vars()
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
return credentials
def _auth_with_web_identity_token(
self,
aws_web_identity_token: str,
aws_role_name: str,
aws_session_name: str,
aws_region_name: Optional[str],
aws_sts_endpoint: Optional[str],
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Web Identity Token
"""
import boto3
verbose_logger.debug(
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
)
if aws_sts_endpoint is None:
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
else:
sts_endpoint = aws_sts_endpoint
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise AwsAuthError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
credentials = session.get_credentials()
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=sts_endpoint,
)
if (
credentials.token is None
): # don't cache if session token exists. The expiry time for that is not known.
self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
Policy='{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}',
)
return credentials
else:
# check env var. Do not cache the response from this.
session = boto3.Session()
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
credentials = session.get_credentials()
if sts_response["PackedPolicySize"] > 75:
verbose_logger.warning(
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
)
return credentials
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds, self._get_default_ttl_for_boto3_credentials()
def _auth_with_aws_role(
self,
aws_access_key_id: Optional[str],
aws_secret_access_key: Optional[str],
aws_role_name: str,
aws_session_name: str,
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Role
"""
import boto3
from botocore.credentials import Credentials
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
sts_expiry = sts_credentials["Expiration"]
# Convert to timezone-aware datetime for comparison
current_time = datetime.now(sts_expiry.tzinfo)
sts_ttl = (sts_expiry - current_time).total_seconds() - 60
return credentials, sts_ttl
def _auth_with_aws_profile(
self, aws_profile_name: str
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS profile
"""
import boto3
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials(), None
def _auth_with_aws_session_token(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_session_token: str,
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Session Token
"""
### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
return credentials, None
def _auth_with_access_key_and_secret_key(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_region_name: Optional[str],
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Access Key and Secret Key
"""
import boto3
# Check if credentials are already in cache. These credentials have no expiry time.
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
credentials = session.get_credentials()
return credentials, self._get_default_ttl_for_boto3_credentials()
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Environment Variables
"""
import boto3
session = boto3.Session()
credentials = session.get_credentials()
return credentials, None
def _get_default_ttl_for_boto3_credentials(self) -> int:
"""
Get the default TTL for boto3 credentials
Returns `3600-60` which is 59 minutes
"""
return 3600 - 60
def get_runtime_endpoint(
self,

View file

@ -58,6 +58,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
self.optional_params = kwargs
super().__init__(**kwargs)
BaseAWSLLM.__init__(self)
def convert_to_bedrock_format(
self,

View file

@ -33,6 +33,10 @@ from .base_secret_manager import BaseSecretManager
class AWSSecretsManagerV2(BaseAWSLLM, BaseSecretManager):
def __init__(self, **kwargs):
BaseSecretManager.__init__(self, **kwargs)
BaseAWSLLM.__init__(self, **kwargs)
@classmethod
def validate_environment(cls):
if "AWS_REGION_NAME" not in os.environ:

View file

@ -0,0 +1,181 @@
import pytest
import os
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from botocore.credentials import Credentials
from typing import Dict, Any
from litellm.llms.bedrock.base_aws_llm import (
BaseAWSLLM,
AwsAuthError,
Boto3CredentialsInfo,
)
# Test fixtures
@pytest.fixture
def base_aws_llm():
return BaseAWSLLM()
@pytest.fixture
def mock_credentials():
return Credentials(
access_key="test_access", secret_key="test_secret", token="test_token"
)
# Test cache key generation
def test_get_cache_key(base_aws_llm):
test_args = {
"aws_access_key_id": "test_key",
"aws_secret_access_key": "test_secret",
}
cache_key = base_aws_llm.get_cache_key(test_args)
assert isinstance(cache_key, str)
assert len(cache_key) == 64 # SHA-256 produces 64 character hex string
# Test web identity token authentication
@patch("boto3.client")
@patch("litellm.llms.bedrock.base_aws_llm.get_secret") # Add this patch
def test_auth_with_web_identity_token(mock_get_secret, mock_boto3_client, base_aws_llm):
# Mock get_secret to return a token
mock_get_secret.return_value = "mocked_oidc_token"
# Mock the STS client and response
mock_sts = MagicMock()
mock_sts.assume_role_with_web_identity.return_value = {
"Credentials": {
"AccessKeyId": "test_access",
"SecretAccessKey": "test_secret",
"SessionToken": "test_token",
},
"PackedPolicySize": 10,
}
mock_boto3_client.return_value = mock_sts
credentials, ttl = base_aws_llm._auth_with_web_identity_token(
aws_web_identity_token="test_token",
aws_role_name="test_role",
aws_session_name="test_session",
aws_region_name="us-west-2",
aws_sts_endpoint=None,
)
# Verify get_secret was called with the correct argument
mock_get_secret.assert_called_once_with("test_token")
assert isinstance(credentials, Credentials)
assert ttl == 3540 # default TTL (3600 - 60)
# Test AWS role authentication
@patch("boto3.client")
def test_auth_with_aws_role(mock_boto3_client, base_aws_llm):
# Mock the STS client and response
mock_sts = MagicMock()
expiry_time = datetime.now(timezone.utc)
mock_sts.assume_role.return_value = {
"Credentials": {
"AccessKeyId": "test_access",
"SecretAccessKey": "test_secret",
"SessionToken": "test_token",
"Expiration": expiry_time,
}
}
mock_boto3_client.return_value = mock_sts
credentials, ttl = base_aws_llm._auth_with_aws_role(
aws_access_key_id="test_access",
aws_secret_access_key="test_secret",
aws_role_name="test_role",
aws_session_name="test_session",
)
assert isinstance(credentials, Credentials)
assert isinstance(ttl, float)
# Test AWS profile authentication
@patch("boto3.Session")
def test_auth_with_aws_profile(mock_session, base_aws_llm, mock_credentials):
# Mock the session
mock_session_instance = MagicMock()
mock_session_instance.get_credentials.return_value = mock_credentials
mock_session.return_value = mock_session_instance
credentials, ttl = base_aws_llm._auth_with_aws_profile("test_profile")
assert credentials == mock_credentials
assert ttl is None
# Test session token authentication
def test_auth_with_aws_session_token(base_aws_llm):
credentials, ttl = base_aws_llm._auth_with_aws_session_token(
aws_access_key_id="test_access",
aws_secret_access_key="test_secret",
aws_session_token="test_token",
)
assert isinstance(credentials, Credentials)
assert credentials.access_key == "test_access"
assert credentials.secret_key == "test_secret"
assert credentials.token == "test_token"
assert ttl is None
# Test access key and secret key authentication
@patch("boto3.Session")
def test_auth_with_access_key_and_secret_key(
mock_session, base_aws_llm, mock_credentials
):
# Mock the session
mock_session_instance = MagicMock()
mock_session_instance.get_credentials.return_value = mock_credentials
mock_session.return_value = mock_session_instance
credentials, ttl = base_aws_llm._auth_with_access_key_and_secret_key(
aws_access_key_id="test_access",
aws_secret_access_key="test_secret",
aws_region_name="us-west-2",
)
assert credentials == mock_credentials
assert ttl == 3540 # default TTL (3600 - 60)
# Test environment variables authentication
@patch("boto3.Session")
def test_auth_with_env_vars(mock_session, base_aws_llm, mock_credentials):
# Mock the session
mock_session_instance = MagicMock()
mock_session_instance.get_credentials.return_value = mock_credentials
mock_session.return_value = mock_session_instance
credentials, ttl = base_aws_llm._auth_with_env_vars()
assert credentials == mock_credentials
assert ttl is None
# Test runtime endpoint resolution
def test_get_runtime_endpoint(base_aws_llm):
endpoint_url, proxy_endpoint_url = base_aws_llm.get_runtime_endpoint(
api_base=None, aws_bedrock_runtime_endpoint=None, aws_region_name="us-west-2"
)
assert endpoint_url == "https://bedrock-runtime.us-west-2.amazonaws.com"
assert proxy_endpoint_url == "https://bedrock-runtime.us-west-2.amazonaws.com"
endpoint_url, proxy_endpoint_url = base_aws_llm.get_runtime_endpoint(
aws_bedrock_runtime_endpoint=None, aws_region_name="us-east-1", api_base=None
)
assert endpoint_url == "https://bedrock-runtime.us-east-1.amazonaws.com"
assert proxy_endpoint_url == "https://bedrock-runtime.us-east-1.amazonaws.com"
@pytest.fixture
def clear_cache(base_aws_llm):
"""Clear the cache before each test"""
base_aws_llm.iam_cache.in_memory_cache.cache_dict = {}
yield