diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 8e75428bb..6892445f0 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -9,6 +9,7 @@ from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage import sys from copy import deepcopy import httpx # type: ignore +import io from .prompt_templates.factory import prompt_factory, custom_prompt @@ -25,10 +26,6 @@ class SagemakerError(Exception): ) # Call the base class constructor with the parameters it needs -import io -import json - - class TokenIterator: def __init__(self, stream, acompletion: bool = False): if acompletion == False: @@ -185,7 +182,8 @@ def completion( # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") - or "us-west-2" # default to us-west-2 if user not specified + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified ) client = boto3.client( service_name="sagemaker-runtime", @@ -439,7 +437,8 @@ async def async_streaming( # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") - or "us-west-2" # default to us-west-2 if user not specified + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified ) _client = session.client( service_name="sagemaker-runtime", @@ -506,7 +505,8 @@ async def async_completion( # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") - or "us-west-2" # default to us-west-2 if user not specified + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified ) _client = session.client( service_name="sagemaker-runtime", @@ -661,7 +661,8 @@ def embedding( # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") - or "us-west-2" # default to us-west-2 if user not specified + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified ) client = boto3.client( service_name="sagemaker-runtime", diff --git a/litellm/tests/test_provider_specific_config.py b/litellm/tests/test_provider_specific_config.py index 08a84b560..c20c44fb1 100644 --- a/litellm/tests/test_provider_specific_config.py +++ b/litellm/tests/test_provider_specific_config.py @@ -512,6 +512,106 @@ def sagemaker_test_completion(): # sagemaker_test_completion() + +def test_sagemaker_default_region(mocker): + """ + If no regions are specified in config or in environment, the default region is us-west-2 + """ + mock_client = mocker.patch("boto3.client") + try: + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[ + { + "content": "Hello, world!", + "role": "user" + } + ] + ) + except Exception: + pass # expected serialization exception because AWS client was replaced with a Mock + assert mock_client.call_args.kwargs["region_name"] == "us-west-2" + +# test_sagemaker_default_region() + + +def test_sagemaker_environment_region(mocker): + """ + If a region is specified in the environment, use that region instead of us-west-2 + """ + expected_region = "us-east-1" + os.environ["AWS_REGION_NAME"] = expected_region + mock_client = mocker.patch("boto3.client") + try: + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[ + { + "content": "Hello, world!", + "role": "user" + } + ] + ) + except Exception: + pass # expected serialization exception because AWS client was replaced with a Mock + del os.environ["AWS_REGION_NAME"] # cleanup + assert mock_client.call_args.kwargs["region_name"] == expected_region + +# test_sagemaker_environment_region() + + +def test_sagemaker_config_region(mocker): + """ + If a region is specified as part of the optional parameters of the completion, including as + part of the config file, then use that region instead of us-west-2 + """ + expected_region = "us-east-1" + mock_client = mocker.patch("boto3.client") + try: + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[ + { + "content": "Hello, world!", + "role": "user" + } + ], + aws_region_name=expected_region, + ) + except Exception: + pass # expected serialization exception because AWS client was replaced with a Mock + assert mock_client.call_args.kwargs["region_name"] == expected_region + +# test_sagemaker_config_region() + + +def test_sagemaker_config_and_environment_region(mocker): + """ + If both the environment and config file specify a region, the environment region is expected + """ + expected_region = "us-east-1" + unexpected_region = "us-east-2" + os.environ["AWS_REGION_NAME"] = expected_region + mock_client = mocker.patch("boto3.client") + try: + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[ + { + "content": "Hello, world!", + "role": "user" + } + ], + aws_region_name=unexpected_region, + ) + except Exception: + pass # expected serialization exception because AWS client was replaced with a Mock + del os.environ["AWS_REGION_NAME"] # cleanup + assert mock_client.call_args.kwargs["region_name"] == expected_region + +# test_sagemaker_config_and_environment_region() + + # Bedrock