forked from phoenix/litellm-mirror
refactor(redis_cache.py): use a default cache value when writing to r… (#6358)
* refactor(redis_cache.py): use a default cache value when writing to redis prevent redis from blowing up in high traffic * refactor(redis_cache.py): refactor all cache writes to use self.get_ttl ensures default ttl always used when writing to redis Prevents redis db from blowing up in prod
This commit is contained in:
parent
199896f912
commit
7338b24a74
3 changed files with 199 additions and 29 deletions
|
@ -23,8 +23,9 @@ import litellm
|
|||
from litellm import aembedding, completion, embedding
|
||||
from litellm.caching.caching import Cache
|
||||
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock, patch, MagicMock, call
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
# litellm.set_verbose=True
|
||||
|
||||
|
@ -2394,3 +2395,123 @@ async def test_audio_caching(stream):
|
|||
)
|
||||
|
||||
assert "cache_hit" in completion._hidden_params
|
||||
|
||||
|
||||
def test_redis_caching_default_ttl():
|
||||
"""
|
||||
Ensure that the default redis cache TTL is 60s
|
||||
"""
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
litellm.default_redis_ttl = 120
|
||||
|
||||
cache_obj = RedisCache()
|
||||
assert cache_obj.default_ttl == 120
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_redis_caching_llm_caching_ttl(sync_mode):
|
||||
"""
|
||||
Ensure default redis cache ttl is used for a sample redis cache object
|
||||
"""
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
litellm.default_redis_ttl = 120
|
||||
cache_obj = RedisCache()
|
||||
assert cache_obj.default_ttl == 120
|
||||
|
||||
if sync_mode is False:
|
||||
# Create an AsyncMock for the Redis client
|
||||
mock_redis_instance = AsyncMock()
|
||||
|
||||
# Make sure the mock can be used as an async context manager
|
||||
mock_redis_instance.__aenter__.return_value = mock_redis_instance
|
||||
mock_redis_instance.__aexit__.return_value = None
|
||||
|
||||
## Set cache
|
||||
if sync_mode is True:
|
||||
with patch.object(cache_obj.redis_client, "set") as mock_set:
|
||||
cache_obj.set_cache(key="test", value="test")
|
||||
mock_set.assert_called_once_with(name="test", value="test", ex=120)
|
||||
else:
|
||||
|
||||
# Patch self.init_async_client to return our mock Redis client
|
||||
with patch.object(
|
||||
cache_obj, "init_async_client", return_value=mock_redis_instance
|
||||
):
|
||||
# Call async_set_cache
|
||||
await cache_obj.async_set_cache(key="test", value="test_value")
|
||||
|
||||
# Verify that the set method was called on the mock Redis instance
|
||||
mock_redis_instance.set.assert_called_once_with(
|
||||
name="test", value='"test_value"', ex=120
|
||||
)
|
||||
|
||||
## Increment cache
|
||||
if sync_mode is True:
|
||||
with patch.object(cache_obj.redis_client, "ttl") as mock_incr:
|
||||
cache_obj.increment_cache(key="test", value=1)
|
||||
mock_incr.assert_called_once_with("test")
|
||||
else:
|
||||
# Patch self.init_async_client to return our mock Redis client
|
||||
with patch.object(
|
||||
cache_obj, "init_async_client", return_value=mock_redis_instance
|
||||
):
|
||||
# Call async_set_cache
|
||||
await cache_obj.async_increment(key="test", value="test_value")
|
||||
|
||||
# Verify that the set method was called on the mock Redis instance
|
||||
mock_redis_instance.ttl.assert_called_once_with("test")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_caching_ttl_pipeline():
|
||||
"""
|
||||
Ensure that a default ttl is set for all redis functions
|
||||
"""
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
litellm.default_redis_ttl = 120
|
||||
expected_timedelta = timedelta(seconds=120)
|
||||
cache_obj = RedisCache()
|
||||
|
||||
## TEST 1 - async_set_cache_pipeline
|
||||
# Patch self.init_async_client to return our mock Redis client
|
||||
# Call async_set_cache
|
||||
mock_pipe_instance = AsyncMock()
|
||||
with patch.object(mock_pipe_instance, "set", return_value=None) as mock_set:
|
||||
await cache_obj._pipeline_helper(
|
||||
pipe=mock_pipe_instance,
|
||||
cache_list=[("test_key1", "test_value1"), ("test_key2", "test_value2")],
|
||||
ttl=None,
|
||||
)
|
||||
|
||||
# Verify that the set method was called on the mock Redis instance
|
||||
mock_set.assert_has_calls(
|
||||
[
|
||||
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
|
||||
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_caching_ttl_sadd():
|
||||
"""
|
||||
Ensure that a default ttl is set for all redis functions
|
||||
"""
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
litellm.default_redis_ttl = 120
|
||||
expected_timedelta = timedelta(seconds=120)
|
||||
cache_obj = RedisCache()
|
||||
redis_client = AsyncMock()
|
||||
|
||||
with patch.object(redis_client, "expire", return_value=None) as mock_expire:
|
||||
await cache_obj._set_cache_sadd_helper(
|
||||
redis_client=redis_client, key="test_key", value=["test_value"], ttl=None
|
||||
)
|
||||
print(f"expected_timedelta: {expected_timedelta}")
|
||||
assert mock_expire.call_args.args[1] == expected_timedelta
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue