diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index f845633cb..5fd972a76 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -271,7 +271,7 @@ class Cache: cache_key += f"{str(param)}: {str(param_value)}" verbose_logger.debug("\nCreated cache key: %s", cache_key) - hashed_cache_key = self._get_hashed_cache_key(cache_key) + hashed_cache_key = Cache._get_hashed_cache_key(cache_key) hashed_cache_key = self._add_redis_namespace_to_cache_key( hashed_cache_key, **kwargs ) @@ -431,7 +431,8 @@ class Cache: """ return set(["metadata"]) - def _get_hashed_cache_key(self, cache_key: str) -> str: + @staticmethod + def _get_hashed_cache_key(cache_key: str) -> str: """ Get the hashed cache key for the given cache key. diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py index 35659b865..1bf16bb65 100644 --- a/litellm/caching/dual_cache.py +++ b/litellm/caching/dual_cache.py @@ -8,8 +8,9 @@ Has 4 primary methods: - async_get_cache """ +import time import traceback -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Tuple import litellm from litellm._logging import print_verbose, verbose_logger @@ -25,6 +26,19 @@ if TYPE_CHECKING: else: Span = Any +from collections import OrderedDict + + +class LimitedSizeOrderedDict(OrderedDict): + def __init__(self, *args, max_size=100, **kwargs): + super().__init__(*args, **kwargs) + self.max_size = max_size + + def __setitem__(self, key, value): + # If inserting a new key exceeds max size, remove the oldest item + if len(self) >= self.max_size: + self.popitem(last=False) + super().__setitem__(key, value) class DualCache(BaseCache): """ @@ -39,13 +53,18 @@ class DualCache(BaseCache): redis_cache: Optional[RedisCache] = None, default_in_memory_ttl: Optional[float] = None, default_redis_ttl: Optional[float] = None, + default_redis_batch_cache_expiry: float = 1, + default_max_redis_batch_cache_size: int = 100, ) -> None: super().__init__() # If in_memory_cache is not provided, use the default InMemoryCache self.in_memory_cache = in_memory_cache or InMemoryCache() # If redis_cache is not provided, use the default RedisCache self.redis_cache = redis_cache - + self.last_redis_batch_access_time = LimitedSizeOrderedDict( + max_size=default_max_redis_batch_cache_size + ) + self.redis_batch_cache_expiry = default_redis_batch_cache_expiry self.default_in_memory_ttl = ( default_in_memory_ttl or litellm.default_in_memory_ttl ) @@ -150,20 +169,34 @@ class DualCache(BaseCache): - for the none values in the result - check the redis cache """ - sublist_keys = [ - key for key, value in zip(keys, result) if value is None - ] - # If not found in in-memory cache, try fetching from Redis - redis_result = self.redis_cache.batch_get_cache( - sublist_keys, parent_otel_span=parent_otel_span - ) - if redis_result is not None: - # Update in-memory cache with the value from Redis - for key in redis_result: - self.in_memory_cache.set_cache(key, redis_result[key], **kwargs) + # Track the last access time for these keys + current_time = time.time() + key_tuple = tuple(keys) - for key, value in redis_result.items(): - result[keys.index(key)] = value + # Only hit Redis if the last access time was more than 5 seconds ago + if ( + key_tuple not in self.last_redis_batch_access_time + or current_time - self.last_redis_batch_access_time[key_tuple] + >= self.redis_batch_cache_expiry + ): + + sublist_keys = [ + key for key, value in zip(keys, result) if value is None + ] + # If not found in in-memory cache, try fetching from Redis + redis_result = self.redis_cache.batch_get_cache( + sublist_keys, parent_otel_span=parent_otel_span + ) + if redis_result is not None: + # Update in-memory cache with the value from Redis + for key in redis_result: + self.in_memory_cache.set_cache( + key, redis_result[key], **kwargs + ) + + + for key, value in redis_result.items(): + result[keys.index(key)] = value print_verbose(f"async batch get cache: cache result: {result}") return result @@ -227,29 +260,41 @@ class DualCache(BaseCache): if in_memory_result is not None: result = in_memory_result + if None in result and self.redis_cache is not None and local_only is False: """ - for the none values in the result - check the redis cache """ - sublist_keys = [ - key for key, value in zip(keys, result) if value is None - ] - # If not found in in-memory cache, try fetching from Redis - redis_result = await self.redis_cache.async_batch_get_cache( - sublist_keys, parent_otel_span=parent_otel_span - ) + # Track the last access time for these keys + current_time = time.time() + key_tuple = tuple(keys) - if redis_result is not None: - # Update in-memory cache with the value from Redis + # Only hit Redis if the last access time was more than 5 seconds ago + if ( + key_tuple not in self.last_redis_batch_access_time + or current_time - self.last_redis_batch_access_time[key_tuple] + >= self.redis_batch_cache_expiry + ): + sublist_keys = [ + key for key, value in zip(keys, result) if value is None + ] + # If not found in in-memory cache, try fetching from Redis + redis_result = await self.redis_cache.async_batch_get_cache( + sublist_keys, parent_otel_span=parent_otel_span + ) + + + if redis_result is not None: + # Update in-memory cache with the value from Redis + for key, value in redis_result.items(): + if value is not None: + await self.in_memory_cache.async_set_cache( + key, redis_result[key], **kwargs + ) for key, value in redis_result.items(): - if value is not None: - await self.in_memory_cache.async_set_cache( - key, redis_result[key], **kwargs - ) - for key, value in redis_result.items(): - index = keys.index(key) - result[index] = value + index = keys.index(key) + result[index] = value return result except Exception: diff --git a/litellm/router.py b/litellm/router.py index e2c033c60..ac26aa61e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5153,6 +5153,7 @@ class Router: verbose_router_logger.debug( f"async cooldown deployments: {cooldown_deployments}" ) + verbose_router_logger.debug(f"cooldown_deployments: {cooldown_deployments}") healthy_deployments = self._filter_cooldown_deployments( healthy_deployments=healthy_deployments, cooldown_deployments=cooldown_deployments, @@ -5261,7 +5262,7 @@ class Router: _cooldown_time = self.cooldown_cache.get_min_cooldown( model_ids=model_ids, parent_otel_span=parent_otel_span ) - _cooldown_list = _get_cooldown_deployments( + _cooldown_list = await _async_get_cooldown_deployments( litellm_router_instance=self, parent_otel_span=parent_otel_span ) raise RouterRateLimitError( diff --git a/litellm/router_utils/cooldown_cache.py b/litellm/router_utils/cooldown_cache.py index 792d91811..44174f3b1 100644 --- a/litellm/router_utils/cooldown_cache.py +++ b/litellm/router_utils/cooldown_cache.py @@ -7,7 +7,15 @@ import time from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict from litellm import verbose_logger -from litellm.caching.caching import DualCache +from litellm.caching.caching import Cache, DualCache +from litellm.caching.in_memory_cache import InMemoryCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -28,6 +36,7 @@ class CooldownCache: def __init__(self, cache: DualCache, default_cooldown_time: float): self.cache = cache self.default_cooldown_time = default_cooldown_time + self.in_memory_cache = InMemoryCache() def _common_add_cooldown_logic( self, model_id: str, original_exception, exception_status, cooldown_time: float @@ -83,21 +92,32 @@ class CooldownCache: ) raise e + @staticmethod + def get_cooldown_cache_key(model_id: str) -> str: + return f"deployment:{model_id}:cooldown" + async def async_get_active_cooldowns( self, model_ids: List[str], parent_otel_span: Optional[Span] ) -> List[Tuple[str, CooldownCacheValue]]: # Generate the keys for the deployments - keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] + keys = [ + CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids + ] # Retrieve the values for the keys using mget - results = ( - await self.cache.async_batch_get_cache( - keys=keys, parent_otel_span=parent_otel_span - ) - or [] - ) + ## more likely to be none if no models ratelimited. So just check redis every 1s + ## each redis call adds ~100ms latency. + + ## check in memory cache first + results = await self.cache.async_batch_get_cache( + keys=keys, parent_otel_span=parent_otel_span + ) + active_cooldowns: List[Tuple[str, CooldownCacheValue]] = [] + + if results is None: + return active_cooldowns + - active_cooldowns = [] # Process the results for model_id, result in zip(model_ids, results): if result and isinstance(result, dict): diff --git a/tests/local_testing/test_acooldowns_router.py b/tests/local_testing/test_acooldowns_router.py index cad4d9e66..f186d42f1 100644 --- a/tests/local_testing/test_acooldowns_router.py +++ b/tests/local_testing/test_acooldowns_router.py @@ -17,6 +17,7 @@ import concurrent from dotenv import load_dotenv import litellm + from litellm import Router load_dotenv() @@ -130,6 +131,7 @@ def test_multiple_deployments_parallel(): @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_cooldown_same_model_name(sync_mode): + litellm._turn_on_debug() # users could have the same model with different api_base # example # azure/chatgpt, api_base: 1234 diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index b19585430..3456f4535 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -681,6 +681,7 @@ async def test_redis_cache_basic(): @pytest.mark.asyncio +@pytest.mark.flaky(retries=3, delay=1) async def test_redis_batch_cache_write(): """ Init redis client @@ -2477,3 +2478,30 @@ async def test_redis_caching_ttl_sadd(): ) print(f"expected_timedelta: {expected_timedelta}") assert mock_expire.call_args.args[1] == expected_timedelta + + +@pytest.mark.asyncio() +async def test_dual_cache_caching_batch_get_cache(): + """ + - check redis cache called for initial batch get cache + - check redis cache not called for consecutive batch get cache with same keys + """ + from litellm.caching.dual_cache import DualCache + from litellm.caching.redis_cache import RedisCache + + dc = DualCache(redis_cache=MagicMock(spec=RedisCache)) + + with patch.object( + dc.redis_cache, + "async_batch_get_cache", + new=AsyncMock( + return_value={"test_key1": "test_value1", "test_key2": "test_value2"} + ), + ) as mock_async_get_cache: + await dc.async_batch_get_cache(keys=["test_key1", "test_key2"]) + + assert mock_async_get_cache.call_count == 1 + + await dc.async_batch_get_cache(keys=["test_key1", "test_key2"]) + + assert mock_async_get_cache.call_count == 1 diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index d360d7317..7bf0b0bba 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2445,6 +2445,8 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode): except litellm.RateLimitError: pass + await asyncio.sleep(2) + if sync_mode: cooldown_deployments = _get_cooldown_deployments( litellm_router_instance=router, parent_otel_span=None diff --git a/tests/local_testing/test_unit_test_caching.py b/tests/local_testing/test_unit_test_caching.py index 4d7c50666..5f8f41ba5 100644 --- a/tests/local_testing/test_unit_test_caching.py +++ b/tests/local_testing/test_unit_test_caching.py @@ -135,7 +135,7 @@ def test_get_cache_key_text_completion(): def test_get_hashed_cache_key(): cache = Cache() cache_key = "model:gpt-3.5-turbo,messages:Hello world" - hashed_key = cache._get_hashed_cache_key(cache_key) + hashed_key = Cache._get_hashed_cache_key(cache_key) assert len(hashed_key) == 64 # SHA-256 produces a 64-character hex string diff --git a/tests/router_unit_tests/test_router_cooldown_utils.py b/tests/router_unit_tests/test_router_cooldown_utils.py index c8795e541..7ee2e927d 100644 --- a/tests/router_unit_tests/test_router_cooldown_utils.py +++ b/tests/router_unit_tests/test_router_cooldown_utils.py @@ -11,7 +11,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from dotenv import load_dotenv -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from litellm.integrations.prometheus import PrometheusLogger from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback from litellm.router_utils.cooldown_handlers import (