(feat) s3 cache support all boto3 params

This commit is contained in:
ishaan-jaff 2024-01-03 15:41:38 +05:30
parent b51371952b
commit 58ce5d44ae

View file

@ -10,7 +10,7 @@
import litellm import litellm
import time, logging import time, logging
import json, traceback, ast, hashlib import json, traceback, ast, hashlib
from typing import Optional, Literal, List from typing import Optional, Literal, List, Union, Any
def print_verbose(print_statement): def print_verbose(print_statement):
@ -118,21 +118,35 @@ class RedisCache(BaseCache):
class S3Cache(BaseCache): class S3Cache(BaseCache):
def __init__( def __init__(
self, self,
bucket_name=None, s3_bucket_name,
aws_access_key_id=None, s3_region_name=None,
aws_secret_access_key=None, s3_api_version=None,
region_name=None, s3_use_ssl=True,
endpoint_url=None, s3_verify=None,
s3_endpoint_url=None,
s3_aws_access_key_id=None,
s3_aws_secret_access_key=None,
s3_aws_session_token=None,
s3_config=None,
**kwargs, **kwargs,
): ):
import boto3 import boto3
self.bucket_name = "cache-bucket-litellm" self.bucket_name = s3_bucket_name
self.region_name = region_name
self.endpoint_url = endpoint_url # Add the endpoint_url parameter
# Create an S3 client with custom endpoint URL # Create an S3 client with custom endpoint URL
self.s3_client = boto3.client("s3", region_name="us-west-2", **kwargs) self.s3_client = boto3.client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
api_version=s3_api_version,
use_ssl=s3_use_ssl,
verify=s3_verify,
aws_access_key_id=s3_aws_access_key_id,
aws_secret_access_key=s3_aws_secret_access_key,
aws_session_token=s3_aws_session_token,
config=s3_config,
**kwargs,
)
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
try: try:
@ -281,6 +295,17 @@ class Cache:
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"], ] = ["completion", "acompletion", "embedding", "aembedding"],
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
s3_api_version: Optional[str] = None,
s3_use_ssl: Optional[bool] = True,
s3_verify: Optional[Union[bool, str]] = None,
s3_endpoint_url: Optional[str] = None,
s3_aws_access_key_id: Optional[str] = None,
s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None,
**kwargs, **kwargs,
): ):
""" """
@ -305,7 +330,19 @@ class Cache:
if type == "local": if type == "local":
self.cache = InMemoryCache() self.cache = InMemoryCache()
if type == "s3": if type == "s3":
self.cache = S3Cache() self.cache = S3Cache(
s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name,
s3_api_version=s3_api_version,
s3_use_ssl=s3_use_ssl,
s3_verify=s3_verify,
s3_endpoint_url=s3_endpoint_url,
s3_aws_access_key_id=s3_aws_access_key_id,
s3_aws_secret_access_key=s3_aws_secret_access_key,
s3_aws_session_token=s3_aws_session_token,
s3_config=s3_config,
**kwargs,
)
if "cache" not in litellm.input_callback: if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache") litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback: if "cache" not in litellm.success_callback: