From 9ba17657ad664a21b5e91259a152db58540be024 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 20:50:06 -0800 Subject: [PATCH] (feat) init redis cache with **kwargs --- litellm/caching.py | 10 ++++++---- litellm/tests/test_caching.py | 27 +++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 5e8fcf447..d9b94b958 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -69,10 +69,10 @@ class InMemoryCache(BaseCache): class RedisCache(BaseCache): - def __init__(self, host, port, password): + def __init__(self, host, port, password, **kwargs): import redis # if users don't provider one, use the default litellm cache - self.redis_client = redis.Redis(host=host, port=port, password=password) + self.redis_client = redis.Redis(host=host, port=port, password=password, **kwargs) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) @@ -168,7 +168,8 @@ class Cache: type="local", host=None, port=None, - password=None + password=None, + **kwargs ): """ Initializes the cache based on the given type. @@ -178,6 +179,7 @@ class Cache: host (str, optional): The host address for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis". + **kwargs: Additional keyword arguments for redis.Redis() cache Raises: ValueError: If an invalid cache type is provided. @@ -186,7 +188,7 @@ class Cache: None """ if type == "redis": - self.cache = RedisCache(host, port, password) + self.cache = RedisCache(host, port, password, **kwargs) if type == "local": self.cache = InMemoryCache() if "cache" not in litellm.input_callback: diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index ab24d3e70..713f97b3e 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -90,7 +90,7 @@ def test_embedding_caching(): print(f"embedding2: {embedding2}") pytest.fail("Error occurred: Embedding caching failed") -test_embedding_caching() +# test_embedding_caching() def test_embedding_caching_azure(): @@ -190,7 +190,7 @@ def test_redis_cache_completion(): print(f"response4: {response4}") pytest.fail(f"Error occurred:") -test_redis_cache_completion() +# test_redis_cache_completion() # redis cache with custom keys def custom_get_cache_key(*args, **kwargs): @@ -231,6 +231,29 @@ def test_custom_redis_cache_with_key(): # test_custom_redis_cache_with_key() + +def test_custom_redis_cache_params(): + # test if we can init redis with **kwargs + try: + litellm.cache = Cache( + type="redis", + host=os.environ['REDIS_HOST'], + port=os.environ['REDIS_PORT'], + password=os.environ['REDIS_PASSWORD'], + db = 0, + ssl=True, + ssl_certfile="./redis_user.crt", + ssl_keyfile="./redis_user_private.key", + ssl_ca_certs="./redis_ca.pem", + ) + + print(litellm.cache.cache.redis_client) + litellm.cache = None + except Exception as e: + pytest.fail(f"Error occurred:", e) + +# test_custom_redis_cache_params() + # def test_redis_cache_with_ttl(): # cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) # sample_model_response_object_str = """{