Add tests for SageMaker region selection

This commit is contained in:
Peter Muller 2024-07-02 15:30:39 -07:00
parent c6be8326db
commit d9e9a8645b
No known key found for this signature in database
GPG key ID: A6094745C95B1613
2 changed files with 103 additions and 16 deletions

View file

@ -9,6 +9,9 @@ from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
import sys
from copy import deepcopy
import httpx # type: ignore
import boto3
import aioboto3
import io
from .prompt_templates.factory import prompt_factory, custom_prompt
@ -25,10 +28,6 @@ class SagemakerError(Exception):
) # Call the base class constructor with the parameters it needs
import io
import json
class TokenIterator:
def __init__(self, stream, acompletion: bool = False):
if acompletion == False:
@ -160,8 +159,6 @@ def completion(
logger_fn=None,
acompletion: bool = False,
):
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
@ -416,10 +413,6 @@ async def async_streaming(
aws_access_key_id: Optional[str],
aws_region_name: Optional[str],
):
"""
Use aioboto3
"""
import aioboto3
session = aioboto3.Session()
@ -484,10 +477,6 @@ async def async_completion(
aws_access_key_id: Optional[str],
aws_region_name: Optional[str],
):
"""
Use aioboto3
"""
import aioboto3
session = aioboto3.Session()
@ -639,8 +628,6 @@ def embedding(
"""
Supports Huggingface Jumpstart embeddings like GPT-6B
"""
### BOTO3 INIT
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)