From c6be8326dbd0da26fcb2164f1a140929cd14fd92 Mon Sep 17 00:00:00 2001 From: Peter Muller Date: Mon, 1 Jul 2024 16:00:42 -0700 Subject: [PATCH 1/4] Allow calling SageMaker endpoints from different regions --- litellm/llms/sagemaker.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 8e75428bb..079951b93 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -185,7 +185,8 @@ def completion( # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") - or "us-west-2" # default to us-west-2 if user not specified + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified ) client = boto3.client( service_name="sagemaker-runtime", @@ -439,7 +440,8 @@ async def async_streaming( # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") - or "us-west-2" # default to us-west-2 if user not specified + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified ) _client = session.client( service_name="sagemaker-runtime", @@ -506,7 +508,8 @@ async def async_completion( # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") - or "us-west-2" # default to us-west-2 if user not specified + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified ) _client = session.client( service_name="sagemaker-runtime", @@ -661,7 +664,8 @@ def embedding( # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") - or "us-west-2" # default to us-west-2 if user not specified + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified ) client = boto3.client( service_name="sagemaker-runtime", From d9e9a8645bf05c6dbd7519034ddf1ce3f5f21d85 Mon Sep 17 00:00:00 2001 From: Peter Muller Date: Tue, 2 Jul 2024 15:30:39 -0700 Subject: [PATCH 2/4] Add tests for SageMaker region selection --- litellm/llms/sagemaker.py | 19 +--- .../tests/test_provider_specific_config.py | 100 ++++++++++++++++++ 2 files changed, 103 insertions(+), 16 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 079951b93..0e0fa8006 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -9,6 +9,9 @@ from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage import sys from copy import deepcopy import httpx # type: ignore +import boto3 +import aioboto3 +import io from .prompt_templates.factory import prompt_factory, custom_prompt @@ -25,10 +28,6 @@ class SagemakerError(Exception): ) # Call the base class constructor with the parameters it needs -import io -import json - - class TokenIterator: def __init__(self, stream, acompletion: bool = False): if acompletion == False: @@ -160,8 +159,6 @@ def completion( logger_fn=None, acompletion: bool = False, ): - import boto3 - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) @@ -416,10 +413,6 @@ async def async_streaming( aws_access_key_id: Optional[str], aws_region_name: Optional[str], ): - """ - Use aioboto3 - """ - import aioboto3 session = aioboto3.Session() @@ -484,10 +477,6 @@ async def async_completion( aws_access_key_id: Optional[str], aws_region_name: Optional[str], ): - """ - Use aioboto3 - """ - import aioboto3 session = aioboto3.Session() @@ -639,8 +628,6 @@ def embedding( """ Supports Huggingface Jumpstart embeddings like GPT-6B """ - ### BOTO3 INIT - import boto3 # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) diff --git a/litellm/tests/test_provider_specific_config.py b/litellm/tests/test_provider_specific_config.py index 08a84b560..e79b5769f 100644 --- a/litellm/tests/test_provider_specific_config.py +++ b/litellm/tests/test_provider_specific_config.py @@ -512,6 +512,106 @@ def sagemaker_test_completion(): # sagemaker_test_completion() + +def test_sagemaker_default_region(mocker): + """ + If no regions are specified in config or in environment, the default region is us-west-2 + """ + mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") + try: + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[ + { + "content": "Hello, world!", + "role": "user" + } + ] + ) + except Exception: + pass # expected serialization exception because AWS client was replaced with a Mock + assert mock_client.call_args.kwargs["region_name"] == "us-west-2" + +# test_sagemaker_provided_region() + + +def test_sagemaker_environment_region(mocker): + """ + If a region is specified in the environment, use that region instead of us-west-2 + """ + expected_region = "us-east-1" + os.environ["AWS_REGION_NAME"] = expected_region + mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") + try: + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[ + { + "content": "Hello, world!", + "role": "user" + } + ] + ) + except Exception: + pass # expected serialization exception because AWS client was replaced with a Mock + del os.environ["AWS_REGION_NAME"] # cleanup + assert mock_client.call_args.kwargs["region_name"] == expected_region + +# test_sagemaker_environment_region() + + +def test_sagemaker_config_region(mocker): + """ + If a region is specified as part of the optional parameters of the completion, including as + part of the config file, then use that region instead of us-west-2 + """ + expected_region = "us-east-1" + mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") + try: + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[ + { + "content": "Hello, world!", + "role": "user" + } + ], + aws_region_name=expected_region, + ) + except Exception: + pass # expected serialization exception because AWS client was replaced with a Mock + assert mock_client.call_args.kwargs["region_name"] == expected_region + +# test_sagemaker_config_region() + + +def test_sagemaker_config_and_environment_region(mocker): + """ + If both the environment and config file specify a region, the environment region is expected + """ + expected_region = "us-east-1" + unexpected_region = "us-east-2" + os.environ["AWS_REGION_NAME"] = expected_region + mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") + try: + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[ + { + "content": "Hello, world!", + "role": "user" + } + ], + aws_region_name=unexpected_region, + ) + except Exception: + pass # expected serialization exception because AWS client was replaced with a Mock + del os.environ["AWS_REGION_NAME"] # cleanup + assert mock_client.call_args.kwargs["region_name"] == expected_region + +# test_sagemaker_config_and_environment_region() + + # Bedrock From 47c97e1fa28345859da11de920c533681f0e1aa1 Mon Sep 17 00:00:00 2001 From: Peter Muller Date: Tue, 2 Jul 2024 15:38:15 -0700 Subject: [PATCH 3/4] Fix test name typo in comment --- litellm/tests/test_provider_specific_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_provider_specific_config.py b/litellm/tests/test_provider_specific_config.py index e79b5769f..c1d5362ec 100644 --- a/litellm/tests/test_provider_specific_config.py +++ b/litellm/tests/test_provider_specific_config.py @@ -532,7 +532,7 @@ def test_sagemaker_default_region(mocker): pass # expected serialization exception because AWS client was replaced with a Mock assert mock_client.call_args.kwargs["region_name"] == "us-west-2" -# test_sagemaker_provided_region() +# test_sagemaker_default_region() def test_sagemaker_environment_region(mocker): From d8fc8252fa6d4371a46643cf255a948837563bb9 Mon Sep 17 00:00:00 2001 From: Peter Muller Date: Tue, 2 Jul 2024 19:09:22 -0700 Subject: [PATCH 4/4] Revert imports changes, update tests to match --- litellm/llms/sagemaker.py | 14 ++++++++++++-- litellm/tests/test_provider_specific_config.py | 8 ++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 0e0fa8006..6892445f0 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -9,8 +9,6 @@ from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage import sys from copy import deepcopy import httpx # type: ignore -import boto3 -import aioboto3 import io from .prompt_templates.factory import prompt_factory, custom_prompt @@ -159,6 +157,8 @@ def completion( logger_fn=None, acompletion: bool = False, ): + import boto3 + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) @@ -413,6 +413,10 @@ async def async_streaming( aws_access_key_id: Optional[str], aws_region_name: Optional[str], ): + """ + Use aioboto3 + """ + import aioboto3 session = aioboto3.Session() @@ -477,6 +481,10 @@ async def async_completion( aws_access_key_id: Optional[str], aws_region_name: Optional[str], ): + """ + Use aioboto3 + """ + import aioboto3 session = aioboto3.Session() @@ -628,6 +636,8 @@ def embedding( """ Supports Huggingface Jumpstart embeddings like GPT-6B """ + ### BOTO3 INIT + import boto3 # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) diff --git a/litellm/tests/test_provider_specific_config.py b/litellm/tests/test_provider_specific_config.py index c1d5362ec..c20c44fb1 100644 --- a/litellm/tests/test_provider_specific_config.py +++ b/litellm/tests/test_provider_specific_config.py @@ -517,7 +517,7 @@ def test_sagemaker_default_region(mocker): """ If no regions are specified in config or in environment, the default region is us-west-2 """ - mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") + mock_client = mocker.patch("boto3.client") try: response = litellm.completion( model="sagemaker/mock-endpoint", @@ -541,7 +541,7 @@ def test_sagemaker_environment_region(mocker): """ expected_region = "us-east-1" os.environ["AWS_REGION_NAME"] = expected_region - mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") + mock_client = mocker.patch("boto3.client") try: response = litellm.completion( model="sagemaker/mock-endpoint", @@ -566,7 +566,7 @@ def test_sagemaker_config_region(mocker): part of the config file, then use that region instead of us-west-2 """ expected_region = "us-east-1" - mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") + mock_client = mocker.patch("boto3.client") try: response = litellm.completion( model="sagemaker/mock-endpoint", @@ -592,7 +592,7 @@ def test_sagemaker_config_and_environment_region(mocker): expected_region = "us-east-1" unexpected_region = "us-east-2" os.environ["AWS_REGION_NAME"] = expected_region - mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") + mock_client = mocker.patch("boto3.client") try: response = litellm.completion( model="sagemaker/mock-endpoint",