forked from phoenix/litellm-mirror
refactor(redis_cache.py): use a default cache value when writing to r… (#6358)
* refactor(redis_cache.py): use a default cache value when writing to redis prevent redis from blowing up in high traffic * refactor(redis_cache.py): refactor all cache writes to use self.get_ttl ensures default ttl always used when writing to redis Prevents redis db from blowing up in prod
This commit is contained in:
parent
199896f912
commit
7338b24a74
3 changed files with 199 additions and 29 deletions
|
@ -8,8 +8,18 @@ Has 4 methods:
|
||||||
- async_get_cache
|
- async_get_cache
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class BaseCache:
|
class BaseCache:
|
||||||
|
def __init__(self, default_ttl: int = 60):
|
||||||
|
self.default_ttl = default_ttl
|
||||||
|
|
||||||
|
def get_ttl(self, **kwargs) -> Optional[int]:
|
||||||
|
if kwargs.get("ttl") is not None:
|
||||||
|
return kwargs.get("ttl")
|
||||||
|
return self.default_ttl
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -22,7 +32,7 @@ class BaseCache:
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def batch_cache_write(self, result, *args, **kwargs):
|
async def batch_cache_write(self, key, value, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
|
|
|
@ -14,8 +14,9 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
|
@ -23,6 +24,16 @@ from litellm.types.utils import all_litellm_params
|
||||||
|
|
||||||
from .base_cache import BaseCache
|
from .base_cache import BaseCache
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
from redis.asyncio.client import Pipeline
|
||||||
|
|
||||||
|
pipeline = Pipeline
|
||||||
|
async_redis_client = Redis
|
||||||
|
else:
|
||||||
|
pipeline = Any
|
||||||
|
async_redis_client = Any
|
||||||
|
|
||||||
|
|
||||||
class RedisCache(BaseCache):
|
class RedisCache(BaseCache):
|
||||||
# if users don't provider one, use the default litellm cache
|
# if users don't provider one, use the default litellm cache
|
||||||
|
@ -104,6 +115,11 @@ class RedisCache(BaseCache):
|
||||||
"Error connecting to Sync Redis client", extra={"error": str(e)}
|
"Error connecting to Sync Redis client", extra={"error": str(e)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if litellm.default_redis_ttl is not None:
|
||||||
|
super().__init__(default_ttl=int(litellm.default_redis_ttl))
|
||||||
|
else:
|
||||||
|
super().__init__() # defaults to 60s
|
||||||
|
|
||||||
def init_async_client(self):
|
def init_async_client(self):
|
||||||
from .._redis import get_redis_async_client
|
from .._redis import get_redis_async_client
|
||||||
|
|
||||||
|
@ -121,7 +137,7 @@ class RedisCache(BaseCache):
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = self.get_ttl(**kwargs)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
||||||
)
|
)
|
||||||
|
@ -139,15 +155,16 @@ class RedisCache(BaseCache):
|
||||||
) -> int:
|
) -> int:
|
||||||
_redis_client = self.redis_client
|
_redis_client = self.redis_client
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
set_ttl = self.get_ttl(ttl=ttl)
|
||||||
try:
|
try:
|
||||||
result: int = _redis_client.incr(name=key, amount=value) # type: ignore
|
result: int = _redis_client.incr(name=key, amount=value) # type: ignore
|
||||||
|
|
||||||
if ttl is not None:
|
if set_ttl is not None:
|
||||||
# check if key already has ttl, if not -> set ttl
|
# check if key already has ttl, if not -> set ttl
|
||||||
current_ttl = _redis_client.ttl(key)
|
current_ttl = _redis_client.ttl(key)
|
||||||
if current_ttl == -1:
|
if current_ttl == -1:
|
||||||
# Key has no expiration
|
# Key has no expiration
|
||||||
_redis_client.expire(key, ttl) # type: ignore
|
_redis_client.expire(key, set_ttl) # type: ignore
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING ##
|
## LOGGING ##
|
||||||
|
@ -236,7 +253,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
key = self.check_and_fix_namespace(key=key)
|
key = self.check_and_fix_namespace(key=key)
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = self.get_ttl(**kwargs)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||||
)
|
)
|
||||||
|
@ -284,6 +301,26 @@ class RedisCache(BaseCache):
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _pipeline_helper(
|
||||||
|
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
|
||||||
|
) -> List:
|
||||||
|
ttl = self.get_ttl(ttl=ttl)
|
||||||
|
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||||
|
for cache_key, cache_value in cache_list:
|
||||||
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
|
print_verbose(
|
||||||
|
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||||
|
)
|
||||||
|
json_cache_value = json.dumps(cache_value)
|
||||||
|
# Set the value with a TTL if it's provided.
|
||||||
|
_td: Optional[timedelta] = None
|
||||||
|
if ttl is not None:
|
||||||
|
_td = timedelta(seconds=ttl)
|
||||||
|
pipe.set(cache_key, json_cache_value, ex=_td)
|
||||||
|
# Execute the pipeline and return the results.
|
||||||
|
results = await pipe.execute()
|
||||||
|
return results
|
||||||
|
|
||||||
async def async_set_cache_pipeline(
|
async def async_set_cache_pipeline(
|
||||||
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
|
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
|
||||||
):
|
):
|
||||||
|
@ -298,8 +335,6 @@ class RedisCache(BaseCache):
|
||||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
ttl = ttl or kwargs.get("ttl", None)
|
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
||||||
)
|
)
|
||||||
|
@ -307,20 +342,7 @@ class RedisCache(BaseCache):
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
async with redis_client.pipeline(transaction=True) as pipe:
|
||||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
results = await self._pipeline_helper(pipe, cache_list, ttl)
|
||||||
for cache_key, cache_value in cache_list:
|
|
||||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
|
||||||
print_verbose(
|
|
||||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
|
||||||
)
|
|
||||||
json_cache_value = json.dumps(cache_value)
|
|
||||||
# Set the value with a TTL if it's provided.
|
|
||||||
_td: Optional[timedelta] = None
|
|
||||||
if ttl is not None:
|
|
||||||
_td = timedelta(seconds=ttl)
|
|
||||||
pipe.set(cache_key, json_cache_value, ex=_td)
|
|
||||||
# Execute the pipeline and return the results.
|
|
||||||
results = await pipe.execute()
|
|
||||||
|
|
||||||
print_verbose(f"pipeline results: {results}")
|
print_verbose(f"pipeline results: {results}")
|
||||||
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
||||||
|
@ -360,6 +382,23 @@ class RedisCache(BaseCache):
|
||||||
cache_value,
|
cache_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _set_cache_sadd_helper(
|
||||||
|
self,
|
||||||
|
redis_client: async_redis_client,
|
||||||
|
key: str,
|
||||||
|
value: List,
|
||||||
|
ttl: Optional[float],
|
||||||
|
) -> None:
|
||||||
|
"""Helper function for async_set_cache_sadd. Separated for testing."""
|
||||||
|
ttl = self.get_ttl(ttl=ttl)
|
||||||
|
try:
|
||||||
|
await redis_client.sadd(key, *value) # type: ignore
|
||||||
|
if ttl is not None:
|
||||||
|
_td = timedelta(seconds=ttl)
|
||||||
|
await redis_client.expire(key, _td)
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
async def async_set_cache_sadd(
|
async def async_set_cache_sadd(
|
||||||
self, key, value: List, ttl: Optional[float], **kwargs
|
self, key, value: List, ttl: Optional[float], **kwargs
|
||||||
):
|
):
|
||||||
|
@ -396,10 +435,9 @@ class RedisCache(BaseCache):
|
||||||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await redis_client.sadd(key, *value) # type: ignore
|
await self._set_cache_sadd_helper(
|
||||||
if ttl is not None:
|
redis_client=redis_client, key=key, value=value, ttl=ttl
|
||||||
_td = timedelta(seconds=ttl)
|
)
|
||||||
await redis_client.expire(key, _td)
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
|
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
|
||||||
)
|
)
|
||||||
|
@ -452,16 +490,17 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
_used_ttl = self.get_ttl(ttl=ttl)
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
result = await redis_client.incrbyfloat(name=key, amount=value)
|
result = await redis_client.incrbyfloat(name=key, amount=value)
|
||||||
|
|
||||||
if ttl is not None:
|
if _used_ttl is not None:
|
||||||
# check if key already has ttl, if not -> set ttl
|
# check if key already has ttl, if not -> set ttl
|
||||||
current_ttl = await redis_client.ttl(key)
|
current_ttl = await redis_client.ttl(key)
|
||||||
if current_ttl == -1:
|
if current_ttl == -1:
|
||||||
# Key has no expiration
|
# Key has no expiration
|
||||||
await redis_client.expire(key, ttl)
|
await redis_client.expire(key, _used_ttl)
|
||||||
|
|
||||||
## LOGGING ##
|
## LOGGING ##
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
|
@ -23,8 +23,9 @@ import litellm
|
||||||
from litellm import aembedding, completion, embedding
|
from litellm import aembedding, completion, embedding
|
||||||
from litellm.caching.caching import Cache
|
from litellm.caching.caching import Cache
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock, call
|
||||||
import datetime
|
import datetime
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
|
|
||||||
|
@ -2394,3 +2395,123 @@ async def test_audio_caching(stream):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "cache_hit" in completion._hidden_params
|
assert "cache_hit" in completion._hidden_params
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_caching_default_ttl():
|
||||||
|
"""
|
||||||
|
Ensure that the default redis cache TTL is 60s
|
||||||
|
"""
|
||||||
|
from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
|
litellm.default_redis_ttl = 120
|
||||||
|
|
||||||
|
cache_obj = RedisCache()
|
||||||
|
assert cache_obj.default_ttl == 120
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_redis_caching_llm_caching_ttl(sync_mode):
|
||||||
|
"""
|
||||||
|
Ensure default redis cache ttl is used for a sample redis cache object
|
||||||
|
"""
|
||||||
|
from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
|
litellm.default_redis_ttl = 120
|
||||||
|
cache_obj = RedisCache()
|
||||||
|
assert cache_obj.default_ttl == 120
|
||||||
|
|
||||||
|
if sync_mode is False:
|
||||||
|
# Create an AsyncMock for the Redis client
|
||||||
|
mock_redis_instance = AsyncMock()
|
||||||
|
|
||||||
|
# Make sure the mock can be used as an async context manager
|
||||||
|
mock_redis_instance.__aenter__.return_value = mock_redis_instance
|
||||||
|
mock_redis_instance.__aexit__.return_value = None
|
||||||
|
|
||||||
|
## Set cache
|
||||||
|
if sync_mode is True:
|
||||||
|
with patch.object(cache_obj.redis_client, "set") as mock_set:
|
||||||
|
cache_obj.set_cache(key="test", value="test")
|
||||||
|
mock_set.assert_called_once_with(name="test", value="test", ex=120)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# Patch self.init_async_client to return our mock Redis client
|
||||||
|
with patch.object(
|
||||||
|
cache_obj, "init_async_client", return_value=mock_redis_instance
|
||||||
|
):
|
||||||
|
# Call async_set_cache
|
||||||
|
await cache_obj.async_set_cache(key="test", value="test_value")
|
||||||
|
|
||||||
|
# Verify that the set method was called on the mock Redis instance
|
||||||
|
mock_redis_instance.set.assert_called_once_with(
|
||||||
|
name="test", value='"test_value"', ex=120
|
||||||
|
)
|
||||||
|
|
||||||
|
## Increment cache
|
||||||
|
if sync_mode is True:
|
||||||
|
with patch.object(cache_obj.redis_client, "ttl") as mock_incr:
|
||||||
|
cache_obj.increment_cache(key="test", value=1)
|
||||||
|
mock_incr.assert_called_once_with("test")
|
||||||
|
else:
|
||||||
|
# Patch self.init_async_client to return our mock Redis client
|
||||||
|
with patch.object(
|
||||||
|
cache_obj, "init_async_client", return_value=mock_redis_instance
|
||||||
|
):
|
||||||
|
# Call async_set_cache
|
||||||
|
await cache_obj.async_increment(key="test", value="test_value")
|
||||||
|
|
||||||
|
# Verify that the set method was called on the mock Redis instance
|
||||||
|
mock_redis_instance.ttl.assert_called_once_with("test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_redis_caching_ttl_pipeline():
|
||||||
|
"""
|
||||||
|
Ensure that a default ttl is set for all redis functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
|
litellm.default_redis_ttl = 120
|
||||||
|
expected_timedelta = timedelta(seconds=120)
|
||||||
|
cache_obj = RedisCache()
|
||||||
|
|
||||||
|
## TEST 1 - async_set_cache_pipeline
|
||||||
|
# Patch self.init_async_client to return our mock Redis client
|
||||||
|
# Call async_set_cache
|
||||||
|
mock_pipe_instance = AsyncMock()
|
||||||
|
with patch.object(mock_pipe_instance, "set", return_value=None) as mock_set:
|
||||||
|
await cache_obj._pipeline_helper(
|
||||||
|
pipe=mock_pipe_instance,
|
||||||
|
cache_list=[("test_key1", "test_value1"), ("test_key2", "test_value2")],
|
||||||
|
ttl=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the set method was called on the mock Redis instance
|
||||||
|
mock_set.assert_has_calls(
|
||||||
|
[
|
||||||
|
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
|
||||||
|
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_redis_caching_ttl_sadd():
|
||||||
|
"""
|
||||||
|
Ensure that a default ttl is set for all redis functions
|
||||||
|
"""
|
||||||
|
from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
|
litellm.default_redis_ttl = 120
|
||||||
|
expected_timedelta = timedelta(seconds=120)
|
||||||
|
cache_obj = RedisCache()
|
||||||
|
redis_client = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(redis_client, "expire", return_value=None) as mock_expire:
|
||||||
|
await cache_obj._set_cache_sadd_helper(
|
||||||
|
redis_client=redis_client, key="test_key", value=["test_value"], ttl=None
|
||||||
|
)
|
||||||
|
print(f"expected_timedelta: {expected_timedelta}")
|
||||||
|
assert mock_expire.call_args.args[1] == expected_timedelta
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue