forked from phoenix/litellm-mirror
(feat) sagemaker auth in completion
This commit is contained in:
parent
a58d0e9c94
commit
acef90b923
1 changed files with 27 additions and 8 deletions
|
@ -64,16 +64,35 @@ def completion(
|
||||||
):
|
):
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
region_name = (
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
get_secret("AWS_REGION_NAME") or
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
"us-west-2" # default to us-west-2
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
)
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
|
||||||
client = boto3.client(
|
if aws_access_key_id != None:
|
||||||
"sagemaker-runtime",
|
# uses auth params passed to completion
|
||||||
region_name=region_name
|
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||||
)
|
client = boto3.client(
|
||||||
|
service_name="sagemaker-runtime",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||||
|
# boto3 automaticaly reads env variables
|
||||||
|
|
||||||
|
# we need to read region name from env
|
||||||
|
# I assume majority of users use .env for auth
|
||||||
|
region_name = (
|
||||||
|
get_secret("AWS_REGION_NAME") or
|
||||||
|
"us-west-2" # default to us-west-2 if user not specified
|
||||||
|
)
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="sagemaker-runtime",
|
||||||
|
region_name=region_name,
|
||||||
|
)
|
||||||
|
|
||||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||||
inference_params = deepcopy(optional_params)
|
inference_params = deepcopy(optional_params)
|
||||||
inference_params.pop("stream", None)
|
inference_params.pop("stream", None)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue