mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(fix) BaseAWSLLM
- cache IAM role credentials when used (#7775)
* fix base aws llm * fix auth with aws role * test aws base llm * fix base aws llm init * run ci/cd again * fix get_credentials * ci/cd run again * _auth_with_aws_role
This commit is contained in:
parent
5fbbf47581
commit
30bb4c4cdd
6 changed files with 397 additions and 129 deletions
|
@ -2,7 +2,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||||
### INIT VARIABLES #####
|
### INIT VARIABLES ######
|
||||||
import threading
|
import threading
|
||||||
import os
|
import os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Custom Logger that handles batching logic
|
Custom Logger that handles batching logic
|
||||||
|
|
||||||
Use this if you want your logs to be stored in memory and flushed periodically
|
Use this if you want your logs to be stored in memory and flushed periodically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -48,7 +49,7 @@ class BaseAWSLLM:
|
||||||
credential_str = json.dumps(credential_args, sort_keys=True)
|
credential_str = json.dumps(credential_args, sort_keys=True)
|
||||||
return hashlib.sha256(credential_str.encode()).hexdigest()
|
return hashlib.sha256(credential_str.encode()).hexdigest()
|
||||||
|
|
||||||
def get_credentials( # noqa: PLR0915
|
def get_credentials(
|
||||||
self,
|
self,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
aws_secret_access_key: Optional[str] = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
@ -63,10 +64,6 @@ class BaseAWSLLM:
|
||||||
"""
|
"""
|
||||||
Return a boto3.Credentials object
|
Return a boto3.Credentials object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import boto3
|
|
||||||
from botocore.credentials import Credentials
|
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
## CHECK IS 'os.environ/' passed in
|
||||||
param_names = [
|
param_names = [
|
||||||
"aws_access_key_id",
|
"aws_access_key_id",
|
||||||
|
@ -115,10 +112,6 @@ class BaseAWSLLM:
|
||||||
aws_sts_endpoint,
|
aws_sts_endpoint,
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
# create cache key for non-expiring auth flows
|
|
||||||
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
|
|
||||||
cache_key = self.get_cache_key(args)
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"in get credentials\n"
|
"in get credentials\n"
|
||||||
"aws_access_key_id=%s\n"
|
"aws_access_key_id=%s\n"
|
||||||
|
@ -141,12 +134,82 @@ class BaseAWSLLM:
|
||||||
aws_sts_endpoint,
|
aws_sts_endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
### CHECK STS ###
|
# create cache key for non-expiring auth flows
|
||||||
|
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
|
||||||
|
|
||||||
|
cache_key = self.get_cache_key(args)
|
||||||
|
_cached_credentials = self.iam_cache.get_cache(cache_key)
|
||||||
|
if _cached_credentials:
|
||||||
|
return _cached_credentials
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Handle diff boto3 auth flows
|
||||||
|
# for each helper
|
||||||
|
# Return:
|
||||||
|
# Credentials - boto3.Credentials
|
||||||
|
# cache ttl - Optional[int]. If None, the credentials are not cached. Some auth flows have no expiry time.
|
||||||
|
#########################################################
|
||||||
if (
|
if (
|
||||||
aws_web_identity_token is not None
|
aws_web_identity_token is not None
|
||||||
and aws_role_name is not None
|
and aws_role_name is not None
|
||||||
and aws_session_name is not None
|
and aws_session_name is not None
|
||||||
):
|
):
|
||||||
|
credentials, _cache_ttl = self._auth_with_web_identity_token(
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_sts_endpoint=aws_sts_endpoint,
|
||||||
|
)
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
|
credentials, _cache_ttl = self._auth_with_aws_role(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||||
|
credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name)
|
||||||
|
elif (
|
||||||
|
aws_access_key_id is not None
|
||||||
|
and aws_secret_access_key is not None
|
||||||
|
and aws_session_token is not None
|
||||||
|
):
|
||||||
|
credentials, _cache_ttl = self._auth_with_aws_session_token(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
aws_access_key_id is not None
|
||||||
|
and aws_secret_access_key is not None
|
||||||
|
and aws_region_name is not None
|
||||||
|
):
|
||||||
|
credentials, _cache_ttl = self._auth_with_access_key_and_secret_key(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
credentials, _cache_ttl = self._auth_with_env_vars()
|
||||||
|
|
||||||
|
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def _auth_with_web_identity_token(
|
||||||
|
self,
|
||||||
|
aws_web_identity_token: str,
|
||||||
|
aws_role_name: str,
|
||||||
|
aws_session_name: str,
|
||||||
|
aws_region_name: Optional[str],
|
||||||
|
aws_sts_endpoint: Optional[str],
|
||||||
|
) -> Tuple[Credentials, Optional[int]]:
|
||||||
|
"""
|
||||||
|
Authenticate with AWS Web Identity Token
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
|
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
|
||||||
)
|
)
|
||||||
|
@ -156,16 +219,6 @@ class BaseAWSLLM:
|
||||||
else:
|
else:
|
||||||
sts_endpoint = aws_sts_endpoint
|
sts_endpoint = aws_sts_endpoint
|
||||||
|
|
||||||
iam_creds_cache_key = json.dumps(
|
|
||||||
{
|
|
||||||
"aws_web_identity_token": aws_web_identity_token,
|
|
||||||
"aws_role_name": aws_role_name,
|
|
||||||
"aws_session_name": aws_session_name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
iam_creds_dict = self.iam_cache.get_cache(iam_creds_cache_key)
|
|
||||||
if iam_creds_dict is None:
|
|
||||||
oidc_token = get_secret(aws_web_identity_token)
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
if oidc_token is None:
|
if oidc_token is None:
|
||||||
|
@ -192,19 +245,11 @@ class BaseAWSLLM:
|
||||||
|
|
||||||
iam_creds_dict = {
|
iam_creds_dict = {
|
||||||
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||||
"aws_secret_access_key": sts_response["Credentials"][
|
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
|
||||||
"SecretAccessKey"
|
|
||||||
],
|
|
||||||
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||||
"region_name": aws_region_name,
|
"region_name": aws_region_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.iam_cache.set_cache(
|
|
||||||
key=iam_creds_cache_key,
|
|
||||||
value=json.dumps(iam_creds_dict),
|
|
||||||
ttl=3600 - 60,
|
|
||||||
)
|
|
||||||
|
|
||||||
if sts_response["PackedPolicySize"] > 75:
|
if sts_response["PackedPolicySize"] > 75:
|
||||||
verbose_logger.warning(
|
verbose_logger.warning(
|
||||||
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
||||||
|
@ -213,9 +258,21 @@ class BaseAWSLLM:
|
||||||
session = boto3.Session(**iam_creds_dict)
|
session = boto3.Session(**iam_creds_dict)
|
||||||
|
|
||||||
iam_creds = session.get_credentials()
|
iam_creds = session.get_credentials()
|
||||||
|
return iam_creds, self._get_default_ttl_for_boto3_credentials()
|
||||||
|
|
||||||
|
def _auth_with_aws_role(
|
||||||
|
self,
|
||||||
|
aws_access_key_id: Optional[str],
|
||||||
|
aws_secret_access_key: Optional[str],
|
||||||
|
aws_role_name: str,
|
||||||
|
aws_session_name: str,
|
||||||
|
) -> Tuple[Credentials, Optional[int]]:
|
||||||
|
"""
|
||||||
|
Authenticate with AWS Role
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
return iam_creds
|
|
||||||
elif aws_role_name is not None and aws_session_name is not None:
|
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||||
|
@ -234,17 +291,35 @@ class BaseAWSLLM:
|
||||||
secret_key=sts_credentials["SecretAccessKey"],
|
secret_key=sts_credentials["SecretAccessKey"],
|
||||||
token=sts_credentials["SessionToken"],
|
token=sts_credentials["SessionToken"],
|
||||||
)
|
)
|
||||||
return credentials
|
|
||||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
sts_expiry = sts_credentials["Expiration"]
|
||||||
|
# Convert to timezone-aware datetime for comparison
|
||||||
|
current_time = datetime.now(sts_expiry.tzinfo)
|
||||||
|
sts_ttl = (sts_expiry - current_time).total_seconds() - 60
|
||||||
|
return credentials, sts_ttl
|
||||||
|
|
||||||
|
def _auth_with_aws_profile(
|
||||||
|
self, aws_profile_name: str
|
||||||
|
) -> Tuple[Credentials, Optional[int]]:
|
||||||
|
"""
|
||||||
|
Authenticate with AWS profile
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
client = boto3.Session(profile_name=aws_profile_name)
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
return client.get_credentials(), None
|
||||||
|
|
||||||
return client.get_credentials()
|
def _auth_with_aws_session_token(
|
||||||
elif (
|
self,
|
||||||
aws_access_key_id is not None
|
aws_access_key_id: str,
|
||||||
and aws_secret_access_key is not None
|
aws_secret_access_key: str,
|
||||||
and aws_session_token is not None
|
aws_session_token: str,
|
||||||
): ### CHECK FOR AWS SESSION TOKEN ###
|
) -> Tuple[Credentials, Optional[int]]:
|
||||||
|
"""
|
||||||
|
Authenticate with AWS Session Token
|
||||||
|
"""
|
||||||
|
### CHECK FOR AWS SESSION TOKEN ###
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
credentials = Credentials(
|
credentials = Credentials(
|
||||||
|
@ -253,18 +328,20 @@ class BaseAWSLLM:
|
||||||
token=aws_session_token,
|
token=aws_session_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
return credentials
|
return credentials, None
|
||||||
elif (
|
|
||||||
aws_access_key_id is not None
|
def _auth_with_access_key_and_secret_key(
|
||||||
and aws_secret_access_key is not None
|
self,
|
||||||
and aws_region_name is not None
|
aws_access_key_id: str,
|
||||||
):
|
aws_secret_access_key: str,
|
||||||
|
aws_region_name: Optional[str],
|
||||||
|
) -> Tuple[Credentials, Optional[int]]:
|
||||||
|
"""
|
||||||
|
Authenticate with AWS Access Key and Secret Key
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
# Check if credentials are already in cache. These credentials have no expiry time.
|
# Check if credentials are already in cache. These credentials have no expiry time.
|
||||||
cached_credentials: Optional[Credentials] = self.iam_cache.get_cache(
|
|
||||||
cache_key
|
|
||||||
)
|
|
||||||
if cached_credentials:
|
|
||||||
return cached_credentials
|
|
||||||
|
|
||||||
session = boto3.Session(
|
session = boto3.Session(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
@ -273,20 +350,25 @@ class BaseAWSLLM:
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials = session.get_credentials()
|
credentials = session.get_credentials()
|
||||||
|
return credentials, self._get_default_ttl_for_boto3_credentials()
|
||||||
|
|
||||||
if (
|
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
|
||||||
credentials.token is None
|
"""
|
||||||
): # don't cache if session token exists. The expiry time for that is not known.
|
Authenticate with AWS Environment Variables
|
||||||
self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60)
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
return credentials
|
|
||||||
else:
|
|
||||||
# check env var. Do not cache the response from this.
|
|
||||||
session = boto3.Session()
|
session = boto3.Session()
|
||||||
|
|
||||||
credentials = session.get_credentials()
|
credentials = session.get_credentials()
|
||||||
|
return credentials, None
|
||||||
|
|
||||||
return credentials
|
def _get_default_ttl_for_boto3_credentials(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the default TTL for boto3 credentials
|
||||||
|
|
||||||
|
Returns `3600-60` which is 59 minutes
|
||||||
|
"""
|
||||||
|
return 3600 - 60
|
||||||
|
|
||||||
def get_runtime_endpoint(
|
def get_runtime_endpoint(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -58,6 +58,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||||
self.optional_params = kwargs
|
self.optional_params = kwargs
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
BaseAWSLLM.__init__(self)
|
||||||
|
|
||||||
def convert_to_bedrock_format(
|
def convert_to_bedrock_format(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -33,6 +33,10 @@ from .base_secret_manager import BaseSecretManager
|
||||||
|
|
||||||
|
|
||||||
class AWSSecretsManagerV2(BaseAWSLLM, BaseSecretManager):
|
class AWSSecretsManagerV2(BaseAWSLLM, BaseSecretManager):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
BaseSecretManager.__init__(self, **kwargs)
|
||||||
|
BaseAWSLLM.__init__(self, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_environment(cls):
|
def validate_environment(cls):
|
||||||
if "AWS_REGION_NAME" not in os.environ:
|
if "AWS_REGION_NAME" not in os.environ:
|
||||||
|
|
181
tests/llm_translation/test_aws_base_llm.py
Normal file
181
tests/llm_translation/test_aws_base_llm.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
from typing import Dict, Any
|
||||||
|
from litellm.llms.bedrock.base_aws_llm import (
|
||||||
|
BaseAWSLLM,
|
||||||
|
AwsAuthError,
|
||||||
|
Boto3CredentialsInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def base_aws_llm():
|
||||||
|
return BaseAWSLLM()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_credentials():
|
||||||
|
return Credentials(
|
||||||
|
access_key="test_access", secret_key="test_secret", token="test_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test cache key generation
|
||||||
|
def test_get_cache_key(base_aws_llm):
|
||||||
|
test_args = {
|
||||||
|
"aws_access_key_id": "test_key",
|
||||||
|
"aws_secret_access_key": "test_secret",
|
||||||
|
}
|
||||||
|
cache_key = base_aws_llm.get_cache_key(test_args)
|
||||||
|
assert isinstance(cache_key, str)
|
||||||
|
assert len(cache_key) == 64 # SHA-256 produces 64 character hex string
|
||||||
|
|
||||||
|
|
||||||
|
# Test web identity token authentication
|
||||||
|
@patch("boto3.client")
|
||||||
|
@patch("litellm.llms.bedrock.base_aws_llm.get_secret") # Add this patch
|
||||||
|
def test_auth_with_web_identity_token(mock_get_secret, mock_boto3_client, base_aws_llm):
|
||||||
|
# Mock get_secret to return a token
|
||||||
|
mock_get_secret.return_value = "mocked_oidc_token"
|
||||||
|
|
||||||
|
# Mock the STS client and response
|
||||||
|
mock_sts = MagicMock()
|
||||||
|
mock_sts.assume_role_with_web_identity.return_value = {
|
||||||
|
"Credentials": {
|
||||||
|
"AccessKeyId": "test_access",
|
||||||
|
"SecretAccessKey": "test_secret",
|
||||||
|
"SessionToken": "test_token",
|
||||||
|
},
|
||||||
|
"PackedPolicySize": 10,
|
||||||
|
}
|
||||||
|
mock_boto3_client.return_value = mock_sts
|
||||||
|
|
||||||
|
credentials, ttl = base_aws_llm._auth_with_web_identity_token(
|
||||||
|
aws_web_identity_token="test_token",
|
||||||
|
aws_role_name="test_role",
|
||||||
|
aws_session_name="test_session",
|
||||||
|
aws_region_name="us-west-2",
|
||||||
|
aws_sts_endpoint=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify get_secret was called with the correct argument
|
||||||
|
mock_get_secret.assert_called_once_with("test_token")
|
||||||
|
|
||||||
|
assert isinstance(credentials, Credentials)
|
||||||
|
assert ttl == 3540 # default TTL (3600 - 60)
|
||||||
|
|
||||||
|
|
||||||
|
# Test AWS role authentication
|
||||||
|
@patch("boto3.client")
|
||||||
|
def test_auth_with_aws_role(mock_boto3_client, base_aws_llm):
|
||||||
|
# Mock the STS client and response
|
||||||
|
mock_sts = MagicMock()
|
||||||
|
expiry_time = datetime.now(timezone.utc)
|
||||||
|
mock_sts.assume_role.return_value = {
|
||||||
|
"Credentials": {
|
||||||
|
"AccessKeyId": "test_access",
|
||||||
|
"SecretAccessKey": "test_secret",
|
||||||
|
"SessionToken": "test_token",
|
||||||
|
"Expiration": expiry_time,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_boto3_client.return_value = mock_sts
|
||||||
|
|
||||||
|
credentials, ttl = base_aws_llm._auth_with_aws_role(
|
||||||
|
aws_access_key_id="test_access",
|
||||||
|
aws_secret_access_key="test_secret",
|
||||||
|
aws_role_name="test_role",
|
||||||
|
aws_session_name="test_session",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(credentials, Credentials)
|
||||||
|
assert isinstance(ttl, float)
|
||||||
|
|
||||||
|
|
||||||
|
# Test AWS profile authentication
|
||||||
|
@patch("boto3.Session")
|
||||||
|
def test_auth_with_aws_profile(mock_session, base_aws_llm, mock_credentials):
|
||||||
|
# Mock the session
|
||||||
|
mock_session_instance = MagicMock()
|
||||||
|
mock_session_instance.get_credentials.return_value = mock_credentials
|
||||||
|
mock_session.return_value = mock_session_instance
|
||||||
|
|
||||||
|
credentials, ttl = base_aws_llm._auth_with_aws_profile("test_profile")
|
||||||
|
|
||||||
|
assert credentials == mock_credentials
|
||||||
|
assert ttl is None
|
||||||
|
|
||||||
|
|
||||||
|
# Test session token authentication
|
||||||
|
def test_auth_with_aws_session_token(base_aws_llm):
|
||||||
|
credentials, ttl = base_aws_llm._auth_with_aws_session_token(
|
||||||
|
aws_access_key_id="test_access",
|
||||||
|
aws_secret_access_key="test_secret",
|
||||||
|
aws_session_token="test_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(credentials, Credentials)
|
||||||
|
assert credentials.access_key == "test_access"
|
||||||
|
assert credentials.secret_key == "test_secret"
|
||||||
|
assert credentials.token == "test_token"
|
||||||
|
assert ttl is None
|
||||||
|
|
||||||
|
|
||||||
|
# Test access key and secret key authentication
|
||||||
|
@patch("boto3.Session")
|
||||||
|
def test_auth_with_access_key_and_secret_key(
|
||||||
|
mock_session, base_aws_llm, mock_credentials
|
||||||
|
):
|
||||||
|
# Mock the session
|
||||||
|
mock_session_instance = MagicMock()
|
||||||
|
mock_session_instance.get_credentials.return_value = mock_credentials
|
||||||
|
mock_session.return_value = mock_session_instance
|
||||||
|
|
||||||
|
credentials, ttl = base_aws_llm._auth_with_access_key_and_secret_key(
|
||||||
|
aws_access_key_id="test_access",
|
||||||
|
aws_secret_access_key="test_secret",
|
||||||
|
aws_region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert credentials == mock_credentials
|
||||||
|
assert ttl == 3540 # default TTL (3600 - 60)
|
||||||
|
|
||||||
|
|
||||||
|
# Test environment variables authentication
|
||||||
|
@patch("boto3.Session")
|
||||||
|
def test_auth_with_env_vars(mock_session, base_aws_llm, mock_credentials):
|
||||||
|
# Mock the session
|
||||||
|
mock_session_instance = MagicMock()
|
||||||
|
mock_session_instance.get_credentials.return_value = mock_credentials
|
||||||
|
mock_session.return_value = mock_session_instance
|
||||||
|
|
||||||
|
credentials, ttl = base_aws_llm._auth_with_env_vars()
|
||||||
|
|
||||||
|
assert credentials == mock_credentials
|
||||||
|
assert ttl is None
|
||||||
|
|
||||||
|
|
||||||
|
# Test runtime endpoint resolution
|
||||||
|
def test_get_runtime_endpoint(base_aws_llm):
|
||||||
|
endpoint_url, proxy_endpoint_url = base_aws_llm.get_runtime_endpoint(
|
||||||
|
api_base=None, aws_bedrock_runtime_endpoint=None, aws_region_name="us-west-2"
|
||||||
|
)
|
||||||
|
assert endpoint_url == "https://bedrock-runtime.us-west-2.amazonaws.com"
|
||||||
|
assert proxy_endpoint_url == "https://bedrock-runtime.us-west-2.amazonaws.com"
|
||||||
|
|
||||||
|
endpoint_url, proxy_endpoint_url = base_aws_llm.get_runtime_endpoint(
|
||||||
|
aws_bedrock_runtime_endpoint=None, aws_region_name="us-east-1", api_base=None
|
||||||
|
)
|
||||||
|
assert endpoint_url == "https://bedrock-runtime.us-east-1.amazonaws.com"
|
||||||
|
assert proxy_endpoint_url == "https://bedrock-runtime.us-east-1.amazonaws.com"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clear_cache(base_aws_llm):
|
||||||
|
"""Clear the cache before each test"""
|
||||||
|
base_aws_llm.iam_cache.in_memory_cache.cache_dict = {}
|
||||||
|
yield
|
Loading…
Add table
Add a link
Reference in a new issue