fix(sagemaker.py): fix async sagemaker calls

https://github.com/BerriAI/litellm/issues/2086
This commit is contained in:
Krrish Dholakia 2024-02-20 17:20:01 -08:00
parent 6546b43e5c
commit 49c4aa5e75

View file

@ -366,7 +366,37 @@ async def async_streaming(
import aioboto3
session = aioboto3.Session()
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
_client = session.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or "us-west-2" # default to us-west-2 if user not specified
)
_client = session.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
async with _client as client:
try:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
@ -395,7 +425,37 @@ async def async_completion(
import aioboto3
session = aioboto3.Session()
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
_client = session.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or "us-west-2" # default to us-west-2 if user not specified
)
_client = session.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
async with _client as client:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(