diff --git a/litellm/router.py b/litellm/router.py index d582f614f..db6debd56 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5189,7 +5189,7 @@ class Router: ): deployment = ( await self.provider_budget_logger.async_get_available_deployments( - model_group=model, + request_kwargs=request_kwargs, healthy_deployments=healthy_deployments, # type: ignore ) ) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 423bcdd59..5c6594d30 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -45,38 +45,61 @@ class ProviderBudgetLimiting(CustomLogger): async def async_get_available_deployments( self, - model_group: str, healthy_deployments: List[Dict], - messages: Optional[List[Dict[str, str]]] = None, - input: Optional[Union[str, List]] = None, request_kwargs: Optional[Dict] = None, - ): + ) -> Optional[Dict]: """ Filter list of healthy deployments based on provider budget """ potential_deployments: List[Dict] = [] + # Extract the parent OpenTelemetry span for tracing parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( request_kwargs ) + # Collect all providers and their budget configs + # {"openai": ProviderBudgetInfo, "anthropic": ProviderBudgetInfo, "azure": None} + _provider_configs: Dict[str, Optional[ProviderBudgetInfo]] = {} for deployment in healthy_deployments: provider = self._get_llm_provider_for_deployment(deployment) budget_config = self._get_budget_config_for_provider(provider) - if budget_config is None: - verbose_router_logger.debug( - f"No budget config found for provider {provider}, skipping" - ) + _provider_configs[provider] = budget_config + + # Filter out providers without budget config + provider_configs: Dict[str, ProviderBudgetInfo] = { + provider: config + for provider, config in _provider_configs.items() + if config is not None + } + + # Build cache keys for batch retrieval + cache_keys = [] + for provider, config in provider_configs.items(): + cache_keys.append(f"provider_spend:{provider}:{config.time_period}") + + # Fetch current spend for all providers using batch cache + _current_spends = await self.router_cache.async_batch_get_cache( + keys=cache_keys, + parent_otel_span=parent_otel_span, + ) + current_spends: List = _current_spends or [0.0] * len(provider_configs) + + # Map providers to their current spend values + provider_spend_map: Dict[str, float] = {} + for idx, provider in enumerate(provider_configs.keys()): + provider_spend_map[provider] = float(current_spends[idx] or 0.0) + + # Filter healthy deployments based on budget constraints + for deployment in healthy_deployments: + provider = self._get_llm_provider_for_deployment(deployment) + budget_config = provider_configs.get(provider) + + if not budget_config: continue + current_spend = provider_spend_map.get(provider, 0.0) budget_limit = budget_config.budget_limit - current_spend: float = ( - await self.router_cache.async_get_cache( - key=f"provider_spend:{provider}:{budget_config.time_period}", - parent_otel_span=parent_otel_span, - ) - or 0.0 - ) verbose_router_logger.debug( f"Current spend for {provider}: {current_spend}, budget limit: {budget_limit}" @@ -89,18 +112,15 @@ class ProviderBudgetLimiting(CustomLogger): continue potential_deployments.append(deployment) - # randomly pick one deployment from potential_deployments - if potential_deployments: - return random.choice(potential_deployments) - return None + + # Randomly pick one deployment from potential deployments + return random.choice(potential_deployments) if potential_deployments else None async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): """ Increment provider spend in DualCache (InMemory + Redis) """ - verbose_router_logger.debug( - f"in ProviderBudgetLimiting.async_log_success_event" - ) + verbose_router_logger.debug("in ProviderBudgetLimiting.async_log_success_event") standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None )