(feat) sagemaker auth in completion

This commit is contained in:
ishaan-jaff 2023-10-07 15:27:56 -07:00
parent a58d0e9c94
commit acef90b923

View file

@ -64,16 +64,35 @@ def completion(
): ):
import boto3 import boto3
region_name = ( # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
get_secret("AWS_REGION_NAME") or aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
"us-west-2" # default to us-west-2 aws_access_key_id = optional_params.pop("aws_access_key_id", None)
) aws_region_name = optional_params.pop("aws_region_name", None)
client = boto3.client( if aws_access_key_id != None:
"sagemaker-runtime", # uses auth params passed to completion
region_name=region_name # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
) client = boto3.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME") or
"us-west-2" # default to us-west-2 if user not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params) inference_params = deepcopy(optional_params)
inference_params.pop("stream", None) inference_params.pop("stream", None)