diff --git a/litellm/caching.py b/litellm/caching.py index eb7265001..3dc70d2b5 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -10,7 +10,7 @@ import litellm import time, logging import json, traceback, ast, hashlib -from typing import Optional, Literal, List +from typing import Optional, Literal, List, Union, Any def print_verbose(print_statement): @@ -118,21 +118,35 @@ class RedisCache(BaseCache): class S3Cache(BaseCache): def __init__( self, - bucket_name=None, - aws_access_key_id=None, - aws_secret_access_key=None, - region_name=None, - endpoint_url=None, + s3_bucket_name, + s3_region_name=None, + s3_api_version=None, + s3_use_ssl=True, + s3_verify=None, + s3_endpoint_url=None, + s3_aws_access_key_id=None, + s3_aws_secret_access_key=None, + s3_aws_session_token=None, + s3_config=None, **kwargs, ): import boto3 - self.bucket_name = "cache-bucket-litellm" - self.region_name = region_name - self.endpoint_url = endpoint_url # Add the endpoint_url parameter - + self.bucket_name = s3_bucket_name # Create an S3 client with custom endpoint URL - self.s3_client = boto3.client("s3", region_name="us-west-2", **kwargs) + self.s3_client = boto3.client( + "s3", + region_name=s3_region_name, + endpoint_url=s3_endpoint_url, + api_version=s3_api_version, + use_ssl=s3_use_ssl, + verify=s3_verify, + aws_access_key_id=s3_aws_access_key_id, + aws_secret_access_key=s3_aws_secret_access_key, + aws_session_token=s3_aws_session_token, + config=s3_config, + **kwargs, + ) def set_cache(self, key, value, **kwargs): try: @@ -281,6 +295,17 @@ class Cache: supported_call_types: Optional[ List[Literal["completion", "acompletion", "embedding", "aembedding"]] ] = ["completion", "acompletion", "embedding", "aembedding"], + # s3 Bucket, boto3 configuration + s3_bucket_name: Optional[str] = None, + s3_region_name: Optional[str] = None, + s3_api_version: Optional[str] = None, + s3_use_ssl: Optional[bool] = True, + s3_verify: Optional[Union[bool, str]] = None, + s3_endpoint_url: Optional[str] = None, + s3_aws_access_key_id: Optional[str] = None, + s3_aws_secret_access_key: Optional[str] = None, + s3_aws_session_token: Optional[str] = None, + s3_config: Optional[Any] = None, **kwargs, ): """ @@ -305,7 +330,19 @@ class Cache: if type == "local": self.cache = InMemoryCache() if type == "s3": - self.cache = S3Cache() + self.cache = S3Cache( + s3_bucket_name=s3_bucket_name, + s3_region_name=s3_region_name, + s3_api_version=s3_api_version, + s3_use_ssl=s3_use_ssl, + s3_verify=s3_verify, + s3_endpoint_url=s3_endpoint_url, + s3_aws_access_key_id=s3_aws_access_key_id, + s3_aws_secret_access_key=s3_aws_secret_access_key, + s3_aws_session_token=s3_aws_session_token, + s3_config=s3_config, + **kwargs, + ) if "cache" not in litellm.input_callback: litellm.input_callback.append("cache") if "cache" not in litellm.success_callback: