Revert imports changes, update tests to match

This commit is contained in:
Peter Muller 2024-07-02 19:09:22 -07:00
parent 47c97e1fa2
commit d8fc8252fa
No known key found for this signature in database
GPG key ID: A6094745C95B1613
2 changed files with 16 additions and 6 deletions

View file

@ -9,8 +9,6 @@ from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
import sys import sys
from copy import deepcopy from copy import deepcopy
import httpx # type: ignore import httpx # type: ignore
import boto3
import aioboto3
import io import io
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -159,6 +157,8 @@ def completion(
logger_fn=None, logger_fn=None,
acompletion: bool = False, acompletion: bool = False,
): ):
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
@ -413,6 +413,10 @@ async def async_streaming(
aws_access_key_id: Optional[str], aws_access_key_id: Optional[str],
aws_region_name: Optional[str], aws_region_name: Optional[str],
): ):
"""
Use aioboto3
"""
import aioboto3
session = aioboto3.Session() session = aioboto3.Session()
@ -477,6 +481,10 @@ async def async_completion(
aws_access_key_id: Optional[str], aws_access_key_id: Optional[str],
aws_region_name: Optional[str], aws_region_name: Optional[str],
): ):
"""
Use aioboto3
"""
import aioboto3
session = aioboto3.Session() session = aioboto3.Session()
@ -628,6 +636,8 @@ def embedding(
""" """
Supports Huggingface Jumpstart embeddings like GPT-6B Supports Huggingface Jumpstart embeddings like GPT-6B
""" """
### BOTO3 INIT
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)

View file

@ -517,7 +517,7 @@ def test_sagemaker_default_region(mocker):
""" """
If no regions are specified in config or in environment, the default region is us-west-2 If no regions are specified in config or in environment, the default region is us-west-2
""" """
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") mock_client = mocker.patch("boto3.client")
try: try:
response = litellm.completion( response = litellm.completion(
model="sagemaker/mock-endpoint", model="sagemaker/mock-endpoint",
@ -541,7 +541,7 @@ def test_sagemaker_environment_region(mocker):
""" """
expected_region = "us-east-1" expected_region = "us-east-1"
os.environ["AWS_REGION_NAME"] = expected_region os.environ["AWS_REGION_NAME"] = expected_region
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") mock_client = mocker.patch("boto3.client")
try: try:
response = litellm.completion( response = litellm.completion(
model="sagemaker/mock-endpoint", model="sagemaker/mock-endpoint",
@ -566,7 +566,7 @@ def test_sagemaker_config_region(mocker):
part of the config file, then use that region instead of us-west-2 part of the config file, then use that region instead of us-west-2
""" """
expected_region = "us-east-1" expected_region = "us-east-1"
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") mock_client = mocker.patch("boto3.client")
try: try:
response = litellm.completion( response = litellm.completion(
model="sagemaker/mock-endpoint", model="sagemaker/mock-endpoint",
@ -592,7 +592,7 @@ def test_sagemaker_config_and_environment_region(mocker):
expected_region = "us-east-1" expected_region = "us-east-1"
unexpected_region = "us-east-2" unexpected_region = "us-east-2"
os.environ["AWS_REGION_NAME"] = expected_region os.environ["AWS_REGION_NAME"] = expected_region
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client") mock_client = mocker.patch("boto3.client")
try: try:
response = litellm.completion( response = litellm.completion(
model="sagemaker/mock-endpoint", model="sagemaker/mock-endpoint",