From 24ab97948621b8dd785960b2f13a1eb86fcc4de7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 09:42:25 -0800 Subject: [PATCH] use redis async_increment_pipeline --- litellm/router_strategy/provider_budgets.py | 51 ++++++--------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 55ac18606..c0686104c 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union import litellm from litellm._logging import verbose_router_logger from litellm.caching.caching import DualCache +from litellm.caching.redis_cache import RedisPipelineIncrementOperation from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.duration_parser import duration_in_seconds @@ -51,7 +52,7 @@ DEFAULT_REDIS_SYNC_INTERVAL = 60 class ProviderBudgetLimiting(CustomLogger): def __init__(self, router_cache: DualCache, provider_budget_config: dict): self.router_cache = router_cache - self.last_synced_values = {} # To track last synced values for each key + self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) # cast elements of provider_budget_config to ProviderBudgetInfo @@ -214,14 +215,17 @@ class ProviderBudgetLimiting(CustomLogger): spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" ttl_seconds = duration_in_seconds(duration=budget_config.time_period) - verbose_router_logger.debug( - f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" + + # Create RedisPipelineIncrementOperation object + increment_op = RedisPipelineIncrementOperation( + key=spend_key, increment_value=response_cost, ttl_seconds=ttl_seconds ) - # Increment the spend in Redis and set TTL + await self.router_cache.in_memory_cache.async_increment( key=spend_key, value=response_cost, ) + self.redis_increment_operation_queue.append(increment_op) verbose_router_logger.debug( f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" ) @@ -256,40 +260,15 @@ class ProviderBudgetLimiting(CustomLogger): try: if not self.router_cache.redis_cache: return # Redis is not initialized - - # Build cache keys for all providers - cache_keys = [ - f"provider_spend:{provider}:{config.time_period}" - for provider, config in self.provider_budget_config.items() - if config is not None - ] - - # Fetch current in-memory values - current_values = ( - await self.router_cache.in_memory_cache.async_batch_get_cache( - keys=cache_keys - ) + verbose_router_logger.debug( + "Pushing Redis Increment Pipeline for queue: %s", + self.redis_increment_operation_queue, + ) + await self.router_cache.redis_cache.async_increment_pipeline( + increment_list=self.redis_increment_operation_queue, ) - for key, current_value in zip(cache_keys, current_values): - if current_value is None: - continue # Skip keys with no in-memory value - - # Get the last synced value (default to 0 if not synced before) - last_synced = self.last_synced_values.get(key, 0.0) - # Calculate the delta to push to Redis - delta = float(current_value) - last_synced - if delta > 0: # Only push if there is a positive increment - await self.router_cache.redis_cache.async_increment( - key=key, - value=delta, - ) - verbose_router_logger.debug( - f"Pushed delta to Redis for {key}: {delta} (last synced: {last_synced}, current: {current_value})" - ) - - # Update last synced value - self.last_synced_values[key] = float(current_value) + self.redis_increment_operation_queue = [] except Exception as e: verbose_router_logger.error(