(feat - Router / Proxy ) Allow setting budget limits per LLM deployment (#7220)

* fix test_deployment_budget_limits_e2e_test

* refactor async_log_success_event to track spend for provider + deployment

* fix format

* rename class to RouterBudgetLimiting

* rename func

* rename types used for budgets

* add new types for deployment budgets

* add budget limits for deployments

* fix checking budgets set for provider

* update file names

* fix linting error

* _track_provider_remaining_budget_prometheus

* async_filter_deployments

* fix model list passed to router

* update error

* test_deployment_budgets_e2e_test_expect_to_fail

* fix test case

* run deployment budget limits
This commit is contained in:
Ishaan Jaff 2024-12-13 19:15:51 -08:00 committed by GitHub
parent b150faff90
commit 163529b40b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 557 additions and 151 deletions

View file

@ -1,13 +1,19 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Provider Budget Routing
# Budget Routing
LiteLLM Supports setting the following budgets:
- Provider budget - $100/day for OpenAI, $100/day for Azure.
- Model budget - $100/day for gpt-4 https://api-base-1, $100/day for gpt-4o https://api-base-2
## Provider Budgets
Use this to set budgets for LLM Providers - example $100/day for OpenAI, $100/day for Azure.
## Quick Start
### Quick Start
Set provider budgets in your `proxy_config.yaml` file
### Proxy Config setup
#### Proxy Config setup
```yaml
model_list:
- model_name: gpt-3.5-turbo
@ -42,7 +48,7 @@ general_settings:
master_key: sk-1234
```
### Make a test request
#### Make a test request
We expect the first request to succeed, and the second request to fail since we cross the budget for `openai`
@ -67,7 +73,7 @@ curl -i http://localhost:4000/v1/chat/completions \
</TabItem>
<TabItem label="Unsuccessful call" value = "not-allowed">
Expect this to fail since since `ishaan@berri.ai` in the request is PII
Expect this to fail since since we cross the budget for provider `openai`
```shell
curl -i http://localhost:4000/v1/chat/completions \
@ -101,7 +107,7 @@ Expected response on failure
## How provider budget routing works
#### How provider budget routing works
1. **Budget Tracking**:
- Uses Redis to track spend for each provider
@ -124,9 +130,9 @@ Expected response on failure
- Redis required for tracking spend across instances
- Provider names must be litellm provider names. See [Supported Providers](https://docs.litellm.ai/docs/providers)
## Monitoring Provider Remaining Budget
### Monitoring Provider Remaining Budget
### Get Budget, Spend Details
#### Get Budget, Spend Details
Use this endpoint to check current budget, spend and budget reset time for a provider
@ -171,7 +177,7 @@ Example Response
}
```
### Prometheus Metric
#### Prometheus Metric
LiteLLM will emit the following metric on Prometheus to track the remaining budget for each provider
@ -181,6 +187,88 @@ This metric indicates the remaining budget for a provider in dollars (USD)
litellm_provider_remaining_budget_metric{api_provider="openai"} 10
```
## Model Budgets
Use this to set budgets for models - example $10/day for openai/gpt-4o, $100/day for openai/gpt-4o-mini
### Quick Start
Set model budgets in your `proxy_config.yaml` file
```yaml
model_list:
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
max_budget: 0.000000000001 # (USD)
budget_duration: 1d # (Duration. can be 1s, 1m, 1h, 1d, 1mo)
- model_name: gpt-4o-mini
litellm_params:
model: openai/gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY
max_budget: 100 # (USD)
budget_duration: 30d # (Duration. can be 1s, 1m, 1h, 1d, 1mo)
```
#### Make a test request
We expect the first request to succeed, and the second request to fail since we cross the budget for `openai/gpt-4o`
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
<Tabs>
<TabItem label="Successful Call " value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "hi my name is test request"}
]
}'
```
</TabItem>
<TabItem label="Unsuccessful call" value = "not-allowed">
Expect this to fail since since we cross the budget for `openai/gpt-4o`
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "hi my name is test request"}
]
}'
```
Expected response on failure
```json
{
"error": {
"message": "No deployments available - crossed budget: Exceeded budget for deployment model_name: gpt-4o, litellm_params.model: openai/gpt-4o, model_id: dbe80f2fe2b2465f7bfa9a5e77e0f143a2eb3f7d167a8b55fb7fe31aed62587f: 0.00015250000000000002 >= 1e-12",
"type": "None",
"param": "None",
"code": "429"
}
}
```
</TabItem>
</Tabs>
## Multi-instance setup
If you are using a multi-instance setup, you will need to set the Redis host, port, and password in the `proxy_config.yaml` file. Redis is used to sync the spend across LiteLLM instances.

View file

@ -1,4 +1,14 @@
model_list:
- model_name: openai/*
- model_name: gpt-4o
litellm_params:
model: openai/*
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
max_budget: 0.000000000001 # (USD)
budget_duration: 1d # (Duration)
- model_name: gpt-4o-mini
litellm_params:
model: openai/gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY
max_budget: 100 # (USD)
budget_duration: 1d # (Duration)

View file

@ -2533,13 +2533,15 @@ async def provider_budgets() -> ProviderBudgetResponse:
provider_budget_response_dict: Dict[str, ProviderBudgetResponseObject] = {}
for _provider, _budget_info in provider_budget_config.items():
if llm_router.router_budget_logger is None:
raise ValueError("No router budget logger found")
_provider_spend = (
await llm_router.provider_budget_logger._get_current_provider_spend(
await llm_router.router_budget_logger._get_current_provider_spend(
_provider
)
or 0.0
)
_provider_budget_ttl = await llm_router.provider_budget_logger._get_current_provider_budget_reset_at(
_provider_budget_ttl = await llm_router.router_budget_logger._get_current_provider_budget_reset_at(
_provider
)
provider_budget_response_object = ProviderBudgetResponseObject(

View file

@ -56,12 +56,12 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.azure.azure import get_azure_ad_token_from_oidc
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_strategy.provider_budgets import ProviderBudgetLimiting
from litellm.router_strategy.simple_shuffle import simple_shuffle
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
from litellm.router_utils.batch_utils import (
@ -123,11 +123,11 @@ from litellm.types.router import (
CustomRoutingStrategyBase,
Deployment,
DeploymentTypedDict,
GenericBudgetConfigType,
LiteLLM_Params,
LiteLLMParamsTypedDict,
ModelGroupInfo,
ModelInfo,
ProviderBudgetConfigType,
RetryPolicy,
RouterCacheEnum,
RouterErrors,
@ -248,7 +248,7 @@ class Router:
"usage-based-routing-v2",
] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based
provider_budget_config: Optional[ProviderBudgetConfigType] = None,
provider_budget_config: Optional[GenericBudgetConfigType] = None,
alerting_config: Optional[AlertingConfig] = None,
router_general_settings: Optional[
RouterGeneralSettings
@ -537,10 +537,14 @@ class Router:
self.service_logger_obj = ServiceLogging()
self.routing_strategy_args = routing_strategy_args
self.provider_budget_config = provider_budget_config
if self.provider_budget_config is not None:
self.provider_budget_logger = ProviderBudgetLimiting(
self.router_budget_logger: Optional[RouterBudgetLimiting] = None
if RouterBudgetLimiting.should_init_router_budget_limiter(
model_list=model_list, provider_budget_config=self.provider_budget_config
):
self.router_budget_logger = RouterBudgetLimiting(
router_cache=self.cache,
provider_budget_config=self.provider_budget_config,
model_list=self.model_list,
)
self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None:
@ -5318,9 +5322,9 @@ class Router:
healthy_deployments=healthy_deployments,
)
if self.provider_budget_config is not None:
if self.router_budget_logger:
healthy_deployments = (
await self.provider_budget_logger.async_filter_deployments(
await self.router_budget_logger.async_filter_deployments(
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)

View file

@ -20,7 +20,7 @@ anthropic:
import asyncio
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union
import litellm
from litellm._logging import verbose_router_logger
@ -33,9 +33,10 @@ from litellm.router_utils.cooldown_callbacks import (
_get_prometheus_logger_from_callbacks,
)
from litellm.types.router import (
DeploymentTypedDict,
GenericBudgetConfigType,
GenericBudgetInfo,
LiteLLM_Params,
ProviderBudgetConfigType,
ProviderBudgetInfo,
RouterErrors,
)
from litellm.types.utils import StandardLoggingPayload
@ -50,35 +51,24 @@ else:
DEFAULT_REDIS_SYNC_INTERVAL = 1
class ProviderBudgetLimiting(CustomLogger):
def __init__(self, router_cache: DualCache, provider_budget_config: dict):
class RouterBudgetLimiting(CustomLogger):
def __init__(
self,
router_cache: DualCache,
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
self.router_cache = router_cache
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
# cast elements of provider_budget_config to ProviderBudgetInfo
for provider, config in provider_budget_config.items():
if config is None:
raise ValueError(
f"No budget config found for provider {provider}, provider_budget_config: {provider_budget_config}"
)
if not isinstance(config, ProviderBudgetInfo):
provider_budget_config[provider] = ProviderBudgetInfo(
budget_limit=config.get("budget_limit"),
time_period=config.get("time_period"),
)
asyncio.create_task(
self._init_provider_budget_in_cache(
provider=provider,
budget_config=provider_budget_config[provider],
)
)
self.provider_budget_config: ProviderBudgetConfigType = provider_budget_config
verbose_router_logger.debug(
f"Initalized Provider budget config: {self.provider_budget_config}"
self.provider_budget_config: Optional[GenericBudgetConfigType] = (
provider_budget_config
)
self.deployment_budget_config: Optional[GenericBudgetConfigType] = None
self._init_provider_budgets()
self._init_deployment_budgets(model_list=model_list)
# Add self to litellm callbacks if it's a list
if isinstance(litellm.callbacks, list):
@ -114,77 +104,132 @@ class ProviderBudgetLimiting(CustomLogger):
request_kwargs
)
# Collect all providers and their budget configs
# {"openai": ProviderBudgetInfo, "anthropic": ProviderBudgetInfo, "azure": None}
_provider_configs: Dict[str, Optional[ProviderBudgetInfo]] = {}
for deployment in healthy_deployments:
provider = self._get_llm_provider_for_deployment(deployment)
if provider is None:
continue
budget_config = self._get_budget_config_for_provider(provider)
_provider_configs[provider] = budget_config
# Filter out providers without budget config
provider_configs: Dict[str, ProviderBudgetInfo] = {
provider: config
for provider, config in _provider_configs.items()
if config is not None
}
# Build cache keys for batch retrieval
# Build combined cache keys for both provider and deployment budgets
cache_keys = []
for provider, config in provider_configs.items():
cache_keys.append(f"provider_spend:{provider}:{config.time_period}")
provider_configs: Dict[str, GenericBudgetInfo] = {}
deployment_configs: Dict[str, GenericBudgetInfo] = {}
# Fetch current spend for all providers using batch cache
_current_spends = await self.router_cache.async_batch_get_cache(
keys=cache_keys,
parent_otel_span=parent_otel_span,
)
current_spends: List = _current_spends or [0.0] * len(provider_configs)
# Map providers to their current spend values
provider_spend_map: Dict[str, float] = {}
for idx, provider in enumerate(provider_configs.keys()):
provider_spend_map[provider] = float(current_spends[idx] or 0.0)
# Filter healthy deployments based on budget constraints
deployment_above_budget_info: str = "" # used to return in error message
for deployment in healthy_deployments:
provider = self._get_llm_provider_for_deployment(deployment)
if provider is None:
continue
budget_config = provider_configs.get(provider)
# Check provider budgets
if self.provider_budget_config:
provider = self._get_llm_provider_for_deployment(deployment)
if provider is not None:
budget_config = self._get_budget_config_for_provider(provider)
if budget_config is not None:
provider_configs[provider] = budget_config
cache_keys.append(
f"provider_spend:{provider}:{budget_config.time_period}"
)
if not budget_config:
continue
# Check deployment budgets
if self.deployment_budget_config:
model_id = deployment.get("model_info", {}).get("id")
if model_id is not None:
budget_config = self._get_budget_config_for_deployment(model_id)
if budget_config is not None:
deployment_configs[model_id] = budget_config
cache_keys.append(
f"deployment_spend:{model_id}:{budget_config.time_period}"
)
current_spend = provider_spend_map.get(provider, 0.0)
budget_limit = budget_config.budget_limit
verbose_router_logger.debug(
f"Current spend for {provider}: {current_spend}, budget limit: {budget_limit}"
# Single cache read for all spend values
if len(cache_keys) > 0:
_current_spends = await self.router_cache.async_batch_get_cache(
keys=cache_keys,
parent_otel_span=parent_otel_span,
)
self._track_provider_remaining_budget_prometheus(
provider=provider,
spend=current_spend,
budget_limit=budget_limit,
current_spends: List = _current_spends or [0.0] * len(cache_keys)
# Map spends to their respective keys
spend_map: Dict[str, float] = {}
for idx, key in enumerate(cache_keys):
spend_map[key] = float(current_spends[idx] or 0.0)
potential_deployments, deployment_above_budget_info = (
self._filter_out_deployments_above_budget(
healthy_deployments=healthy_deployments,
provider_configs=provider_configs,
deployment_configs=deployment_configs,
spend_map=spend_map,
potential_deployments=potential_deployments,
)
)
if current_spend >= budget_limit:
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {budget_limit}"
verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n"
continue
if len(potential_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
)
potential_deployments.append(deployment)
return potential_deployments
else:
return healthy_deployments
if len(potential_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
)
def _filter_out_deployments_above_budget(
self,
potential_deployments: List[Dict[str, Any]],
healthy_deployments: List[Dict[str, Any]],
provider_configs: Dict[str, GenericBudgetInfo],
deployment_configs: Dict[str, GenericBudgetInfo],
spend_map: Dict[str, float],
) -> Tuple[List[Dict[str, Any]], str]:
"""
Filter out deployments that have exceeded their budget limit.
Follow budget checks are run here:
- Provider budget
- Deployment budget
return potential_deployments
Returns:
Tuple[List[Dict[str, Any]], str]:
- A tuple containing the filtered deployments
- A string containing debug information about deployments that exceeded their budget limit.
"""
# Filter deployments based on both provider and deployment budgets
deployment_above_budget_info: str = ""
for deployment in healthy_deployments:
is_within_budget = True
# Check provider budget
if self.provider_budget_config:
provider = self._get_llm_provider_for_deployment(deployment)
if provider in provider_configs:
config = provider_configs[provider]
current_spend = spend_map.get(
f"provider_spend:{provider}:{config.time_period}", 0.0
)
self._track_provider_remaining_budget_prometheus(
provider=provider,
spend=current_spend,
budget_limit=config.budget_limit,
)
if current_spend >= config.budget_limit:
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.budget_limit}"
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
# Check deployment budget
if self.deployment_budget_config and is_within_budget:
_model_name = deployment.get("model_name")
_litellm_params = deployment.get("litellm_params") or {}
_litellm_model_name = _litellm_params.get("model")
model_id = deployment.get("model_info", {}).get("id")
if model_id in deployment_configs:
config = deployment_configs[model_id]
current_spend = spend_map.get(
f"deployment_spend:{model_id}:{config.time_period}", 0.0
)
if current_spend >= config.budget_limit:
debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_limit}"
verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
if is_within_budget:
potential_deployments.append(deployment)
return potential_deployments, deployment_above_budget_info
async def _get_or_set_budget_start_time(
self, start_time_key: str, current_time: float, ttl_seconds: int
@ -256,7 +301,7 @@ class ProviderBudgetLimiting(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Original method now uses helper functions"""
verbose_router_logger.debug("in ProviderBudgetLimiting.async_log_success_event")
verbose_router_logger.debug("in RouterBudgetLimiting.async_log_success_event")
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
@ -264,7 +309,7 @@ class ProviderBudgetLimiting(CustomLogger):
raise ValueError("standard_logging_payload is required")
response_cost: float = standard_logging_payload.get("response_cost", 0)
model_id: str = str(standard_logging_payload.get("model_id", ""))
custom_llm_provider: str = kwargs.get("litellm_params", {}).get(
"custom_llm_provider", None
)
@ -272,14 +317,40 @@ class ProviderBudgetLimiting(CustomLogger):
raise ValueError("custom_llm_provider is required")
budget_config = self._get_budget_config_for_provider(custom_llm_provider)
if budget_config is None:
raise ValueError(
f"No budget config found for provider {custom_llm_provider}, self.provider_budget_config: {self.provider_budget_config}"
if budget_config:
# increment spend for provider
spend_key = (
f"provider_spend:{custom_llm_provider}:{budget_config.time_period}"
)
start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
await self._increment_spend_for_key(
budget_config=budget_config,
spend_key=spend_key,
start_time_key=start_time_key,
response_cost=response_cost,
)
spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}"
start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
deployment_budget_config = self._get_budget_config_for_deployment(model_id)
if deployment_budget_config:
# increment spend for specific deployment id
deployment_spend_key = (
f"deployment_spend:{model_id}:{deployment_budget_config.time_period}"
)
deployment_start_time_key = f"deployment_budget_start_time:{model_id}"
await self._increment_spend_for_key(
budget_config=deployment_budget_config,
spend_key=deployment_spend_key,
start_time_key=deployment_start_time_key,
response_cost=response_cost,
)
async def _increment_spend_for_key(
self,
budget_config: GenericBudgetInfo,
spend_key: str,
start_time_key: str,
response_cost: float,
):
current_time = datetime.now(timezone.utc).timestamp()
ttl_seconds = duration_in_seconds(budget_config.time_period)
@ -392,10 +463,20 @@ class ProviderBudgetLimiting(CustomLogger):
# 2. Fetch all current provider spend from Redis to update in-memory cache
cache_keys = []
for provider, config in self.provider_budget_config.items():
if config is None:
continue
cache_keys.append(f"provider_spend:{provider}:{config.time_period}")
if self.provider_budget_config is not None:
for provider, config in self.provider_budget_config.items():
if config is None:
continue
cache_keys.append(f"provider_spend:{provider}:{config.time_period}")
if self.deployment_budget_config is not None:
for model_id, config in self.deployment_budget_config.items():
if config is None:
continue
cache_keys.append(
f"deployment_spend:{model_id}:{config.time_period}"
)
# Batch fetch current spend values from Redis
redis_values = await self.router_cache.redis_cache.async_batch_get_cache(
@ -418,9 +499,19 @@ class ProviderBudgetLimiting(CustomLogger):
f"Error syncing in-memory cache with Redis: {str(e)}"
)
def _get_budget_config_for_deployment(
self,
model_id: str,
) -> Optional[GenericBudgetInfo]:
if self.deployment_budget_config is None:
return None
return self.deployment_budget_config.get(model_id, None)
def _get_budget_config_for_provider(
self, provider: str
) -> Optional[ProviderBudgetInfo]:
) -> Optional[GenericBudgetInfo]:
if self.provider_budget_config is None:
return None
return self.provider_budget_config.get(provider, None)
def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
@ -504,7 +595,7 @@ class ProviderBudgetLimiting(CustomLogger):
return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
async def _init_provider_budget_in_cache(
self, provider: str, budget_config: ProviderBudgetInfo
self, provider: str, budget_config: GenericBudgetInfo
):
"""
Initialize provider budget in cache by storing the following keys if they don't exist:
@ -527,3 +618,92 @@ class ProviderBudgetLimiting(CustomLogger):
await self.router_cache.async_set_cache(
key=spend_key, value=0.0, ttl=ttl_seconds
)
@staticmethod
def should_init_router_budget_limiter(
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
"""
Returns `True` if the router budget routing settings are set and RouterBudgetLimiting should be initialized
Either:
- provider_budget_config is set
- budgets are set for deployments in the model_list
"""
if provider_budget_config is not None:
return True
if model_list is None:
return False
for _model in model_list:
_litellm_params = _model.get("litellm_params", {})
if (
_litellm_params.get("max_budget")
or _litellm_params.get("budget_duration") is not None
):
return True
return False
def _init_provider_budgets(self):
if self.provider_budget_config is not None:
# cast elements of provider_budget_config to GenericBudgetInfo
for provider, config in self.provider_budget_config.items():
if config is None:
raise ValueError(
f"No budget config found for provider {provider}, provider_budget_config: {self.provider_budget_config}"
)
if not isinstance(config, GenericBudgetInfo):
self.provider_budget_config[provider] = GenericBudgetInfo(
budget_limit=config.get("budget_limit"),
time_period=config.get("time_period"),
)
asyncio.create_task(
self._init_provider_budget_in_cache(
provider=provider,
budget_config=self.provider_budget_config[provider],
)
)
verbose_router_logger.debug(
f"Initalized Provider budget config: {self.provider_budget_config}"
)
def _init_deployment_budgets(
self,
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
if model_list is None:
return
for _model in model_list:
_litellm_params = _model.get("litellm_params", {})
_model_info = _model.get("model_info", {})
_model_id = _model_info.get("id")
_max_budget = _litellm_params.get("max_budget")
_budget_duration = _litellm_params.get("budget_duration")
verbose_router_logger.debug(
f"Init Deployment Budget: max_budget: {_max_budget}, budget_duration: {_budget_duration}, model_id: {_model_id}"
)
if (
_max_budget is not None
and _budget_duration is not None
and _model_id is not None
):
_budget_config = GenericBudgetInfo(
time_period=_budget_duration,
budget_limit=_max_budget,
)
if self.deployment_budget_config is None:
self.deployment_budget_config = {}
self.deployment_budget_config[_model_id] = _budget_config
verbose_router_logger.debug(
f"Initialized Deployment Budget Config: {self.deployment_budget_config}"
)

View file

@ -172,6 +172,10 @@ class GenericLiteLLMParams(BaseModel):
max_file_size_mb: Optional[float] = None
# Deployment budgets
max_budget: Optional[float] = None
budget_duration: Optional[str] = None
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
def __init__(
@ -207,6 +211,9 @@ class GenericLiteLLMParams(BaseModel):
input_cost_per_second: Optional[float] = None,
output_cost_per_second: Optional[float] = None,
max_file_size_mb: Optional[float] = None,
# Deployment budgets
max_budget: Optional[float] = None,
budget_duration: Optional[str] = None,
**params,
):
args = locals()
@ -351,6 +358,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
# use this for tag-based routing
tags: Optional[List[str]]
# deployment budgets
max_budget: Optional[float]
budget_duration: Optional[str]
class DeploymentTypedDict(TypedDict, total=False):
model_name: Required[str]
@ -436,7 +447,7 @@ class RouterErrors(enum.Enum):
"Not allowed to access model due to tags configuration"
)
no_deployments_with_provider_budget_routing = (
"No deployments available - crossed budget for provider"
"No deployments available - crossed budget"
)
@ -635,12 +646,12 @@ class RoutingStrategy(enum.Enum):
PROVIDER_BUDGET_LIMITING = "provider-budget-routing"
class ProviderBudgetInfo(BaseModel):
class GenericBudgetInfo(BaseModel):
time_period: str # e.g., '1d', '30d'
budget_limit: float
ProviderBudgetConfigType = Dict[str, ProviderBudgetInfo]
GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
class RouterCacheEnum(enum.Enum):
@ -648,7 +659,7 @@ class RouterCacheEnum(enum.Enum):
RPM = "global_router:{id}:{model}:rpm:{current_minute}"
class ProviderBudgetWindowDetails(BaseModel):
class GenericBudgetWindowDetails(BaseModel):
"""Details about a provider's budget window"""
budget_start: float

View file

@ -1432,6 +1432,8 @@ all_litellm_params = [
"user_continue_message",
"fallback_depth",
"max_fallbacks",
"max_budget",
"budget_duration",
]

View file

@ -11,11 +11,11 @@ sys.path.insert(
) # Adds the parent directory to the system-path
import pytest
from litellm import Router
from litellm.router_strategy.provider_budgets import ProviderBudgetLimiting
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.types.router import (
RoutingStrategy,
ProviderBudgetConfigType,
ProviderBudgetInfo,
GenericBudgetConfigType,
GenericBudgetInfo,
)
from litellm.caching.caching import DualCache, RedisCache
import logging
@ -43,6 +43,9 @@ def cleanup_redis():
for key in redis_client.scan_iter("provider_spend:*"):
print("deleting key", key)
redis_client.delete(key)
for key in redis_client.scan_iter("deployment_spend:*"):
print("deleting key", key)
redis_client.delete(key)
except Exception as e:
print(f"Error cleaning up Redis: {str(e)}")
@ -59,9 +62,9 @@ async def test_provider_budgets_e2e_test():
"""
cleanup_redis()
# Modify for test
provider_budget_config: ProviderBudgetConfigType = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001),
"azure": ProviderBudgetInfo(time_period="1d", budget_limit=100),
provider_budget_config: GenericBudgetConfigType = {
"openai": GenericBudgetInfo(time_period="1d", budget_limit=0.000000000001),
"azure": GenericBudgetInfo(time_period="1d", budget_limit=100),
}
router = Router(
@ -175,7 +178,7 @@ async def test_get_llm_provider_for_deployment():
"""
cleanup_redis()
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
)
@ -208,11 +211,11 @@ async def test_get_budget_config_for_provider():
"""
cleanup_redis()
config = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100),
"anthropic": ProviderBudgetInfo(time_period="7d", budget_limit=500),
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
"anthropic": GenericBudgetInfo(time_period="7d", budget_limit=500),
}
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config=config
)
@ -244,18 +247,18 @@ async def test_prometheus_metric_tracking():
mock_prometheus = MagicMock(spec=PrometheusLogger)
# Setup provider budget limiting
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(),
provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100)
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100)
},
)
litellm._async_success_callback = [mock_prometheus]
provider_budget_config: ProviderBudgetConfigType = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001),
"azure": ProviderBudgetInfo(time_period="1d", budget_limit=100),
provider_budget_config: GenericBudgetConfigType = {
"openai": GenericBudgetInfo(time_period="1d", budget_limit=0.000000000001),
"azure": GenericBudgetInfo(time_period="1d", budget_limit=100),
}
router = Router(
@ -308,7 +311,7 @@ async def test_handle_new_budget_window():
Current
"""
cleanup_redis()
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
)
@ -349,7 +352,7 @@ async def test_get_or_set_budget_start_time():
scenario 2: existing start time in cache, should return existing start time
"""
cleanup_redis()
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
)
@ -390,7 +393,7 @@ async def test_increment_spend_in_current_window():
- Queue the increment operation to Redis
"""
cleanup_redis()
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
)
@ -437,11 +440,11 @@ async def test_sync_in_memory_spend_with_redis():
"""
cleanup_redis()
provider_budget_config = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100),
"anthropic": ProviderBudgetInfo(time_period="1d", budget_limit=200),
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
"anthropic": GenericBudgetInfo(time_period="1d", budget_limit=200),
}
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(
redis_cache=RedisCache(
host=os.getenv("REDIS_HOST"),
@ -491,10 +494,10 @@ async def test_get_current_provider_spend():
3. Provider with budget config and spend returns correct value
"""
cleanup_redis()
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(),
provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100),
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
},
)
@ -526,7 +529,7 @@ async def test_get_current_provider_budget_reset_at():
3. Provider with budget config and TTL returns correct ISO timestamp
"""
cleanup_redis()
provider_budget = ProviderBudgetLimiting(
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(
redis_cache=RedisCache(
host=os.getenv("REDIS_HOST"),
@ -535,8 +538,8 @@ async def test_get_current_provider_budget_reset_at():
)
),
provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100),
"vertex_ai": ProviderBudgetInfo(time_period="1h", budget_limit=100),
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
"vertex_ai": GenericBudgetInfo(time_period="1h", budget_limit=100),
},
)
@ -565,3 +568,109 @@ async def test_get_current_provider_budget_reset_at():
# Allow for small time differences (within 5 seconds)
time_difference = abs((reset_time - expected_time).total_seconds())
assert time_difference < 5
@pytest.mark.asyncio
async def test_deployment_budget_limits_e2e_test():
"""
Expected behavior:
- First request forced to openai/gpt-4o
- Hit budget limit for openai/gpt-4o
- Next 3 requests all go to openai/gpt-4o-mini
"""
litellm.set_verbose = True
cleanup_redis()
# Modify for test
router = Router(
model_list=[
{
"model_name": "gpt-4o", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "openai/gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
"max_budget": 0.000000000001,
"budget_duration": "1d",
},
"model_info": {"id": "openai-gpt-4o"},
},
{
"model_name": "gpt-4o", # openai model name
"litellm_params": {
"model": "openai/gpt-4o-mini",
"api_key": os.getenv("OPENAI_API_KEY"),
"max_budget": 10,
"budget_duration": "20d",
},
"model_info": {"id": "openai-gpt-4o-mini"},
},
],
)
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="openai-gpt-4o",
)
print(response)
await asyncio.sleep(2.5)
for _ in range(3):
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="gpt-4o",
)
print(response)
await asyncio.sleep(1)
print("response.hidden_params", response._hidden_params)
assert response._hidden_params.get("model_id") == "openai-gpt-4o-mini"
@pytest.mark.asyncio
async def test_deployment_budgets_e2e_test_expect_to_fail():
"""
Expected behavior:
- first request passes, all subsequent requests fail
"""
cleanup_redis()
router = Router(
model_list=[
{
"model_name": "openai/gpt-4o-mini", # openai model name
"litellm_params": {
"model": "openai/gpt-4o-mini",
"max_budget": 0.000000000001,
"budget_duration": "1d",
},
},
],
redis_host=os.getenv("REDIS_HOST"),
redis_port=int(os.getenv("REDIS_PORT")),
redis_password=os.getenv("REDIS_PASSWORD"),
)
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="openai/gpt-4o-mini",
)
print(response)
await asyncio.sleep(2.5)
for _ in range(3):
with pytest.raises(Exception) as exc_info:
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="openai/gpt-4o-mini",
)
print(response)
print("response.hidden_params", response._hidden_params)
await asyncio.sleep(0.5)
# Verify the error is related to budget exceeded
assert "Exceeded budget for deployment" in str(exc_info.value)