mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Revert imports changes, update tests to match
This commit is contained in:
parent
47c97e1fa2
commit
d8fc8252fa
2 changed files with 16 additions and 6 deletions
|
@ -9,8 +9,6 @@ from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import boto3
|
|
||||||
import aioboto3
|
|
||||||
import io
|
import io
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
@ -159,6 +157,8 @@ def completion(
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
):
|
):
|
||||||
|
import boto3
|
||||||
|
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
@ -413,6 +413,10 @@ async def async_streaming(
|
||||||
aws_access_key_id: Optional[str],
|
aws_access_key_id: Optional[str],
|
||||||
aws_region_name: Optional[str],
|
aws_region_name: Optional[str],
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Use aioboto3
|
||||||
|
"""
|
||||||
|
import aioboto3
|
||||||
|
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
|
|
||||||
|
@ -477,6 +481,10 @@ async def async_completion(
|
||||||
aws_access_key_id: Optional[str],
|
aws_access_key_id: Optional[str],
|
||||||
aws_region_name: Optional[str],
|
aws_region_name: Optional[str],
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Use aioboto3
|
||||||
|
"""
|
||||||
|
import aioboto3
|
||||||
|
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
|
|
||||||
|
@ -628,6 +636,8 @@ def embedding(
|
||||||
"""
|
"""
|
||||||
Supports Huggingface Jumpstart embeddings like GPT-6B
|
Supports Huggingface Jumpstart embeddings like GPT-6B
|
||||||
"""
|
"""
|
||||||
|
### BOTO3 INIT
|
||||||
|
import boto3
|
||||||
|
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
|
|
@ -517,7 +517,7 @@ def test_sagemaker_default_region(mocker):
|
||||||
"""
|
"""
|
||||||
If no regions are specified in config or in environment, the default region is us-west-2
|
If no regions are specified in config or in environment, the default region is us-west-2
|
||||||
"""
|
"""
|
||||||
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client")
|
mock_client = mocker.patch("boto3.client")
|
||||||
try:
|
try:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="sagemaker/mock-endpoint",
|
model="sagemaker/mock-endpoint",
|
||||||
|
@ -541,7 +541,7 @@ def test_sagemaker_environment_region(mocker):
|
||||||
"""
|
"""
|
||||||
expected_region = "us-east-1"
|
expected_region = "us-east-1"
|
||||||
os.environ["AWS_REGION_NAME"] = expected_region
|
os.environ["AWS_REGION_NAME"] = expected_region
|
||||||
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client")
|
mock_client = mocker.patch("boto3.client")
|
||||||
try:
|
try:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="sagemaker/mock-endpoint",
|
model="sagemaker/mock-endpoint",
|
||||||
|
@ -566,7 +566,7 @@ def test_sagemaker_config_region(mocker):
|
||||||
part of the config file, then use that region instead of us-west-2
|
part of the config file, then use that region instead of us-west-2
|
||||||
"""
|
"""
|
||||||
expected_region = "us-east-1"
|
expected_region = "us-east-1"
|
||||||
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client")
|
mock_client = mocker.patch("boto3.client")
|
||||||
try:
|
try:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="sagemaker/mock-endpoint",
|
model="sagemaker/mock-endpoint",
|
||||||
|
@ -592,7 +592,7 @@ def test_sagemaker_config_and_environment_region(mocker):
|
||||||
expected_region = "us-east-1"
|
expected_region = "us-east-1"
|
||||||
unexpected_region = "us-east-2"
|
unexpected_region = "us-east-2"
|
||||||
os.environ["AWS_REGION_NAME"] = expected_region
|
os.environ["AWS_REGION_NAME"] = expected_region
|
||||||
mock_client = mocker.patch("litellm.llms.sagemaker.boto3.client")
|
mock_client = mocker.patch("boto3.client")
|
||||||
try:
|
try:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="sagemaker/mock-endpoint",
|
model="sagemaker/mock-endpoint",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue