forked from phoenix/litellm-mirror
refactor(redis_cache.py): use a default cache value when writing to r… (#6358)
* refactor(redis_cache.py): use a default cache value when writing to redis prevent redis from blowing up in high traffic * refactor(redis_cache.py): refactor all cache writes to use self.get_ttl ensures default ttl always used when writing to redis Prevents redis db from blowing up in prod
This commit is contained in:
parent
199896f912
commit
7338b24a74
3 changed files with 199 additions and 29 deletions
|
@ -8,8 +8,18 @@ Has 4 methods:
|
|||
- async_get_cache
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BaseCache:
|
||||
def __init__(self, default_ttl: int = 60):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
def get_ttl(self, **kwargs) -> Optional[int]:
|
||||
if kwargs.get("ttl") is not None:
|
||||
return kwargs.get("ttl")
|
||||
return self.default_ttl
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -22,7 +32,7 @@ class BaseCache:
|
|||
async def async_get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def batch_cache_write(self, result, *args, **kwargs):
|
||||
async def batch_cache_write(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def disconnect(self):
|
||||
|
|
|
@ -14,8 +14,9 @@ import inspect
|
|||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
@ -23,6 +24,16 @@ from litellm.types.utils import all_litellm_params
|
|||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio.client import Pipeline
|
||||
|
||||
pipeline = Pipeline
|
||||
async_redis_client = Redis
|
||||
else:
|
||||
pipeline = Any
|
||||
async_redis_client = Any
|
||||
|
||||
|
||||
class RedisCache(BaseCache):
|
||||
# if users don't provider one, use the default litellm cache
|
||||
|
@ -104,6 +115,11 @@ class RedisCache(BaseCache):
|
|||
"Error connecting to Sync Redis client", extra={"error": str(e)}
|
||||
)
|
||||
|
||||
if litellm.default_redis_ttl is not None:
|
||||
super().__init__(default_ttl=int(litellm.default_redis_ttl))
|
||||
else:
|
||||
super().__init__() # defaults to 60s
|
||||
|
||||
def init_async_client(self):
|
||||
from .._redis import get_redis_async_client
|
||||
|
||||
|
@ -121,7 +137,7 @@ class RedisCache(BaseCache):
|
|||
return key
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
ttl = kwargs.get("ttl", None)
|
||||
ttl = self.get_ttl(**kwargs)
|
||||
print_verbose(
|
||||
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
||||
)
|
||||
|
@ -139,15 +155,16 @@ class RedisCache(BaseCache):
|
|||
) -> int:
|
||||
_redis_client = self.redis_client
|
||||
start_time = time.time()
|
||||
set_ttl = self.get_ttl(ttl=ttl)
|
||||
try:
|
||||
result: int = _redis_client.incr(name=key, amount=value) # type: ignore
|
||||
|
||||
if ttl is not None:
|
||||
if set_ttl is not None:
|
||||
# check if key already has ttl, if not -> set ttl
|
||||
current_ttl = _redis_client.ttl(key)
|
||||
if current_ttl == -1:
|
||||
# Key has no expiration
|
||||
_redis_client.expire(key, ttl) # type: ignore
|
||||
_redis_client.expire(key, set_ttl) # type: ignore
|
||||
return result
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
|
@ -236,7 +253,7 @@ class RedisCache(BaseCache):
|
|||
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
async with _redis_client as redis_client:
|
||||
ttl = kwargs.get("ttl", None)
|
||||
ttl = self.get_ttl(**kwargs)
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
|
@ -284,6 +301,26 @@ class RedisCache(BaseCache):
|
|||
value,
|
||||
)
|
||||
|
||||
async def _pipeline_helper(
|
||||
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
|
||||
) -> List:
|
||||
ttl = self.get_ttl(ttl=ttl)
|
||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||
for cache_key, cache_value in cache_list:
|
||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||
)
|
||||
json_cache_value = json.dumps(cache_value)
|
||||
# Set the value with a TTL if it's provided.
|
||||
_td: Optional[timedelta] = None
|
||||
if ttl is not None:
|
||||
_td = timedelta(seconds=ttl)
|
||||
pipe.set(cache_key, json_cache_value, ex=_td)
|
||||
# Execute the pipeline and return the results.
|
||||
results = await pipe.execute()
|
||||
return results
|
||||
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
|
||||
):
|
||||
|
@ -298,8 +335,6 @@ class RedisCache(BaseCache):
|
|||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
start_time = time.time()
|
||||
|
||||
ttl = ttl or kwargs.get("ttl", None)
|
||||
|
||||
print_verbose(
|
||||
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
||||
)
|
||||
|
@ -307,20 +342,7 @@ class RedisCache(BaseCache):
|
|||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||
for cache_key, cache_value in cache_list:
|
||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||
)
|
||||
json_cache_value = json.dumps(cache_value)
|
||||
# Set the value with a TTL if it's provided.
|
||||
_td: Optional[timedelta] = None
|
||||
if ttl is not None:
|
||||
_td = timedelta(seconds=ttl)
|
||||
pipe.set(cache_key, json_cache_value, ex=_td)
|
||||
# Execute the pipeline and return the results.
|
||||
results = await pipe.execute()
|
||||
results = await self._pipeline_helper(pipe, cache_list, ttl)
|
||||
|
||||
print_verbose(f"pipeline results: {results}")
|
||||
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
||||
|
@ -360,6 +382,23 @@ class RedisCache(BaseCache):
|
|||
cache_value,
|
||||
)
|
||||
|
||||
async def _set_cache_sadd_helper(
|
||||
self,
|
||||
redis_client: async_redis_client,
|
||||
key: str,
|
||||
value: List,
|
||||
ttl: Optional[float],
|
||||
) -> None:
|
||||
"""Helper function for async_set_cache_sadd. Separated for testing."""
|
||||
ttl = self.get_ttl(ttl=ttl)
|
||||
try:
|
||||
await redis_client.sadd(key, *value) # type: ignore
|
||||
if ttl is not None:
|
||||
_td = timedelta(seconds=ttl)
|
||||
await redis_client.expire(key, _td)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
async def async_set_cache_sadd(
|
||||
self, key, value: List, ttl: Optional[float], **kwargs
|
||||
):
|
||||
|
@ -396,10 +435,9 @@ class RedisCache(BaseCache):
|
|||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
try:
|
||||
await redis_client.sadd(key, *value) # type: ignore
|
||||
if ttl is not None:
|
||||
_td = timedelta(seconds=ttl)
|
||||
await redis_client.expire(key, _td)
|
||||
await self._set_cache_sadd_helper(
|
||||
redis_client=redis_client, key=key, value=value, ttl=ttl
|
||||
)
|
||||
print_verbose(
|
||||
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
|
@ -452,16 +490,17 @@ class RedisCache(BaseCache):
|
|||
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
start_time = time.time()
|
||||
_used_ttl = self.get_ttl(ttl=ttl)
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
result = await redis_client.incrbyfloat(name=key, amount=value)
|
||||
|
||||
if ttl is not None:
|
||||
if _used_ttl is not None:
|
||||
# check if key already has ttl, if not -> set ttl
|
||||
current_ttl = await redis_client.ttl(key)
|
||||
if current_ttl == -1:
|
||||
# Key has no expiration
|
||||
await redis_client.expire(key, ttl)
|
||||
await redis_client.expire(key, _used_ttl)
|
||||
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
|
|
|
@ -23,8 +23,9 @@ import litellm
|
|||
from litellm import aembedding, completion, embedding
|
||||
from litellm.caching.caching import Cache
|
||||
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock, patch, MagicMock, call
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
# litellm.set_verbose=True
|
||||
|
||||
|
@ -2394,3 +2395,123 @@ async def test_audio_caching(stream):
|
|||
)
|
||||
|
||||
assert "cache_hit" in completion._hidden_params
|
||||
|
||||
|
||||
def test_redis_caching_default_ttl():
|
||||
"""
|
||||
Ensure that the default redis cache TTL is 60s
|
||||
"""
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
litellm.default_redis_ttl = 120
|
||||
|
||||
cache_obj = RedisCache()
|
||||
assert cache_obj.default_ttl == 120
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_redis_caching_llm_caching_ttl(sync_mode):
|
||||
"""
|
||||
Ensure default redis cache ttl is used for a sample redis cache object
|
||||
"""
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
litellm.default_redis_ttl = 120
|
||||
cache_obj = RedisCache()
|
||||
assert cache_obj.default_ttl == 120
|
||||
|
||||
if sync_mode is False:
|
||||
# Create an AsyncMock for the Redis client
|
||||
mock_redis_instance = AsyncMock()
|
||||
|
||||
# Make sure the mock can be used as an async context manager
|
||||
mock_redis_instance.__aenter__.return_value = mock_redis_instance
|
||||
mock_redis_instance.__aexit__.return_value = None
|
||||
|
||||
## Set cache
|
||||
if sync_mode is True:
|
||||
with patch.object(cache_obj.redis_client, "set") as mock_set:
|
||||
cache_obj.set_cache(key="test", value="test")
|
||||
mock_set.assert_called_once_with(name="test", value="test", ex=120)
|
||||
else:
|
||||
|
||||
# Patch self.init_async_client to return our mock Redis client
|
||||
with patch.object(
|
||||
cache_obj, "init_async_client", return_value=mock_redis_instance
|
||||
):
|
||||
# Call async_set_cache
|
||||
await cache_obj.async_set_cache(key="test", value="test_value")
|
||||
|
||||
# Verify that the set method was called on the mock Redis instance
|
||||
mock_redis_instance.set.assert_called_once_with(
|
||||
name="test", value='"test_value"', ex=120
|
||||
)
|
||||
|
||||
## Increment cache
|
||||
if sync_mode is True:
|
||||
with patch.object(cache_obj.redis_client, "ttl") as mock_incr:
|
||||
cache_obj.increment_cache(key="test", value=1)
|
||||
mock_incr.assert_called_once_with("test")
|
||||
else:
|
||||
# Patch self.init_async_client to return our mock Redis client
|
||||
with patch.object(
|
||||
cache_obj, "init_async_client", return_value=mock_redis_instance
|
||||
):
|
||||
# Call async_set_cache
|
||||
await cache_obj.async_increment(key="test", value="test_value")
|
||||
|
||||
# Verify that the set method was called on the mock Redis instance
|
||||
mock_redis_instance.ttl.assert_called_once_with("test")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_caching_ttl_pipeline():
|
||||
"""
|
||||
Ensure that a default ttl is set for all redis functions
|
||||
"""
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
litellm.default_redis_ttl = 120
|
||||
expected_timedelta = timedelta(seconds=120)
|
||||
cache_obj = RedisCache()
|
||||
|
||||
## TEST 1 - async_set_cache_pipeline
|
||||
# Patch self.init_async_client to return our mock Redis client
|
||||
# Call async_set_cache
|
||||
mock_pipe_instance = AsyncMock()
|
||||
with patch.object(mock_pipe_instance, "set", return_value=None) as mock_set:
|
||||
await cache_obj._pipeline_helper(
|
||||
pipe=mock_pipe_instance,
|
||||
cache_list=[("test_key1", "test_value1"), ("test_key2", "test_value2")],
|
||||
ttl=None,
|
||||
)
|
||||
|
||||
# Verify that the set method was called on the mock Redis instance
|
||||
mock_set.assert_has_calls(
|
||||
[
|
||||
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
|
||||
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_redis_caching_ttl_sadd():
|
||||
"""
|
||||
Ensure that a default ttl is set for all redis functions
|
||||
"""
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
litellm.default_redis_ttl = 120
|
||||
expected_timedelta = timedelta(seconds=120)
|
||||
cache_obj = RedisCache()
|
||||
redis_client = AsyncMock()
|
||||
|
||||
with patch.object(redis_client, "expire", return_value=None) as mock_expire:
|
||||
await cache_obj._set_cache_sadd_helper(
|
||||
redis_client=redis_client, key="test_key", value=["test_value"], ttl=None
|
||||
)
|
||||
print(f"expected_timedelta: {expected_timedelta}")
|
||||
assert mock_expire.call_args.args[1] == expected_timedelta
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue