mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(fix) BaseAWSLLM
- cache IAM role credentials when used (#7775)
* fix base aws llm * fix auth with aws role * test aws base llm * fix base aws llm init * run ci/cd again * fix get_credentials * ci/cd run again * _auth_with_aws_role
This commit is contained in:
parent
5fbbf47581
commit
30bb4c4cdd
6 changed files with 397 additions and 129 deletions
|
@ -2,7 +2,7 @@
|
|||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||
### INIT VARIABLES #####
|
||||
### INIT VARIABLES ######
|
||||
import threading
|
||||
import os
|
||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Custom Logger that handles batching logic
|
||||
|
||||
Use this if you want your logs to be stored in memory and flushed periodically
|
||||
Use this if you want your logs to be stored in memory and flushed periodically.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
@ -48,7 +49,7 @@ class BaseAWSLLM:
|
|||
credential_str = json.dumps(credential_args, sort_keys=True)
|
||||
return hashlib.sha256(credential_str.encode()).hexdigest()
|
||||
|
||||
def get_credentials( # noqa: PLR0915
|
||||
def get_credentials(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
|
@ -63,10 +64,6 @@ class BaseAWSLLM:
|
|||
"""
|
||||
Return a boto3.Credentials object
|
||||
"""
|
||||
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
param_names = [
|
||||
"aws_access_key_id",
|
||||
|
@ -115,10 +112,6 @@ class BaseAWSLLM:
|
|||
aws_sts_endpoint,
|
||||
) = params_to_check
|
||||
|
||||
# create cache key for non-expiring auth flows
|
||||
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
|
||||
cache_key = self.get_cache_key(args)
|
||||
|
||||
verbose_logger.debug(
|
||||
"in get credentials\n"
|
||||
"aws_access_key_id=%s\n"
|
||||
|
@ -141,152 +134,241 @@ class BaseAWSLLM:
|
|||
aws_sts_endpoint,
|
||||
)
|
||||
|
||||
### CHECK STS ###
|
||||
# create cache key for non-expiring auth flows
|
||||
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
|
||||
|
||||
cache_key = self.get_cache_key(args)
|
||||
_cached_credentials = self.iam_cache.get_cache(cache_key)
|
||||
if _cached_credentials:
|
||||
return _cached_credentials
|
||||
|
||||
#########################################################
|
||||
# Handle diff boto3 auth flows
|
||||
# for each helper
|
||||
# Return:
|
||||
# Credentials - boto3.Credentials
|
||||
# cache ttl - Optional[int]. If None, the credentials are not cached. Some auth flows have no expiry time.
|
||||
#########################################################
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
|
||||
credentials, _cache_ttl = self._auth_with_web_identity_token(
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
|
||||
if aws_sts_endpoint is None:
|
||||
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
|
||||
else:
|
||||
sts_endpoint = aws_sts_endpoint
|
||||
|
||||
iam_creds_cache_key = json.dumps(
|
||||
{
|
||||
"aws_web_identity_token": aws_web_identity_token,
|
||||
"aws_role_name": aws_role_name,
|
||||
"aws_session_name": aws_session_name,
|
||||
}
|
||||
)
|
||||
|
||||
iam_creds_dict = self.iam_cache.get_cache(iam_creds_cache_key)
|
||||
if iam_creds_dict is None:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AwsAuthError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
region_name=aws_region_name,
|
||||
endpoint_url=sts_endpoint,
|
||||
)
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
Policy='{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}',
|
||||
)
|
||||
|
||||
iam_creds_dict = {
|
||||
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||
"aws_secret_access_key": sts_response["Credentials"][
|
||||
"SecretAccessKey"
|
||||
],
|
||||
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||
"region_name": aws_region_name,
|
||||
}
|
||||
|
||||
self.iam_cache.set_cache(
|
||||
key=iam_creds_cache_key,
|
||||
value=json.dumps(iam_creds_dict),
|
||||
ttl=3600 - 60,
|
||||
)
|
||||
|
||||
if sts_response["PackedPolicySize"] > 75:
|
||||
verbose_logger.warning(
|
||||
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
||||
)
|
||||
|
||||
session = boto3.Session(**iam_creds_dict)
|
||||
|
||||
iam_creds = session.get_credentials()
|
||||
|
||||
return iam_creds
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||
credentials, _cache_ttl = self._auth_with_aws_role(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
# Extract the credentials from the response and convert to Session Credentials
|
||||
sts_credentials = sts_response["Credentials"]
|
||||
|
||||
credentials = Credentials(
|
||||
access_key=sts_credentials["AccessKeyId"],
|
||||
secret_key=sts_credentials["SecretAccessKey"],
|
||||
token=sts_credentials["SessionToken"],
|
||||
)
|
||||
return credentials
|
||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
client = boto3.Session(profile_name=aws_profile_name)
|
||||
|
||||
return client.get_credentials()
|
||||
credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name)
|
||||
elif (
|
||||
aws_access_key_id is not None
|
||||
and aws_secret_access_key is not None
|
||||
and aws_session_token is not None
|
||||
): ### CHECK FOR AWS SESSION TOKEN ###
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
credentials = Credentials(
|
||||
access_key=aws_access_key_id,
|
||||
secret_key=aws_secret_access_key,
|
||||
token=aws_session_token,
|
||||
):
|
||||
credentials, _cache_ttl = self._auth_with_aws_session_token(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
|
||||
return credentials
|
||||
elif (
|
||||
aws_access_key_id is not None
|
||||
and aws_secret_access_key is not None
|
||||
and aws_region_name is not None
|
||||
):
|
||||
# Check if credentials are already in cache. These credentials have no expiry time.
|
||||
cached_credentials: Optional[Credentials] = self.iam_cache.get_cache(
|
||||
cache_key
|
||||
)
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
session = boto3.Session(
|
||||
credentials, _cache_ttl = self._auth_with_access_key_and_secret_key(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
credentials, _cache_ttl = self._auth_with_env_vars()
|
||||
|
||||
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
||||
return credentials
|
||||
|
||||
def _auth_with_web_identity_token(
|
||||
self,
|
||||
aws_web_identity_token: str,
|
||||
aws_role_name: str,
|
||||
aws_session_name: str,
|
||||
aws_region_name: Optional[str],
|
||||
aws_sts_endpoint: Optional[str],
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Web Identity Token
|
||||
"""
|
||||
import boto3
|
||||
|
||||
verbose_logger.debug(
|
||||
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
|
||||
)
|
||||
|
||||
if aws_sts_endpoint is None:
|
||||
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
|
||||
else:
|
||||
sts_endpoint = aws_sts_endpoint
|
||||
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AwsAuthError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
credentials = session.get_credentials()
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
region_name=aws_region_name,
|
||||
endpoint_url=sts_endpoint,
|
||||
)
|
||||
|
||||
if (
|
||||
credentials.token is None
|
||||
): # don't cache if session token exists. The expiry time for that is not known.
|
||||
self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60)
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
Policy='{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}',
|
||||
)
|
||||
|
||||
return credentials
|
||||
else:
|
||||
# check env var. Do not cache the response from this.
|
||||
session = boto3.Session()
|
||||
iam_creds_dict = {
|
||||
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
|
||||
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||
"region_name": aws_region_name,
|
||||
}
|
||||
|
||||
credentials = session.get_credentials()
|
||||
if sts_response["PackedPolicySize"] > 75:
|
||||
verbose_logger.warning(
|
||||
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
||||
)
|
||||
|
||||
return credentials
|
||||
session = boto3.Session(**iam_creds_dict)
|
||||
|
||||
iam_creds = session.get_credentials()
|
||||
return iam_creds, self._get_default_ttl_for_boto3_credentials()
|
||||
|
||||
def _auth_with_aws_role(
|
||||
self,
|
||||
aws_access_key_id: Optional[str],
|
||||
aws_secret_access_key: Optional[str],
|
||||
aws_role_name: str,
|
||||
aws_session_name: str,
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Role
|
||||
"""
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
# Extract the credentials from the response and convert to Session Credentials
|
||||
sts_credentials = sts_response["Credentials"]
|
||||
|
||||
credentials = Credentials(
|
||||
access_key=sts_credentials["AccessKeyId"],
|
||||
secret_key=sts_credentials["SecretAccessKey"],
|
||||
token=sts_credentials["SessionToken"],
|
||||
)
|
||||
|
||||
sts_expiry = sts_credentials["Expiration"]
|
||||
# Convert to timezone-aware datetime for comparison
|
||||
current_time = datetime.now(sts_expiry.tzinfo)
|
||||
sts_ttl = (sts_expiry - current_time).total_seconds() - 60
|
||||
return credentials, sts_ttl
|
||||
|
||||
def _auth_with_aws_profile(
|
||||
self, aws_profile_name: str
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS profile
|
||||
"""
|
||||
import boto3
|
||||
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
client = boto3.Session(profile_name=aws_profile_name)
|
||||
return client.get_credentials(), None
|
||||
|
||||
def _auth_with_aws_session_token(
|
||||
self,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
aws_session_token: str,
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Session Token
|
||||
"""
|
||||
### CHECK FOR AWS SESSION TOKEN ###
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
credentials = Credentials(
|
||||
access_key=aws_access_key_id,
|
||||
secret_key=aws_secret_access_key,
|
||||
token=aws_session_token,
|
||||
)
|
||||
|
||||
return credentials, None
|
||||
|
||||
def _auth_with_access_key_and_secret_key(
|
||||
self,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
aws_region_name: Optional[str],
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Access Key and Secret Key
|
||||
"""
|
||||
import boto3
|
||||
|
||||
# Check if credentials are already in cache. These credentials have no expiry time.
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
credentials = session.get_credentials()
|
||||
return credentials, self._get_default_ttl_for_boto3_credentials()
|
||||
|
||||
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Environment Variables
|
||||
"""
|
||||
import boto3
|
||||
|
||||
session = boto3.Session()
|
||||
credentials = session.get_credentials()
|
||||
return credentials, None
|
||||
|
||||
def _get_default_ttl_for_boto3_credentials(self) -> int:
|
||||
"""
|
||||
Get the default TTL for boto3 credentials
|
||||
|
||||
Returns `3600-60` which is 59 minutes
|
||||
"""
|
||||
return 3600 - 60
|
||||
|
||||
def get_runtime_endpoint(
|
||||
self,
|
||||
|
|
|
@ -58,6 +58,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
self.optional_params = kwargs
|
||||
|
||||
super().__init__(**kwargs)
|
||||
BaseAWSLLM.__init__(self)
|
||||
|
||||
def convert_to_bedrock_format(
|
||||
self,
|
||||
|
|
|
@ -33,6 +33,10 @@ from .base_secret_manager import BaseSecretManager
|
|||
|
||||
|
||||
class AWSSecretsManagerV2(BaseAWSLLM, BaseSecretManager):
|
||||
def __init__(self, **kwargs):
|
||||
BaseSecretManager.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def validate_environment(cls):
|
||||
if "AWS_REGION_NAME" not in os.environ:
|
||||
|
|
181
tests/llm_translation/test_aws_base_llm.py
Normal file
181
tests/llm_translation/test_aws_base_llm.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
import pytest
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from botocore.credentials import Credentials
|
||||
from typing import Dict, Any
|
||||
from litellm.llms.bedrock.base_aws_llm import (
|
||||
BaseAWSLLM,
|
||||
AwsAuthError,
|
||||
Boto3CredentialsInfo,
|
||||
)
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def base_aws_llm():
|
||||
return BaseAWSLLM()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials():
|
||||
return Credentials(
|
||||
access_key="test_access", secret_key="test_secret", token="test_token"
|
||||
)
|
||||
|
||||
|
||||
# Test cache key generation
|
||||
def test_get_cache_key(base_aws_llm):
|
||||
test_args = {
|
||||
"aws_access_key_id": "test_key",
|
||||
"aws_secret_access_key": "test_secret",
|
||||
}
|
||||
cache_key = base_aws_llm.get_cache_key(test_args)
|
||||
assert isinstance(cache_key, str)
|
||||
assert len(cache_key) == 64 # SHA-256 produces 64 character hex string
|
||||
|
||||
|
||||
# Test web identity token authentication
|
||||
@patch("boto3.client")
|
||||
@patch("litellm.llms.bedrock.base_aws_llm.get_secret") # Add this patch
|
||||
def test_auth_with_web_identity_token(mock_get_secret, mock_boto3_client, base_aws_llm):
|
||||
# Mock get_secret to return a token
|
||||
mock_get_secret.return_value = "mocked_oidc_token"
|
||||
|
||||
# Mock the STS client and response
|
||||
mock_sts = MagicMock()
|
||||
mock_sts.assume_role_with_web_identity.return_value = {
|
||||
"Credentials": {
|
||||
"AccessKeyId": "test_access",
|
||||
"SecretAccessKey": "test_secret",
|
||||
"SessionToken": "test_token",
|
||||
},
|
||||
"PackedPolicySize": 10,
|
||||
}
|
||||
mock_boto3_client.return_value = mock_sts
|
||||
|
||||
credentials, ttl = base_aws_llm._auth_with_web_identity_token(
|
||||
aws_web_identity_token="test_token",
|
||||
aws_role_name="test_role",
|
||||
aws_session_name="test_session",
|
||||
aws_region_name="us-west-2",
|
||||
aws_sts_endpoint=None,
|
||||
)
|
||||
|
||||
# Verify get_secret was called with the correct argument
|
||||
mock_get_secret.assert_called_once_with("test_token")
|
||||
|
||||
assert isinstance(credentials, Credentials)
|
||||
assert ttl == 3540 # default TTL (3600 - 60)
|
||||
|
||||
|
||||
# Test AWS role authentication
|
||||
@patch("boto3.client")
|
||||
def test_auth_with_aws_role(mock_boto3_client, base_aws_llm):
|
||||
# Mock the STS client and response
|
||||
mock_sts = MagicMock()
|
||||
expiry_time = datetime.now(timezone.utc)
|
||||
mock_sts.assume_role.return_value = {
|
||||
"Credentials": {
|
||||
"AccessKeyId": "test_access",
|
||||
"SecretAccessKey": "test_secret",
|
||||
"SessionToken": "test_token",
|
||||
"Expiration": expiry_time,
|
||||
}
|
||||
}
|
||||
mock_boto3_client.return_value = mock_sts
|
||||
|
||||
credentials, ttl = base_aws_llm._auth_with_aws_role(
|
||||
aws_access_key_id="test_access",
|
||||
aws_secret_access_key="test_secret",
|
||||
aws_role_name="test_role",
|
||||
aws_session_name="test_session",
|
||||
)
|
||||
|
||||
assert isinstance(credentials, Credentials)
|
||||
assert isinstance(ttl, float)
|
||||
|
||||
|
||||
# Test AWS profile authentication
|
||||
@patch("boto3.Session")
|
||||
def test_auth_with_aws_profile(mock_session, base_aws_llm, mock_credentials):
|
||||
# Mock the session
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.get_credentials.return_value = mock_credentials
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
credentials, ttl = base_aws_llm._auth_with_aws_profile("test_profile")
|
||||
|
||||
assert credentials == mock_credentials
|
||||
assert ttl is None
|
||||
|
||||
|
||||
# Test session token authentication
|
||||
def test_auth_with_aws_session_token(base_aws_llm):
|
||||
credentials, ttl = base_aws_llm._auth_with_aws_session_token(
|
||||
aws_access_key_id="test_access",
|
||||
aws_secret_access_key="test_secret",
|
||||
aws_session_token="test_token",
|
||||
)
|
||||
|
||||
assert isinstance(credentials, Credentials)
|
||||
assert credentials.access_key == "test_access"
|
||||
assert credentials.secret_key == "test_secret"
|
||||
assert credentials.token == "test_token"
|
||||
assert ttl is None
|
||||
|
||||
|
||||
# Test access key and secret key authentication
|
||||
@patch("boto3.Session")
|
||||
def test_auth_with_access_key_and_secret_key(
|
||||
mock_session, base_aws_llm, mock_credentials
|
||||
):
|
||||
# Mock the session
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.get_credentials.return_value = mock_credentials
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
credentials, ttl = base_aws_llm._auth_with_access_key_and_secret_key(
|
||||
aws_access_key_id="test_access",
|
||||
aws_secret_access_key="test_secret",
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
|
||||
assert credentials == mock_credentials
|
||||
assert ttl == 3540 # default TTL (3600 - 60)
|
||||
|
||||
|
||||
# Test environment variables authentication
|
||||
@patch("boto3.Session")
|
||||
def test_auth_with_env_vars(mock_session, base_aws_llm, mock_credentials):
|
||||
# Mock the session
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.get_credentials.return_value = mock_credentials
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
credentials, ttl = base_aws_llm._auth_with_env_vars()
|
||||
|
||||
assert credentials == mock_credentials
|
||||
assert ttl is None
|
||||
|
||||
|
||||
# Test runtime endpoint resolution
|
||||
def test_get_runtime_endpoint(base_aws_llm):
|
||||
endpoint_url, proxy_endpoint_url = base_aws_llm.get_runtime_endpoint(
|
||||
api_base=None, aws_bedrock_runtime_endpoint=None, aws_region_name="us-west-2"
|
||||
)
|
||||
assert endpoint_url == "https://bedrock-runtime.us-west-2.amazonaws.com"
|
||||
assert proxy_endpoint_url == "https://bedrock-runtime.us-west-2.amazonaws.com"
|
||||
|
||||
endpoint_url, proxy_endpoint_url = base_aws_llm.get_runtime_endpoint(
|
||||
aws_bedrock_runtime_endpoint=None, aws_region_name="us-east-1", api_base=None
|
||||
)
|
||||
assert endpoint_url == "https://bedrock-runtime.us-east-1.amazonaws.com"
|
||||
assert proxy_endpoint_url == "https://bedrock-runtime.us-east-1.amazonaws.com"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clear_cache(base_aws_llm):
|
||||
"""Clear the cache before each test"""
|
||||
base_aws_llm.iam_cache.in_memory_cache.cache_dict = {}
|
||||
yield
|
Loading…
Add table
Add a link
Reference in a new issue