Merge pull request #2379 from BerriAI/litellm_s3_bucket_folder_path

fix(caching.py): add s3 path as a top-level param
This commit is contained in:
Krish Dholakia 2024-03-06 19:35:46 -08:00 committed by GitHub
commit 06bde2b8c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 16 additions and 3 deletions

View file

@ -572,6 +572,7 @@ class S3Cache(BaseCache):
self.bucket_name = s3_bucket_name self.bucket_name = s3_bucket_name
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else "" self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
# Create an S3 client with custom endpoint URL # Create an S3 client with custom endpoint URL
self.s3_client = boto3.client( self.s3_client = boto3.client(
"s3", "s3",
region_name=s3_region_name, region_name=s3_region_name,
@ -776,6 +777,7 @@ class Cache:
s3_aws_secret_access_key: Optional[str] = None, s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None, s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None, s3_config: Optional[Any] = None,
s3_path: Optional[str] = None,
redis_semantic_cache_use_async=False, redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_semantic_cache_embedding_model="text-embedding-ada-002",
**kwargs, **kwargs,
@ -825,6 +827,7 @@ class Cache:
s3_aws_secret_access_key=s3_aws_secret_access_key, s3_aws_secret_access_key=s3_aws_secret_access_key,
s3_aws_session_token=s3_aws_session_token, s3_aws_session_token=s3_aws_session_token,
s3_config=s3_config, s3_config=s3_config,
s3_path=s3_path,
**kwargs, **kwargs,
) )
if "cache" not in litellm.input_callback: if "cache" not in litellm.input_callback:

View file

@ -485,7 +485,12 @@ def convert_url_to_base64(url):
import requests import requests
import base64 import base64
response = requests.get(url) for _ in range(3):
try:
response = requests.get(url)
break
except:
pass
if response.status_code == 200: if response.status_code == 200:
image_bytes = response.content image_bytes = response.content
base64_image = base64.b64encode(image_bytes).decode("utf-8") base64_image = base64.b64encode(image_bytes).decode("utf-8")
@ -536,6 +541,8 @@ def convert_to_anthropic_image_obj(openai_image_url: str):
"data": base64_data, "data": base64_data,
} }
except Exception as e: except Exception as e:
if "Error: Unable to fetch image from URL" in str(e):
raise e
raise Exception( raise Exception(
"""Image url not in expected format. Example Expected input - "image_url": "data:image/jpeg;base64,{base64_image}". Supported formats - ['image/jpeg', 'image/png', 'image/gif', 'image/webp'] """ """Image url not in expected format. Example Expected input - "image_url": "data:image/jpeg;base64,{base64_image}". Supported formats - ['image/jpeg', 'image/png', 'image/gif', 'image/webp'] """
) )

View file

@ -695,7 +695,6 @@ def test_s3_cache_acompletion_stream_azure():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="AWS Suspended Account")
async def test_s3_cache_acompletion_azure(): async def test_s3_cache_acompletion_azure():
import asyncio import asyncio
import logging import logging
@ -714,7 +713,9 @@ async def test_s3_cache_acompletion_azure():
} }
] ]
litellm.cache = Cache( litellm.cache = Cache(
type="s3", s3_bucket_name="cache-bucket-litellm", s3_region_name="us-west-2" type="s3",
s3_bucket_name="litellm-my-test-bucket-2",
s3_region_name="us-east-1",
) )
print("s3 Cache: test for caching, streaming + completion") print("s3 Cache: test for caching, streaming + completion")

View file

@ -219,6 +219,7 @@ def test_completion_claude_3_base64():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.skip(reason="issue getting wikipedia images in ci/cd")
def test_completion_claude_3_function_plus_image(): def test_completion_claude_3_function_plus_image():
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -647,6 +647,7 @@ async def test_streaming_router_tpm_limit():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bad_router_call(): async def test_bad_router_call():
litellm.set_verbose = True
model_list = [ model_list = [
{ {
"model_name": "azure-model", "model_name": "azure-model",