mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(sagemaker.py): fix async sagemaker calls
https://github.com/BerriAI/litellm/issues/2086
This commit is contained in:
parent
6546b43e5c
commit
49c4aa5e75
1 changed files with 62 additions and 2 deletions
|
@ -366,7 +366,37 @@ async def async_streaming(
|
|||
import aioboto3
|
||||
|
||||
session = aioboto3.Session()
|
||||
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
|
||||
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
_client = session.client(
|
||||
service_name="sagemaker-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automaticaly reads env variables
|
||||
|
||||
# we need to read region name from env
|
||||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or "us-west-2" # default to us-west-2 if user not specified
|
||||
)
|
||||
_client = session.client(
|
||||
service_name="sagemaker-runtime",
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
async with _client as client:
|
||||
try:
|
||||
response = await client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=model,
|
||||
|
@ -395,7 +425,37 @@ async def async_completion(
|
|||
import aioboto3
|
||||
|
||||
session = aioboto3.Session()
|
||||
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
|
||||
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
_client = session.client(
|
||||
service_name="sagemaker-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automaticaly reads env variables
|
||||
|
||||
# we need to read region name from env
|
||||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or "us-west-2" # default to us-west-2 if user not specified
|
||||
)
|
||||
_client = session.client(
|
||||
service_name="sagemaker-runtime",
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
async with _client as client:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue