(feat) s3 cache support all boto3 params

This commit is contained in:
ishaan-jaff 2024-01-03 15:41:38 +05:30
parent b51371952b
commit 58ce5d44ae

View file

@ -10,7 +10,7 @@
import litellm
import time, logging
import json, traceback, ast, hashlib
from typing import Optional, Literal, List
from typing import Optional, Literal, List, Union, Any
def print_verbose(print_statement):
@ -118,21 +118,35 @@ class RedisCache(BaseCache):
class S3Cache(BaseCache):
def __init__(
self,
bucket_name=None,
aws_access_key_id=None,
aws_secret_access_key=None,
region_name=None,
endpoint_url=None,
s3_bucket_name,
s3_region_name=None,
s3_api_version=None,
s3_use_ssl=True,
s3_verify=None,
s3_endpoint_url=None,
s3_aws_access_key_id=None,
s3_aws_secret_access_key=None,
s3_aws_session_token=None,
s3_config=None,
**kwargs,
):
import boto3
self.bucket_name = "cache-bucket-litellm"
self.region_name = region_name
self.endpoint_url = endpoint_url # Add the endpoint_url parameter
self.bucket_name = s3_bucket_name
# Create an S3 client with custom endpoint URL
self.s3_client = boto3.client("s3", region_name="us-west-2", **kwargs)
self.s3_client = boto3.client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
api_version=s3_api_version,
use_ssl=s3_use_ssl,
verify=s3_verify,
aws_access_key_id=s3_aws_access_key_id,
aws_secret_access_key=s3_aws_secret_access_key,
aws_session_token=s3_aws_session_token,
config=s3_config,
**kwargs,
)
def set_cache(self, key, value, **kwargs):
try:
@ -281,6 +295,17 @@ class Cache:
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
s3_api_version: Optional[str] = None,
s3_use_ssl: Optional[bool] = True,
s3_verify: Optional[Union[bool, str]] = None,
s3_endpoint_url: Optional[str] = None,
s3_aws_access_key_id: Optional[str] = None,
s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None,
**kwargs,
):
"""
@ -305,7 +330,19 @@ class Cache:
if type == "local":
self.cache = InMemoryCache()
if type == "s3":
self.cache = S3Cache()
self.cache = S3Cache(
s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name,
s3_api_version=s3_api_version,
s3_use_ssl=s3_use_ssl,
s3_verify=s3_verify,
s3_endpoint_url=s3_endpoint_url,
s3_aws_access_key_id=s3_aws_access_key_id,
s3_aws_secret_access_key=s3_aws_secret_access_key,
s3_aws_session_token=s3_aws_session_token,
s3_config=s3_config,
**kwargs,
)
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback: