refactor(redis_cache.py): use a default cache value when writing to r… (#6358)

* refactor(redis_cache.py): use a default cache value when writing to redis

prevent redis from blowing up in high traffic

* refactor(redis_cache.py): refactor all cache writes to use self.get_ttl

ensures default ttl always used when writing to redis

Prevents redis db from blowing up in prod
This commit is contained in:
Krish Dholakia 2024-10-21 16:42:12 -07:00 committed by GitHub
parent 199896f912
commit 7338b24a74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 199 additions and 29 deletions

View file

@ -23,8 +23,9 @@ import litellm
from litellm import aembedding, completion, embedding
from litellm.caching.caching import Cache
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, patch, MagicMock, call
import datetime
from datetime import timedelta
# litellm.set_verbose=True
@ -2394,3 +2395,123 @@ async def test_audio_caching(stream):
)
assert "cache_hit" in completion._hidden_params
def test_redis_caching_default_ttl():
"""
Ensure that the default redis cache TTL is 60s
"""
from litellm.caching.redis_cache import RedisCache
litellm.default_redis_ttl = 120
cache_obj = RedisCache()
assert cache_obj.default_ttl == 120
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_redis_caching_llm_caching_ttl(sync_mode):
"""
Ensure default redis cache ttl is used for a sample redis cache object
"""
from litellm.caching.redis_cache import RedisCache
litellm.default_redis_ttl = 120
cache_obj = RedisCache()
assert cache_obj.default_ttl == 120
if sync_mode is False:
# Create an AsyncMock for the Redis client
mock_redis_instance = AsyncMock()
# Make sure the mock can be used as an async context manager
mock_redis_instance.__aenter__.return_value = mock_redis_instance
mock_redis_instance.__aexit__.return_value = None
## Set cache
if sync_mode is True:
with patch.object(cache_obj.redis_client, "set") as mock_set:
cache_obj.set_cache(key="test", value="test")
mock_set.assert_called_once_with(name="test", value="test", ex=120)
else:
# Patch self.init_async_client to return our mock Redis client
with patch.object(
cache_obj, "init_async_client", return_value=mock_redis_instance
):
# Call async_set_cache
await cache_obj.async_set_cache(key="test", value="test_value")
# Verify that the set method was called on the mock Redis instance
mock_redis_instance.set.assert_called_once_with(
name="test", value='"test_value"', ex=120
)
## Increment cache
if sync_mode is True:
with patch.object(cache_obj.redis_client, "ttl") as mock_incr:
cache_obj.increment_cache(key="test", value=1)
mock_incr.assert_called_once_with("test")
else:
# Patch self.init_async_client to return our mock Redis client
with patch.object(
cache_obj, "init_async_client", return_value=mock_redis_instance
):
# Call async_set_cache
await cache_obj.async_increment(key="test", value="test_value")
# Verify that the set method was called on the mock Redis instance
mock_redis_instance.ttl.assert_called_once_with("test")
@pytest.mark.asyncio()
async def test_redis_caching_ttl_pipeline():
"""
Ensure that a default ttl is set for all redis functions
"""
from litellm.caching.redis_cache import RedisCache
litellm.default_redis_ttl = 120
expected_timedelta = timedelta(seconds=120)
cache_obj = RedisCache()
## TEST 1 - async_set_cache_pipeline
# Patch self.init_async_client to return our mock Redis client
# Call async_set_cache
mock_pipe_instance = AsyncMock()
with patch.object(mock_pipe_instance, "set", return_value=None) as mock_set:
await cache_obj._pipeline_helper(
pipe=mock_pipe_instance,
cache_list=[("test_key1", "test_value1"), ("test_key2", "test_value2")],
ttl=None,
)
# Verify that the set method was called on the mock Redis instance
mock_set.assert_has_calls(
[
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
]
)
@pytest.mark.asyncio()
async def test_redis_caching_ttl_sadd():
"""
Ensure that a default ttl is set for all redis functions
"""
from litellm.caching.redis_cache import RedisCache
litellm.default_redis_ttl = 120
expected_timedelta = timedelta(seconds=120)
cache_obj = RedisCache()
redis_client = AsyncMock()
with patch.object(redis_client, "expire", return_value=None) as mock_expire:
await cache_obj._set_cache_sadd_helper(
redis_client=redis_client, key="test_key", value=["test_value"], ttl=None
)
print(f"expected_timedelta: {expected_timedelta}")
assert mock_expire.call_args.args[1] == expected_timedelta