refactor(redis_cache.py): use a default cache value when writing to r… (#6358)

* refactor(redis_cache.py): use a default cache value when writing to redis

prevent redis from blowing up in high traffic

* refactor(redis_cache.py): refactor all cache writes to use self.get_ttl

ensures default ttl always used when writing to redis

Prevents redis db from blowing up in prod
This commit is contained in:
Krish Dholakia 2024-10-21 16:42:12 -07:00 committed by GitHub
parent 199896f912
commit 7338b24a74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 199 additions and 29 deletions

View file

@ -8,8 +8,18 @@ Has 4 methods:
- async_get_cache
"""
from typing import Optional
class BaseCache:
def __init__(self, default_ttl: int = 60):
self.default_ttl = default_ttl
def get_ttl(self, **kwargs) -> Optional[int]:
if kwargs.get("ttl") is not None:
return kwargs.get("ttl")
return self.default_ttl
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
@ -22,7 +32,7 @@ class BaseCache:
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def batch_cache_write(self, result, *args, **kwargs):
async def batch_cache_write(self, key, value, **kwargs):
raise NotImplementedError
async def disconnect(self):

View file

@ -14,8 +14,9 @@ import inspect
import json
import time
from datetime import timedelta
from typing import Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import litellm
from litellm._logging import print_verbose, verbose_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
@ -23,6 +24,16 @@ from litellm.types.utils import all_litellm_params
from .base_cache import BaseCache
if TYPE_CHECKING:
from redis.asyncio import Redis
from redis.asyncio.client import Pipeline
pipeline = Pipeline
async_redis_client = Redis
else:
pipeline = Any
async_redis_client = Any
class RedisCache(BaseCache):
# if users don't provider one, use the default litellm cache
@ -104,6 +115,11 @@ class RedisCache(BaseCache):
"Error connecting to Sync Redis client", extra={"error": str(e)}
)
if litellm.default_redis_ttl is not None:
super().__init__(default_ttl=int(litellm.default_redis_ttl))
else:
super().__init__() # defaults to 60s
def init_async_client(self):
from .._redis import get_redis_async_client
@ -121,7 +137,7 @@ class RedisCache(BaseCache):
return key
def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
ttl = self.get_ttl(**kwargs)
print_verbose(
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
)
@ -139,15 +155,16 @@ class RedisCache(BaseCache):
) -> int:
_redis_client = self.redis_client
start_time = time.time()
set_ttl = self.get_ttl(ttl=ttl)
try:
result: int = _redis_client.incr(name=key, amount=value) # type: ignore
if ttl is not None:
if set_ttl is not None:
# check if key already has ttl, if not -> set ttl
current_ttl = _redis_client.ttl(key)
if current_ttl == -1:
# Key has no expiration
_redis_client.expire(key, ttl) # type: ignore
_redis_client.expire(key, set_ttl) # type: ignore
return result
except Exception as e:
## LOGGING ##
@ -236,7 +253,7 @@ class RedisCache(BaseCache):
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None)
ttl = self.get_ttl(**kwargs)
print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
@ -284,6 +301,26 @@ class RedisCache(BaseCache):
value,
)
async def _pipeline_helper(
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
) -> List:
ttl = self.get_ttl(ttl=ttl)
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
)
json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided.
_td: Optional[timedelta] = None
if ttl is not None:
_td = timedelta(seconds=ttl)
pipe.set(cache_key, json_cache_value, ex=_td)
# Execute the pipeline and return the results.
results = await pipe.execute()
return results
async def async_set_cache_pipeline(
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
):
@ -298,8 +335,6 @@ class RedisCache(BaseCache):
_redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time()
ttl = ttl or kwargs.get("ttl", None)
print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
@ -307,20 +342,7 @@ class RedisCache(BaseCache):
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
)
json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided.
_td: Optional[timedelta] = None
if ttl is not None:
_td = timedelta(seconds=ttl)
pipe.set(cache_key, json_cache_value, ex=_td)
# Execute the pipeline and return the results.
results = await pipe.execute()
results = await self._pipeline_helper(pipe, cache_list, ttl)
print_verbose(f"pipeline results: {results}")
# Optionally, you could process 'results' to make sure that all set operations were successful.
@ -360,6 +382,23 @@ class RedisCache(BaseCache):
cache_value,
)
async def _set_cache_sadd_helper(
self,
redis_client: async_redis_client,
key: str,
value: List,
ttl: Optional[float],
) -> None:
"""Helper function for async_set_cache_sadd. Separated for testing."""
ttl = self.get_ttl(ttl=ttl)
try:
await redis_client.sadd(key, *value) # type: ignore
if ttl is not None:
_td = timedelta(seconds=ttl)
await redis_client.expire(key, _td)
except Exception:
raise
async def async_set_cache_sadd(
self, key, value: List, ttl: Optional[float], **kwargs
):
@ -396,10 +435,9 @@ class RedisCache(BaseCache):
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.sadd(key, *value) # type: ignore
if ttl is not None:
_td = timedelta(seconds=ttl)
await redis_client.expire(key, _td)
await self._set_cache_sadd_helper(
redis_client=redis_client, key=key, value=value, ttl=ttl
)
print_verbose(
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
)
@ -452,16 +490,17 @@ class RedisCache(BaseCache):
_redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time()
_used_ttl = self.get_ttl(ttl=ttl)
try:
async with _redis_client as redis_client:
result = await redis_client.incrbyfloat(name=key, amount=value)
if ttl is not None:
if _used_ttl is not None:
# check if key already has ttl, if not -> set ttl
current_ttl = await redis_client.ttl(key)
if current_ttl == -1:
# Key has no expiration
await redis_client.expire(key, ttl)
await redis_client.expire(key, _used_ttl)
## LOGGING ##
end_time = time.time()

View file

@ -23,8 +23,9 @@ import litellm
from litellm import aembedding, completion, embedding
from litellm.caching.caching import Cache
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, patch, MagicMock, call
import datetime
from datetime import timedelta
# litellm.set_verbose=True
@ -2394,3 +2395,123 @@ async def test_audio_caching(stream):
)
assert "cache_hit" in completion._hidden_params
def test_redis_caching_default_ttl():
"""
Ensure that the default redis cache TTL is 60s
"""
from litellm.caching.redis_cache import RedisCache
litellm.default_redis_ttl = 120
cache_obj = RedisCache()
assert cache_obj.default_ttl == 120
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_redis_caching_llm_caching_ttl(sync_mode):
"""
Ensure default redis cache ttl is used for a sample redis cache object
"""
from litellm.caching.redis_cache import RedisCache
litellm.default_redis_ttl = 120
cache_obj = RedisCache()
assert cache_obj.default_ttl == 120
if sync_mode is False:
# Create an AsyncMock for the Redis client
mock_redis_instance = AsyncMock()
# Make sure the mock can be used as an async context manager
mock_redis_instance.__aenter__.return_value = mock_redis_instance
mock_redis_instance.__aexit__.return_value = None
## Set cache
if sync_mode is True:
with patch.object(cache_obj.redis_client, "set") as mock_set:
cache_obj.set_cache(key="test", value="test")
mock_set.assert_called_once_with(name="test", value="test", ex=120)
else:
# Patch self.init_async_client to return our mock Redis client
with patch.object(
cache_obj, "init_async_client", return_value=mock_redis_instance
):
# Call async_set_cache
await cache_obj.async_set_cache(key="test", value="test_value")
# Verify that the set method was called on the mock Redis instance
mock_redis_instance.set.assert_called_once_with(
name="test", value='"test_value"', ex=120
)
## Increment cache
if sync_mode is True:
with patch.object(cache_obj.redis_client, "ttl") as mock_incr:
cache_obj.increment_cache(key="test", value=1)
mock_incr.assert_called_once_with("test")
else:
# Patch self.init_async_client to return our mock Redis client
with patch.object(
cache_obj, "init_async_client", return_value=mock_redis_instance
):
# Call async_set_cache
await cache_obj.async_increment(key="test", value="test_value")
# Verify that the set method was called on the mock Redis instance
mock_redis_instance.ttl.assert_called_once_with("test")
@pytest.mark.asyncio()
async def test_redis_caching_ttl_pipeline():
"""
Ensure that a default ttl is set for all redis functions
"""
from litellm.caching.redis_cache import RedisCache
litellm.default_redis_ttl = 120
expected_timedelta = timedelta(seconds=120)
cache_obj = RedisCache()
## TEST 1 - async_set_cache_pipeline
# Patch self.init_async_client to return our mock Redis client
# Call async_set_cache
mock_pipe_instance = AsyncMock()
with patch.object(mock_pipe_instance, "set", return_value=None) as mock_set:
await cache_obj._pipeline_helper(
pipe=mock_pipe_instance,
cache_list=[("test_key1", "test_value1"), ("test_key2", "test_value2")],
ttl=None,
)
# Verify that the set method was called on the mock Redis instance
mock_set.assert_has_calls(
[
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
]
)
@pytest.mark.asyncio()
async def test_redis_caching_ttl_sadd():
"""
Ensure that a default ttl is set for all redis functions
"""
from litellm.caching.redis_cache import RedisCache
litellm.default_redis_ttl = 120
expected_timedelta = timedelta(seconds=120)
cache_obj = RedisCache()
redis_client = AsyncMock()
with patch.object(redis_client, "expire", return_value=None) as mock_expire:
await cache_obj._set_cache_sadd_helper(
redis_client=redis_client, key="test_key", value=["test_value"], ttl=None
)
print(f"expected_timedelta: {expected_timedelta}")
assert mock_expire.call_args.args[1] == expected_timedelta