diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 65b1e3c97..002e2d1c4 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -51,6 +51,7 @@ DEFAULT_REDIS_SYNC_INTERVAL = 60 class ProviderBudgetLimiting(CustomLogger): def __init__(self, router_cache: DualCache, provider_budget_config: dict): self.router_cache = router_cache + self.last_synced_values = {} # To track last synced values for each key asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) # cast elements of provider_budget_config to ProviderBudgetInfo @@ -217,10 +218,9 @@ class ProviderBudgetLimiting(CustomLogger): f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" ) # Increment the spend in Redis and set TTL - await self.router_cache.async_increment_cache( + await self.router_cache.in_memory_cache.async_increment( key=spend_key, value=response_cost, - ttl=ttl_seconds, ) verbose_router_logger.debug( f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" @@ -244,6 +244,55 @@ class ProviderBudgetLimiting(CustomLogger): DEFAULT_REDIS_SYNC_INTERVAL ) # Still wait 5 seconds on error before retrying + async def _push_in_memory_increments_to_redis(self): + """ + Sync in-memory spend to Redis. + + Pushes all increments from in-memory counter to Redis and resets the counter. + """ + try: + if not self.router_cache.redis_cache: + return # Redis is not initialized + + # Build cache keys for all providers + cache_keys = [ + f"provider_spend:{provider}:{config.time_period}" + for provider, config in self.provider_budget_config.items() + if config is not None + ] + + # Fetch current in-memory values + current_values = ( + await self.router_cache.in_memory_cache.async_batch_get_cache( + keys=cache_keys + ) + ) + + for key, current_value in zip(cache_keys, current_values): + if current_value is None: + continue # Skip keys with no in-memory value + + # Get the last synced value (default to 0 if not synced before) + last_synced = self.last_synced_values.get(key, 0.0) + + # Calculate the delta to push to Redis + delta = float(current_value) - last_synced + if delta > 0: # Only push if there is a positive increment + await self.router_cache.redis_cache.async_increment( + key=key, value=delta + ) + verbose_router_logger.debug( + f"Pushed delta to Redis for {key}: {delta} (last synced: {last_synced}, current: {current_value})" + ) + + # Update last synced value + self.last_synced_values[key] = float(current_value) + + except Exception as e: + verbose_router_logger.error( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + async def _sync_in_memory_spend_with_redis(self): """ Ensures in-memory cache is updated with latest Redis values for all provider spends. @@ -263,6 +312,8 @@ class ProviderBudgetLimiting(CustomLogger): if self.router_cache.redis_cache is None: return + await self._push_in_memory_increments_to_redis() + # Get all providers and their budget configs cache_keys = [] for provider, config in self.provider_budget_config.items():