mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Add tests for SageMaker region selection
This commit is contained in:
parent
c6be8326db
commit
d9e9a8645b
2 changed files with 103 additions and 16 deletions
|
@ -9,6 +9,9 @@ from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
import boto3
|
||||||
|
import aioboto3
|
||||||
|
import io
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,10 +28,6 @@ class SagemakerError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
class TokenIterator:
|
class TokenIterator:
|
||||||
def __init__(self, stream, acompletion: bool = False):
|
def __init__(self, stream, acompletion: bool = False):
|
||||||
if acompletion == False:
|
if acompletion == False:
|
||||||
|
@ -160,8 +159,6 @@ def completion(
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
):
|
):
|
||||||
import boto3
|
|
||||||
|
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
@ -416,10 +413,6 @@ async def async_streaming(
|
||||||
aws_access_key_id: Optional[str],
|
aws_access_key_id: Optional[str],
|
||||||
aws_region_name: Optional[str],
|
aws_region_name: Optional[str],
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Use aioboto3
|
|
||||||
"""
|
|
||||||
import aioboto3
|
|
||||||
|
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
|
|
||||||
|
@ -484,10 +477,6 @@ async def async_completion(
|
||||||
aws_access_key_id: Optional[str],
|
aws_access_key_id: Optional[str],
|
||||||
aws_region_name: Optional[str],
|
aws_region_name: Optional[str],
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Use aioboto3
|
|
||||||
"""
|
|
||||||
import aioboto3
|
|
||||||
|
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
|
|
||||||
|
@ -639,8 +628,6 @@ def embedding(
|
||||||
"""
|
"""
|
||||||
Supports Huggingface Jumpstart embeddings like GPT-6B
|
Supports Huggingface Jumpstart embeddings like GPT-6B
|
||||||
"""
|
"""
|
||||||
### BOTO3 INIT
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
|
|
@ -512,6 +512,106 @@ def sagemaker_test_completion():
|
||||||
|
|
||||||
# sagemaker_test_completion()
|
# sagemaker_test_completion()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sagemaker_default_region(mocker):
|
||||||
|
"""
|
||||||
|
If no regions are specified in config or in environment, the default region is us-west-2
|
||||||
|
"""
|
||||||
|
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client")
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/mock-endpoint",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"content": "Hello, world!",
|
||||||
|
"role": "user"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # expected serialization exception because AWS client was replaced with a Mock
|
||||||
|
assert mock_client.call_args.kwargs["region_name"] == "us-west-2"
|
||||||
|
|
||||||
|
# test_sagemaker_provided_region()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sagemaker_environment_region(mocker):
|
||||||
|
"""
|
||||||
|
If a region is specified in the environment, use that region instead of us-west-2
|
||||||
|
"""
|
||||||
|
expected_region = "us-east-1"
|
||||||
|
os.environ["AWS_REGION_NAME"] = expected_region
|
||||||
|
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client")
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/mock-endpoint",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"content": "Hello, world!",
|
||||||
|
"role": "user"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # expected serialization exception because AWS client was replaced with a Mock
|
||||||
|
del os.environ["AWS_REGION_NAME"] # cleanup
|
||||||
|
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
||||||
|
|
||||||
|
# test_sagemaker_environment_region()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sagemaker_config_region(mocker):
|
||||||
|
"""
|
||||||
|
If a region is specified as part of the optional parameters of the completion, including as
|
||||||
|
part of the config file, then use that region instead of us-west-2
|
||||||
|
"""
|
||||||
|
expected_region = "us-east-1"
|
||||||
|
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client")
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/mock-endpoint",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"content": "Hello, world!",
|
||||||
|
"role": "user"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
aws_region_name=expected_region,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # expected serialization exception because AWS client was replaced with a Mock
|
||||||
|
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
||||||
|
|
||||||
|
# test_sagemaker_config_region()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sagemaker_config_and_environment_region(mocker):
|
||||||
|
"""
|
||||||
|
If both the environment and config file specify a region, the environment region is expected
|
||||||
|
"""
|
||||||
|
expected_region = "us-east-1"
|
||||||
|
unexpected_region = "us-east-2"
|
||||||
|
os.environ["AWS_REGION_NAME"] = expected_region
|
||||||
|
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client")
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/mock-endpoint",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"content": "Hello, world!",
|
||||||
|
"role": "user"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
aws_region_name=unexpected_region,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # expected serialization exception because AWS client was replaced with a Mock
|
||||||
|
del os.environ["AWS_REGION_NAME"] # cleanup
|
||||||
|
assert mock_client.call_args.kwargs["region_name"] == expected_region
|
||||||
|
|
||||||
|
# test_sagemaker_config_and_environment_region()
|
||||||
|
|
||||||
|
|
||||||
# Bedrock
|
# Bedrock
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue