diff --git a/litellm/router.py b/litellm/router.py index db6debd56..61fcd8b82 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -120,6 +120,7 @@ from litellm.types.router import ( LiteLLMParamsTypedDict, ModelGroupInfo, ModelInfo, + ProviderBudgetConfigType, RetryPolicy, RouterErrors, RouterGeneralSettings, @@ -235,9 +236,9 @@ class Router: "latency-based-routing", "cost-based-routing", "usage-based-routing-v2", - "provider-budget-routing", ] = "simple-shuffle", - routing_strategy_args: dict = {}, # just for latency-based, + routing_strategy_args: dict = {}, # just for latency-based + provider_budget_config: Optional[ProviderBudgetConfigType] = None, semaphore: Optional[asyncio.Semaphore] = None, alerting_config: Optional[AlertingConfig] = None, router_general_settings: Optional[ @@ -274,6 +275,7 @@ class Router: routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle". routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}. alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None. + provider_budget_config (ProviderBudgetConfig): Provider budget configuration. Use this to set llm_provider budget limits. example $100/day to OpenAI, $100/day to Azure, etc. Defaults to None. Returns: Router: An instance of the litellm.Router class. @@ -519,6 +521,7 @@ class Router: ) self.service_logger_obj = ServiceLogging() self.routing_strategy_args = routing_strategy_args + self.provider_budget_config = provider_budget_config self.retry_policy: Optional[RetryPolicy] = None if retry_policy is not None: if isinstance(retry_policy, dict): @@ -646,16 +649,6 @@ class Router: ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowestcost_logger) # type: ignore - elif ( - routing_strategy == RoutingStrategy.PROVIDER_BUDGET_LIMITING.value - or routing_strategy == RoutingStrategy.PROVIDER_BUDGET_LIMITING - ): - self.provider_budget_logger = ProviderBudgetLimiting( - router_cache=self.cache, - provider_budget_config=routing_strategy_args, - ) - if isinstance(litellm.callbacks, list): - litellm.callbacks.append(self.provider_budget_logger) # type: ignore else: pass @@ -5067,7 +5060,6 @@ class Router: and self.routing_strategy != "cost-based-routing" and self.routing_strategy != "latency-based-routing" and self.routing_strategy != "least-busy" - and self.routing_strategy != "provider-budget-routing" ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. return self.get_available_deployment( model=model, @@ -5183,16 +5175,6 @@ class Router: healthy_deployments=healthy_deployments, # type: ignore ) ) - elif ( - self.routing_strategy == "provider-budget-routing" - and self.provider_budget_logger is not None - ): - deployment = ( - await self.provider_budget_logger.async_get_available_deployments( - request_kwargs=request_kwargs, - healthy_deployments=healthy_deployments, # type: ignore - ) - ) else: deployment = None if deployment is None: