diff --git a/litellm/router.py b/litellm/router.py index 61fcd8b82..f724c96c4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -522,6 +522,11 @@ class Router: self.service_logger_obj = ServiceLogging() self.routing_strategy_args = routing_strategy_args self.provider_budget_config = provider_budget_config + if self.provider_budget_config is not None: + self.provider_budget_logger = ProviderBudgetLimiting( + router_cache=self.cache, + provider_budget_config=self.provider_budget_config, + ) self.retry_policy: Optional[RetryPolicy] = None if retry_policy is not None: if isinstance(retry_policy, dict): @@ -5114,6 +5119,14 @@ class Router: healthy_deployments=healthy_deployments, ) + if self.provider_budget_config is not None: + healthy_deployments = ( + await self.provider_budget_logger.async_filter_deployments( + healthy_deployments=healthy_deployments, + request_kwargs=request_kwargs, + ) + ) + if len(healthy_deployments) == 0: exception = await async_raise_no_deployment_exception( litellm_router_instance=self, diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 590716c1b..9610af149 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -55,19 +55,23 @@ class ProviderBudgetLimiting(CustomLogger): async def async_filter_deployments( self, - healthy_deployments: List[Dict], + healthy_deployments: Union[List[Dict[str, Any]], Dict[str, Any]], request_kwargs: Optional[Dict] = None, - ) -> Optional[Dict]: + ): """ - For all deployments, check their LLM provider budget is less than their budget limit. + Filter out deployments that have exceeded their provider budget limit. - If multiple deployments are available, randomly pick one. Example: if deployment = openai/gpt-3.5-turbo - check if openai budget limit is exceeded - + and openai spend > openai budget limit + then skip this deployment """ + + # If a single deployment is passed, convert it to a list + if isinstance(healthy_deployments, dict): + healthy_deployments = [healthy_deployments] + potential_deployments: List[Dict] = [] # Extract the parent OpenTelemetry span for tracing @@ -134,8 +138,7 @@ class ProviderBudgetLimiting(CustomLogger): potential_deployments.append(deployment) - # Randomly pick one deployment from potential deployments - return random.choice(potential_deployments) if potential_deployments else None + return potential_deployments async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): """