forked from phoenix/litellm-mirror
Merge pull request #4499 from petermuller/main
Allow calling SageMaker endpoints from different regions
This commit is contained in:
commit
944f22a089
2 changed files with 109 additions and 8 deletions
|
@ -9,6 +9,7 @@ from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
|||
import sys
|
||||
from copy import deepcopy
|
||||
import httpx # type: ignore
|
||||
import io
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
|
@ -25,10 +26,6 @@ class SagemakerError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
|
||||
class TokenIterator:
|
||||
def __init__(self, stream, acompletion: bool = False):
|
||||
if acompletion == False:
|
||||
|
@ -185,7 +182,8 @@ def completion(
|
|||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or "us-west-2" # default to us-west-2 if user not specified
|
||||
or aws_region_name # get region from config file if specified
|
||||
or "us-west-2" # default to us-west-2 if region not specified
|
||||
)
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
|
@ -439,7 +437,8 @@ async def async_streaming(
|
|||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or "us-west-2" # default to us-west-2 if user not specified
|
||||
or aws_region_name # get region from config file if specified
|
||||
or "us-west-2" # default to us-west-2 if region not specified
|
||||
)
|
||||
_client = session.client(
|
||||
service_name="sagemaker-runtime",
|
||||
|
@ -506,7 +505,8 @@ async def async_completion(
|
|||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or "us-west-2" # default to us-west-2 if user not specified
|
||||
or aws_region_name # get region from config file if specified
|
||||
or "us-west-2" # default to us-west-2 if region not specified
|
||||
)
|
||||
_client = session.client(
|
||||
service_name="sagemaker-runtime",
|
||||
|
@ -661,7 +661,8 @@ def embedding(
|
|||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or "us-west-2" # default to us-west-2 if user not specified
|
||||
or aws_region_name # get region from config file if specified
|
||||
or "us-west-2" # default to us-west-2 if region not specified
|
||||
)
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
|
|
|
@ -512,6 +512,106 @@ def sagemaker_test_completion():
|
|||
|
||||
# sagemaker_test_completion()
|
||||
|
||||
|
||||
def test_sagemaker_default_region(mocker):
|
||||
"""
|
||||
If no regions are specified in config or in environment, the default region is us-west-2
|
||||
"""
|
||||
mock_client = mocker.patch("boto3.client")
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="sagemaker/mock-endpoint",
|
||||
messages=[
|
||||
{
|
||||
"content": "Hello, world!",
|
||||
"role": "user"
|
||||
}
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
pass # expected serialization exception because AWS client was replaced with a Mock
|
||||
assert mock_client.call_args.kwargs["region_name"] == "us-west-2"
|
||||
|
||||
# test_sagemaker_default_region()
|
||||
|
||||
|
||||
def test_sagemaker_environment_region(mocker):
|
||||
"""
|
||||
If a region is specified in the environment, use that region instead of us-west-2
|
||||
"""
|
||||
expected_region = "us-east-1"
|
||||
os.environ["AWS_REGION_NAME"] = expected_region
|
||||
mock_client = mocker.patch("boto3.client")
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="sagemaker/mock-endpoint",
|
||||
messages=[
|
||||
{
|
||||
"content": "Hello, world!",
|
||||
"role": "user"
|
||||
}
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
pass # expected serialization exception because AWS client was replaced with a Mock
|
||||
del os.environ["AWS_REGION_NAME"] # cleanup
|
||||
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
||||
|
||||
# test_sagemaker_environment_region()
|
||||
|
||||
|
||||
def test_sagemaker_config_region(mocker):
|
||||
"""
|
||||
If a region is specified as part of the optional parameters of the completion, including as
|
||||
part of the config file, then use that region instead of us-west-2
|
||||
"""
|
||||
expected_region = "us-east-1"
|
||||
mock_client = mocker.patch("boto3.client")
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="sagemaker/mock-endpoint",
|
||||
messages=[
|
||||
{
|
||||
"content": "Hello, world!",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
aws_region_name=expected_region,
|
||||
)
|
||||
except Exception:
|
||||
pass # expected serialization exception because AWS client was replaced with a Mock
|
||||
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
||||
|
||||
# test_sagemaker_config_region()
|
||||
|
||||
|
||||
def test_sagemaker_config_and_environment_region(mocker):
|
||||
"""
|
||||
If both the environment and config file specify a region, the environment region is expected
|
||||
"""
|
||||
expected_region = "us-east-1"
|
||||
unexpected_region = "us-east-2"
|
||||
os.environ["AWS_REGION_NAME"] = expected_region
|
||||
mock_client = mocker.patch("boto3.client")
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="sagemaker/mock-endpoint",
|
||||
messages=[
|
||||
{
|
||||
"content": "Hello, world!",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
aws_region_name=unexpected_region,
|
||||
)
|
||||
except Exception:
|
||||
pass # expected serialization exception because AWS client was replaced with a Mock
|
||||
del os.environ["AWS_REGION_NAME"] # cleanup
|
||||
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
||||
|
||||
# test_sagemaker_config_and_environment_region()
|
||||
|
||||
|
||||
# Bedrock
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue