From 7338b24a7458e5d85a66c4bd6807827abfe3f29f Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 21 Oct 2024 16:42:12 -0700 Subject: [PATCH] =?UTF-8?q?refactor(redis=5Fcache.py):=20use=20a=20default?= =?UTF-8?q?=20cache=20value=20when=20writing=20to=20r=E2=80=A6=20(#6358)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(redis_cache.py): use a default cache value when writing to redis prevent redis from blowing up in high traffic * refactor(redis_cache.py): refactor all cache writes to use self.get_ttl ensures default ttl always used when writing to redis Prevents redis db from blowing up in prod --- litellm/caching/base_cache.py | 12 ++- litellm/caching/redis_cache.py | 93 +++++++++++++++------ tests/local_testing/test_caching.py | 123 +++++++++++++++++++++++++++- 3 files changed, 199 insertions(+), 29 deletions(-) diff --git a/litellm/caching/base_cache.py b/litellm/caching/base_cache.py index ec4a56d23..016ad70b9 100644 --- a/litellm/caching/base_cache.py +++ b/litellm/caching/base_cache.py @@ -8,8 +8,18 @@ Has 4 methods: - async_get_cache """ +from typing import Optional + class BaseCache: + def __init__(self, default_ttl: int = 60): + self.default_ttl = default_ttl + + def get_ttl(self, **kwargs) -> Optional[int]: + if kwargs.get("ttl") is not None: + return kwargs.get("ttl") + return self.default_ttl + def set_cache(self, key, value, **kwargs): raise NotImplementedError @@ -22,7 +32,7 @@ class BaseCache: async def async_get_cache(self, key, **kwargs): raise NotImplementedError - async def batch_cache_write(self, result, *args, **kwargs): + async def batch_cache_write(self, key, value, **kwargs): raise NotImplementedError async def disconnect(self): diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 29d9a71d2..8604bdad6 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -14,8 +14,9 @@ import inspect import json import time from datetime import timedelta -from typing import Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple +import litellm from litellm._logging import print_verbose, verbose_logger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.types.services import ServiceLoggerPayload, ServiceTypes @@ -23,6 +24,16 @@ from litellm.types.utils import all_litellm_params from .base_cache import BaseCache +if TYPE_CHECKING: + from redis.asyncio import Redis + from redis.asyncio.client import Pipeline + + pipeline = Pipeline + async_redis_client = Redis +else: + pipeline = Any + async_redis_client = Any + class RedisCache(BaseCache): # if users don't provider one, use the default litellm cache @@ -104,6 +115,11 @@ class RedisCache(BaseCache): "Error connecting to Sync Redis client", extra={"error": str(e)} ) + if litellm.default_redis_ttl is not None: + super().__init__(default_ttl=int(litellm.default_redis_ttl)) + else: + super().__init__() # defaults to 60s + def init_async_client(self): from .._redis import get_redis_async_client @@ -121,7 +137,7 @@ class RedisCache(BaseCache): return key def set_cache(self, key, value, **kwargs): - ttl = kwargs.get("ttl", None) + ttl = self.get_ttl(**kwargs) print_verbose( f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" ) @@ -139,15 +155,16 @@ class RedisCache(BaseCache): ) -> int: _redis_client = self.redis_client start_time = time.time() + set_ttl = self.get_ttl(ttl=ttl) try: result: int = _redis_client.incr(name=key, amount=value) # type: ignore - if ttl is not None: + if set_ttl is not None: # check if key already has ttl, if not -> set ttl current_ttl = _redis_client.ttl(key) if current_ttl == -1: # Key has no expiration - _redis_client.expire(key, ttl) # type: ignore + _redis_client.expire(key, set_ttl) # type: ignore return result except Exception as e: ## LOGGING ## @@ -236,7 +253,7 @@ class RedisCache(BaseCache): key = self.check_and_fix_namespace(key=key) async with _redis_client as redis_client: - ttl = kwargs.get("ttl", None) + ttl = self.get_ttl(**kwargs) print_verbose( f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" ) @@ -284,6 +301,26 @@ class RedisCache(BaseCache): value, ) + async def _pipeline_helper( + self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] + ) -> List: + ttl = self.get_ttl(ttl=ttl) + # Iterate through each key-value pair in the cache_list and set them in the pipeline. + for cache_key, cache_value in cache_list: + cache_key = self.check_and_fix_namespace(key=cache_key) + print_verbose( + f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" + ) + json_cache_value = json.dumps(cache_value) + # Set the value with a TTL if it's provided. + _td: Optional[timedelta] = None + if ttl is not None: + _td = timedelta(seconds=ttl) + pipe.set(cache_key, json_cache_value, ex=_td) + # Execute the pipeline and return the results. + results = await pipe.execute() + return results + async def async_set_cache_pipeline( self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs ): @@ -298,8 +335,6 @@ class RedisCache(BaseCache): _redis_client: Redis = self.init_async_client() # type: ignore start_time = time.time() - ttl = ttl or kwargs.get("ttl", None) - print_verbose( f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" ) @@ -307,20 +342,7 @@ class RedisCache(BaseCache): try: async with _redis_client as redis_client: async with redis_client.pipeline(transaction=True) as pipe: - # Iterate through each key-value pair in the cache_list and set them in the pipeline. - for cache_key, cache_value in cache_list: - cache_key = self.check_and_fix_namespace(key=cache_key) - print_verbose( - f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" - ) - json_cache_value = json.dumps(cache_value) - # Set the value with a TTL if it's provided. - _td: Optional[timedelta] = None - if ttl is not None: - _td = timedelta(seconds=ttl) - pipe.set(cache_key, json_cache_value, ex=_td) - # Execute the pipeline and return the results. - results = await pipe.execute() + results = await self._pipeline_helper(pipe, cache_list, ttl) print_verbose(f"pipeline results: {results}") # Optionally, you could process 'results' to make sure that all set operations were successful. @@ -360,6 +382,23 @@ class RedisCache(BaseCache): cache_value, ) + async def _set_cache_sadd_helper( + self, + redis_client: async_redis_client, + key: str, + value: List, + ttl: Optional[float], + ) -> None: + """Helper function for async_set_cache_sadd. Separated for testing.""" + ttl = self.get_ttl(ttl=ttl) + try: + await redis_client.sadd(key, *value) # type: ignore + if ttl is not None: + _td = timedelta(seconds=ttl) + await redis_client.expire(key, _td) + except Exception: + raise + async def async_set_cache_sadd( self, key, value: List, ttl: Optional[float], **kwargs ): @@ -396,10 +435,9 @@ class RedisCache(BaseCache): f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" ) try: - await redis_client.sadd(key, *value) # type: ignore - if ttl is not None: - _td = timedelta(seconds=ttl) - await redis_client.expire(key, _td) + await self._set_cache_sadd_helper( + redis_client=redis_client, key=key, value=value, ttl=ttl + ) print_verbose( f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}" ) @@ -452,16 +490,17 @@ class RedisCache(BaseCache): _redis_client: Redis = self.init_async_client() # type: ignore start_time = time.time() + _used_ttl = self.get_ttl(ttl=ttl) try: async with _redis_client as redis_client: result = await redis_client.incrbyfloat(name=key, amount=value) - if ttl is not None: + if _used_ttl is not None: # check if key already has ttl, if not -> set ttl current_ttl = await redis_client.ttl(key) if current_ttl == -1: # Key has no expiration - await redis_client.expire(key, ttl) + await redis_client.expire(key, _used_ttl) ## LOGGING ## end_time = time.time() diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index 3d927dc33..dfadf11bb 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -23,8 +23,9 @@ import litellm from litellm import aembedding, completion, embedding from litellm.caching.caching import Cache -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, patch, MagicMock, call import datetime +from datetime import timedelta # litellm.set_verbose=True @@ -2394,3 +2395,123 @@ async def test_audio_caching(stream): ) assert "cache_hit" in completion._hidden_params + + +def test_redis_caching_default_ttl(): + """ + Ensure that the default redis cache TTL is 60s + """ + from litellm.caching.redis_cache import RedisCache + + litellm.default_redis_ttl = 120 + + cache_obj = RedisCache() + assert cache_obj.default_ttl == 120 + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_redis_caching_llm_caching_ttl(sync_mode): + """ + Ensure default redis cache ttl is used for a sample redis cache object + """ + from litellm.caching.redis_cache import RedisCache + + litellm.default_redis_ttl = 120 + cache_obj = RedisCache() + assert cache_obj.default_ttl == 120 + + if sync_mode is False: + # Create an AsyncMock for the Redis client + mock_redis_instance = AsyncMock() + + # Make sure the mock can be used as an async context manager + mock_redis_instance.__aenter__.return_value = mock_redis_instance + mock_redis_instance.__aexit__.return_value = None + + ## Set cache + if sync_mode is True: + with patch.object(cache_obj.redis_client, "set") as mock_set: + cache_obj.set_cache(key="test", value="test") + mock_set.assert_called_once_with(name="test", value="test", ex=120) + else: + + # Patch self.init_async_client to return our mock Redis client + with patch.object( + cache_obj, "init_async_client", return_value=mock_redis_instance + ): + # Call async_set_cache + await cache_obj.async_set_cache(key="test", value="test_value") + + # Verify that the set method was called on the mock Redis instance + mock_redis_instance.set.assert_called_once_with( + name="test", value='"test_value"', ex=120 + ) + + ## Increment cache + if sync_mode is True: + with patch.object(cache_obj.redis_client, "ttl") as mock_incr: + cache_obj.increment_cache(key="test", value=1) + mock_incr.assert_called_once_with("test") + else: + # Patch self.init_async_client to return our mock Redis client + with patch.object( + cache_obj, "init_async_client", return_value=mock_redis_instance + ): + # Call async_set_cache + await cache_obj.async_increment(key="test", value="test_value") + + # Verify that the set method was called on the mock Redis instance + mock_redis_instance.ttl.assert_called_once_with("test") + + +@pytest.mark.asyncio() +async def test_redis_caching_ttl_pipeline(): + """ + Ensure that a default ttl is set for all redis functions + """ + + from litellm.caching.redis_cache import RedisCache + + litellm.default_redis_ttl = 120 + expected_timedelta = timedelta(seconds=120) + cache_obj = RedisCache() + + ## TEST 1 - async_set_cache_pipeline + # Patch self.init_async_client to return our mock Redis client + # Call async_set_cache + mock_pipe_instance = AsyncMock() + with patch.object(mock_pipe_instance, "set", return_value=None) as mock_set: + await cache_obj._pipeline_helper( + pipe=mock_pipe_instance, + cache_list=[("test_key1", "test_value1"), ("test_key2", "test_value2")], + ttl=None, + ) + + # Verify that the set method was called on the mock Redis instance + mock_set.assert_has_calls( + [ + call.set("test_key1", '"test_value1"', ex=expected_timedelta), + call.set("test_key2", '"test_value2"', ex=expected_timedelta), + ] + ) + + +@pytest.mark.asyncio() +async def test_redis_caching_ttl_sadd(): + """ + Ensure that a default ttl is set for all redis functions + """ + from litellm.caching.redis_cache import RedisCache + + litellm.default_redis_ttl = 120 + expected_timedelta = timedelta(seconds=120) + cache_obj = RedisCache() + redis_client = AsyncMock() + + with patch.object(redis_client, "expire", return_value=None) as mock_expire: + await cache_obj._set_cache_sadd_helper( + redis_client=redis_client, key="test_key", value=["test_value"], ttl=None + ) + print(f"expected_timedelta: {expected_timedelta}") + assert mock_expire.call_args.args[1] == expected_timedelta