use as flag, not routing strat

This commit is contained in:
Ishaan Jaff 2024-11-19 17:09:29 -08:00
parent d14347df95
commit a7e96ff9ed

View file

@ -120,6 +120,7 @@ from litellm.types.router import (
LiteLLMParamsTypedDict,
ModelGroupInfo,
ModelInfo,
ProviderBudgetConfigType,
RetryPolicy,
RouterErrors,
RouterGeneralSettings,
@ -235,9 +236,9 @@ class Router:
"latency-based-routing",
"cost-based-routing",
"usage-based-routing-v2",
"provider-budget-routing",
] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based,
routing_strategy_args: dict = {}, # just for latency-based
provider_budget_config: Optional[ProviderBudgetConfigType] = None,
semaphore: Optional[asyncio.Semaphore] = None,
alerting_config: Optional[AlertingConfig] = None,
router_general_settings: Optional[
@ -274,6 +275,7 @@ class Router:
routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle".
routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}.
alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None.
provider_budget_config (ProviderBudgetConfig): Provider budget configuration. Use this to set llm_provider budget limits. example $100/day to OpenAI, $100/day to Azure, etc. Defaults to None.
Returns:
Router: An instance of the litellm.Router class.
@ -519,6 +521,7 @@ class Router:
)
self.service_logger_obj = ServiceLogging()
self.routing_strategy_args = routing_strategy_args
self.provider_budget_config = provider_budget_config
self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None:
if isinstance(retry_policy, dict):
@ -646,16 +649,6 @@ class Router:
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestcost_logger) # type: ignore
elif (
routing_strategy == RoutingStrategy.PROVIDER_BUDGET_LIMITING.value
or routing_strategy == RoutingStrategy.PROVIDER_BUDGET_LIMITING
):
self.provider_budget_logger = ProviderBudgetLimiting(
router_cache=self.cache,
provider_budget_config=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.provider_budget_logger) # type: ignore
else:
pass
@ -5067,7 +5060,6 @@ class Router:
and self.routing_strategy != "cost-based-routing"
and self.routing_strategy != "latency-based-routing"
and self.routing_strategy != "least-busy"
and self.routing_strategy != "provider-budget-routing"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment(
model=model,
@ -5183,16 +5175,6 @@ class Router:
healthy_deployments=healthy_deployments, # type: ignore
)
)
elif (
self.routing_strategy == "provider-budget-routing"
and self.provider_budget_logger is not None
):
deployment = (
await self.provider_budget_logger.async_get_available_deployments(
request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments, # type: ignore
)
)
else:
deployment = None
if deployment is None: