diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index b9478081fd..bd57301fda 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -64,16 +64,35 @@ def completion( ): import boto3 - region_name = ( - get_secret("AWS_REGION_NAME") or - "us-west-2" # default to us-west-2 - ) + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) - client = boto3.client( - "sagemaker-runtime", - region_name=region_name - ) + if aws_access_key_id != None: + # uses auth params passed to completion + # aws_access_key_id is not None, assume user is trying to auth using litellm.completion + client = boto3.client( + service_name="sagemaker-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + else: + # aws_access_key_id is None, assume user is trying to auth using env variables + # boto3 automaticaly reads env variables + # we need to read region name from env + # I assume majority of users use .env for auth + region_name = ( + get_secret("AWS_REGION_NAME") or + "us-west-2" # default to us-west-2 if user not specified + ) + client = boto3.client( + service_name="sagemaker-runtime", + region_name=region_name, + ) + # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker inference_params = deepcopy(optional_params) inference_params.pop("stream", None)