(feat - Router / Proxy ) Allow setting budget limits per LLM deployment (#7220)

* fix test_deployment_budget_limits_e2e_test

* refactor async_log_success_event to track spend for provider + deployment

* fix format

* rename class to RouterBudgetLimiting

* rename func

* rename types used for budgets

* add new types for deployment budgets

* add budget limits for deployments

* fix checking budgets set for provider

* update file names

* fix linting error

* _track_provider_remaining_budget_prometheus

* async_filter_deployments

* fix model list passed to router

* update error

* test_deployment_budgets_e2e_test_expect_to_fail

* fix test case

* run deployment budget limits
This commit is contained in:
Ishaan Jaff 2024-12-13 19:15:51 -08:00 committed by GitHub
parent b150faff90
commit 163529b40b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 557 additions and 151 deletions

View file

@ -1,13 +1,19 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# Provider Budget Routing # Budget Routing
LiteLLM Supports setting the following budgets:
- Provider budget - $100/day for OpenAI, $100/day for Azure.
- Model budget - $100/day for gpt-4 https://api-base-1, $100/day for gpt-4o https://api-base-2
## Provider Budgets
Use this to set budgets for LLM Providers - example $100/day for OpenAI, $100/day for Azure. Use this to set budgets for LLM Providers - example $100/day for OpenAI, $100/day for Azure.
## Quick Start ### Quick Start
Set provider budgets in your `proxy_config.yaml` file Set provider budgets in your `proxy_config.yaml` file
### Proxy Config setup #### Proxy Config setup
```yaml ```yaml
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
@ -42,7 +48,7 @@ general_settings:
master_key: sk-1234 master_key: sk-1234
``` ```
### Make a test request #### Make a test request
We expect the first request to succeed, and the second request to fail since we cross the budget for `openai` We expect the first request to succeed, and the second request to fail since we cross the budget for `openai`
@ -67,7 +73,7 @@ curl -i http://localhost:4000/v1/chat/completions \
</TabItem> </TabItem>
<TabItem label="Unsuccessful call" value = "not-allowed"> <TabItem label="Unsuccessful call" value = "not-allowed">
Expect this to fail since since `ishaan@berri.ai` in the request is PII Expect this to fail since since we cross the budget for provider `openai`
```shell ```shell
curl -i http://localhost:4000/v1/chat/completions \ curl -i http://localhost:4000/v1/chat/completions \
@ -101,7 +107,7 @@ Expected response on failure
## How provider budget routing works #### How provider budget routing works
1. **Budget Tracking**: 1. **Budget Tracking**:
- Uses Redis to track spend for each provider - Uses Redis to track spend for each provider
@ -124,9 +130,9 @@ Expected response on failure
- Redis required for tracking spend across instances - Redis required for tracking spend across instances
- Provider names must be litellm provider names. See [Supported Providers](https://docs.litellm.ai/docs/providers) - Provider names must be litellm provider names. See [Supported Providers](https://docs.litellm.ai/docs/providers)
## Monitoring Provider Remaining Budget ### Monitoring Provider Remaining Budget
### Get Budget, Spend Details #### Get Budget, Spend Details
Use this endpoint to check current budget, spend and budget reset time for a provider Use this endpoint to check current budget, spend and budget reset time for a provider
@ -171,7 +177,7 @@ Example Response
} }
``` ```
### Prometheus Metric #### Prometheus Metric
LiteLLM will emit the following metric on Prometheus to track the remaining budget for each provider LiteLLM will emit the following metric on Prometheus to track the remaining budget for each provider
@ -181,6 +187,88 @@ This metric indicates the remaining budget for a provider in dollars (USD)
litellm_provider_remaining_budget_metric{api_provider="openai"} 10 litellm_provider_remaining_budget_metric{api_provider="openai"} 10
``` ```
## Model Budgets
Use this to set budgets for models - example $10/day for openai/gpt-4o, $100/day for openai/gpt-4o-mini
### Quick Start
Set model budgets in your `proxy_config.yaml` file
```yaml
model_list:
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
max_budget: 0.000000000001 # (USD)
budget_duration: 1d # (Duration. can be 1s, 1m, 1h, 1d, 1mo)
- model_name: gpt-4o-mini
litellm_params:
model: openai/gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY
max_budget: 100 # (USD)
budget_duration: 30d # (Duration. can be 1s, 1m, 1h, 1d, 1mo)
```
#### Make a test request
We expect the first request to succeed, and the second request to fail since we cross the budget for `openai/gpt-4o`
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
<Tabs>
<TabItem label="Successful Call " value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "hi my name is test request"}
]
}'
```
</TabItem>
<TabItem label="Unsuccessful call" value = "not-allowed">
Expect this to fail since since we cross the budget for `openai/gpt-4o`
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "hi my name is test request"}
]
}'
```
Expected response on failure
```json
{
"error": {
"message": "No deployments available - crossed budget: Exceeded budget for deployment model_name: gpt-4o, litellm_params.model: openai/gpt-4o, model_id: dbe80f2fe2b2465f7bfa9a5e77e0f143a2eb3f7d167a8b55fb7fe31aed62587f: 0.00015250000000000002 >= 1e-12",
"type": "None",
"param": "None",
"code": "429"
}
}
```
</TabItem>
</Tabs>
## Multi-instance setup ## Multi-instance setup
If you are using a multi-instance setup, you will need to set the Redis host, port, and password in the `proxy_config.yaml` file. Redis is used to sync the spend across LiteLLM instances. If you are using a multi-instance setup, you will need to set the Redis host, port, and password in the `proxy_config.yaml` file. Redis is used to sync the spend across LiteLLM instances.

View file

@ -1,4 +1,14 @@
model_list: model_list:
- model_name: openai/* - model_name: gpt-4o
litellm_params: litellm_params:
model: openai/* model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
max_budget: 0.000000000001 # (USD)
budget_duration: 1d # (Duration)
- model_name: gpt-4o-mini
litellm_params:
model: openai/gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY
max_budget: 100 # (USD)
budget_duration: 1d # (Duration)

View file

@ -2533,13 +2533,15 @@ async def provider_budgets() -> ProviderBudgetResponse:
provider_budget_response_dict: Dict[str, ProviderBudgetResponseObject] = {} provider_budget_response_dict: Dict[str, ProviderBudgetResponseObject] = {}
for _provider, _budget_info in provider_budget_config.items(): for _provider, _budget_info in provider_budget_config.items():
if llm_router.router_budget_logger is None:
raise ValueError("No router budget logger found")
_provider_spend = ( _provider_spend = (
await llm_router.provider_budget_logger._get_current_provider_spend( await llm_router.router_budget_logger._get_current_provider_spend(
_provider _provider
) )
or 0.0 or 0.0
) )
_provider_budget_ttl = await llm_router.provider_budget_logger._get_current_provider_budget_reset_at( _provider_budget_ttl = await llm_router.router_budget_logger._get_current_provider_budget_reset_at(
_provider _provider
) )
provider_budget_response_object = ProviderBudgetResponseObject( provider_budget_response_object = ProviderBudgetResponseObject(

View file

@ -56,12 +56,12 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.azure.azure import get_azure_ad_token_from_oidc from litellm.llms.azure.azure import get_azure_ad_token_from_oidc
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_strategy.provider_budgets import ProviderBudgetLimiting
from litellm.router_strategy.simple_shuffle import simple_shuffle from litellm.router_strategy.simple_shuffle import simple_shuffle
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
from litellm.router_utils.batch_utils import ( from litellm.router_utils.batch_utils import (
@ -123,11 +123,11 @@ from litellm.types.router import (
CustomRoutingStrategyBase, CustomRoutingStrategyBase,
Deployment, Deployment,
DeploymentTypedDict, DeploymentTypedDict,
GenericBudgetConfigType,
LiteLLM_Params, LiteLLM_Params,
LiteLLMParamsTypedDict, LiteLLMParamsTypedDict,
ModelGroupInfo, ModelGroupInfo,
ModelInfo, ModelInfo,
ProviderBudgetConfigType,
RetryPolicy, RetryPolicy,
RouterCacheEnum, RouterCacheEnum,
RouterErrors, RouterErrors,
@ -248,7 +248,7 @@ class Router:
"usage-based-routing-v2", "usage-based-routing-v2",
] = "simple-shuffle", ] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing_strategy_args: dict = {}, # just for latency-based
provider_budget_config: Optional[ProviderBudgetConfigType] = None, provider_budget_config: Optional[GenericBudgetConfigType] = None,
alerting_config: Optional[AlertingConfig] = None, alerting_config: Optional[AlertingConfig] = None,
router_general_settings: Optional[ router_general_settings: Optional[
RouterGeneralSettings RouterGeneralSettings
@ -537,10 +537,14 @@ class Router:
self.service_logger_obj = ServiceLogging() self.service_logger_obj = ServiceLogging()
self.routing_strategy_args = routing_strategy_args self.routing_strategy_args = routing_strategy_args
self.provider_budget_config = provider_budget_config self.provider_budget_config = provider_budget_config
if self.provider_budget_config is not None: self.router_budget_logger: Optional[RouterBudgetLimiting] = None
self.provider_budget_logger = ProviderBudgetLimiting( if RouterBudgetLimiting.should_init_router_budget_limiter(
model_list=model_list, provider_budget_config=self.provider_budget_config
):
self.router_budget_logger = RouterBudgetLimiting(
router_cache=self.cache, router_cache=self.cache,
provider_budget_config=self.provider_budget_config, provider_budget_config=self.provider_budget_config,
model_list=self.model_list,
) )
self.retry_policy: Optional[RetryPolicy] = None self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None: if retry_policy is not None:
@ -5318,9 +5322,9 @@ class Router:
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
) )
if self.provider_budget_config is not None: if self.router_budget_logger:
healthy_deployments = ( healthy_deployments = (
await self.provider_budget_logger.async_filter_deployments( await self.router_budget_logger.async_filter_deployments(
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
) )

View file

@ -20,7 +20,7 @@ anthropic:
import asyncio import asyncio
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union
import litellm import litellm
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
@ -33,9 +33,10 @@ from litellm.router_utils.cooldown_callbacks import (
_get_prometheus_logger_from_callbacks, _get_prometheus_logger_from_callbacks,
) )
from litellm.types.router import ( from litellm.types.router import (
DeploymentTypedDict,
GenericBudgetConfigType,
GenericBudgetInfo,
LiteLLM_Params, LiteLLM_Params,
ProviderBudgetConfigType,
ProviderBudgetInfo,
RouterErrors, RouterErrors,
) )
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
@ -50,35 +51,24 @@ else:
DEFAULT_REDIS_SYNC_INTERVAL = 1 DEFAULT_REDIS_SYNC_INTERVAL = 1
class ProviderBudgetLimiting(CustomLogger): class RouterBudgetLimiting(CustomLogger):
def __init__(self, router_cache: DualCache, provider_budget_config: dict): def __init__(
self,
router_cache: DualCache,
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
self.router_cache = router_cache self.router_cache = router_cache
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
self.provider_budget_config: Optional[GenericBudgetConfigType] = (
# cast elements of provider_budget_config to ProviderBudgetInfo provider_budget_config
for provider, config in provider_budget_config.items():
if config is None:
raise ValueError(
f"No budget config found for provider {provider}, provider_budget_config: {provider_budget_config}"
)
if not isinstance(config, ProviderBudgetInfo):
provider_budget_config[provider] = ProviderBudgetInfo(
budget_limit=config.get("budget_limit"),
time_period=config.get("time_period"),
)
asyncio.create_task(
self._init_provider_budget_in_cache(
provider=provider,
budget_config=provider_budget_config[provider],
)
)
self.provider_budget_config: ProviderBudgetConfigType = provider_budget_config
verbose_router_logger.debug(
f"Initalized Provider budget config: {self.provider_budget_config}"
) )
self.deployment_budget_config: Optional[GenericBudgetConfigType] = None
self._init_provider_budgets()
self._init_deployment_budgets(model_list=model_list)
# Add self to litellm callbacks if it's a list # Add self to litellm callbacks if it's a list
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
@ -114,77 +104,132 @@ class ProviderBudgetLimiting(CustomLogger):
request_kwargs request_kwargs
) )
# Collect all providers and their budget configs # Build combined cache keys for both provider and deployment budgets
# {"openai": ProviderBudgetInfo, "anthropic": ProviderBudgetInfo, "azure": None}
_provider_configs: Dict[str, Optional[ProviderBudgetInfo]] = {}
for deployment in healthy_deployments:
provider = self._get_llm_provider_for_deployment(deployment)
if provider is None:
continue
budget_config = self._get_budget_config_for_provider(provider)
_provider_configs[provider] = budget_config
# Filter out providers without budget config
provider_configs: Dict[str, ProviderBudgetInfo] = {
provider: config
for provider, config in _provider_configs.items()
if config is not None
}
# Build cache keys for batch retrieval
cache_keys = [] cache_keys = []
for provider, config in provider_configs.items(): provider_configs: Dict[str, GenericBudgetInfo] = {}
cache_keys.append(f"provider_spend:{provider}:{config.time_period}") deployment_configs: Dict[str, GenericBudgetInfo] = {}
# Fetch current spend for all providers using batch cache
_current_spends = await self.router_cache.async_batch_get_cache(
keys=cache_keys,
parent_otel_span=parent_otel_span,
)
current_spends: List = _current_spends or [0.0] * len(provider_configs)
# Map providers to their current spend values
provider_spend_map: Dict[str, float] = {}
for idx, provider in enumerate(provider_configs.keys()):
provider_spend_map[provider] = float(current_spends[idx] or 0.0)
# Filter healthy deployments based on budget constraints
deployment_above_budget_info: str = "" # used to return in error message
for deployment in healthy_deployments: for deployment in healthy_deployments:
provider = self._get_llm_provider_for_deployment(deployment) # Check provider budgets
if provider is None: if self.provider_budget_config:
continue provider = self._get_llm_provider_for_deployment(deployment)
budget_config = provider_configs.get(provider) if provider is not None:
budget_config = self._get_budget_config_for_provider(provider)
if budget_config is not None:
provider_configs[provider] = budget_config
cache_keys.append(
f"provider_spend:{provider}:{budget_config.time_period}"
)
if not budget_config: # Check deployment budgets
continue if self.deployment_budget_config:
model_id = deployment.get("model_info", {}).get("id")
if model_id is not None:
budget_config = self._get_budget_config_for_deployment(model_id)
if budget_config is not None:
deployment_configs[model_id] = budget_config
cache_keys.append(
f"deployment_spend:{model_id}:{budget_config.time_period}"
)
current_spend = provider_spend_map.get(provider, 0.0) # Single cache read for all spend values
budget_limit = budget_config.budget_limit if len(cache_keys) > 0:
_current_spends = await self.router_cache.async_batch_get_cache(
verbose_router_logger.debug( keys=cache_keys,
f"Current spend for {provider}: {current_spend}, budget limit: {budget_limit}" parent_otel_span=parent_otel_span,
) )
self._track_provider_remaining_budget_prometheus( current_spends: List = _current_spends or [0.0] * len(cache_keys)
provider=provider,
spend=current_spend, # Map spends to their respective keys
budget_limit=budget_limit, spend_map: Dict[str, float] = {}
for idx, key in enumerate(cache_keys):
spend_map[key] = float(current_spends[idx] or 0.0)
potential_deployments, deployment_above_budget_info = (
self._filter_out_deployments_above_budget(
healthy_deployments=healthy_deployments,
provider_configs=provider_configs,
deployment_configs=deployment_configs,
spend_map=spend_map,
potential_deployments=potential_deployments,
)
) )
if current_spend >= budget_limit: if len(potential_deployments) == 0:
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {budget_limit}" raise ValueError(
verbose_router_logger.debug(debug_msg) f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
deployment_above_budget_info += f"{debug_msg}\n" )
continue
potential_deployments.append(deployment) return potential_deployments
else:
return healthy_deployments
if len(potential_deployments) == 0: def _filter_out_deployments_above_budget(
raise ValueError( self,
f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}" potential_deployments: List[Dict[str, Any]],
) healthy_deployments: List[Dict[str, Any]],
provider_configs: Dict[str, GenericBudgetInfo],
deployment_configs: Dict[str, GenericBudgetInfo],
spend_map: Dict[str, float],
) -> Tuple[List[Dict[str, Any]], str]:
"""
Filter out deployments that have exceeded their budget limit.
Follow budget checks are run here:
- Provider budget
- Deployment budget
return potential_deployments Returns:
Tuple[List[Dict[str, Any]], str]:
- A tuple containing the filtered deployments
- A string containing debug information about deployments that exceeded their budget limit.
"""
# Filter deployments based on both provider and deployment budgets
deployment_above_budget_info: str = ""
for deployment in healthy_deployments:
is_within_budget = True
# Check provider budget
if self.provider_budget_config:
provider = self._get_llm_provider_for_deployment(deployment)
if provider in provider_configs:
config = provider_configs[provider]
current_spend = spend_map.get(
f"provider_spend:{provider}:{config.time_period}", 0.0
)
self._track_provider_remaining_budget_prometheus(
provider=provider,
spend=current_spend,
budget_limit=config.budget_limit,
)
if current_spend >= config.budget_limit:
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.budget_limit}"
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
# Check deployment budget
if self.deployment_budget_config and is_within_budget:
_model_name = deployment.get("model_name")
_litellm_params = deployment.get("litellm_params") or {}
_litellm_model_name = _litellm_params.get("model")
model_id = deployment.get("model_info", {}).get("id")
if model_id in deployment_configs:
config = deployment_configs[model_id]
current_spend = spend_map.get(
f"deployment_spend:{model_id}:{config.time_period}", 0.0
)
if current_spend >= config.budget_limit:
debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_limit}"
verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
if is_within_budget:
potential_deployments.append(deployment)
return potential_deployments, deployment_above_budget_info
async def _get_or_set_budget_start_time( async def _get_or_set_budget_start_time(
self, start_time_key: str, current_time: float, ttl_seconds: int self, start_time_key: str, current_time: float, ttl_seconds: int
@ -256,7 +301,7 @@ class ProviderBudgetLimiting(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Original method now uses helper functions""" """Original method now uses helper functions"""
verbose_router_logger.debug("in ProviderBudgetLimiting.async_log_success_event") verbose_router_logger.debug("in RouterBudgetLimiting.async_log_success_event")
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None "standard_logging_object", None
) )
@ -264,7 +309,7 @@ class ProviderBudgetLimiting(CustomLogger):
raise ValueError("standard_logging_payload is required") raise ValueError("standard_logging_payload is required")
response_cost: float = standard_logging_payload.get("response_cost", 0) response_cost: float = standard_logging_payload.get("response_cost", 0)
model_id: str = str(standard_logging_payload.get("model_id", ""))
custom_llm_provider: str = kwargs.get("litellm_params", {}).get( custom_llm_provider: str = kwargs.get("litellm_params", {}).get(
"custom_llm_provider", None "custom_llm_provider", None
) )
@ -272,14 +317,40 @@ class ProviderBudgetLimiting(CustomLogger):
raise ValueError("custom_llm_provider is required") raise ValueError("custom_llm_provider is required")
budget_config = self._get_budget_config_for_provider(custom_llm_provider) budget_config = self._get_budget_config_for_provider(custom_llm_provider)
if budget_config is None: if budget_config:
raise ValueError( # increment spend for provider
f"No budget config found for provider {custom_llm_provider}, self.provider_budget_config: {self.provider_budget_config}" spend_key = (
f"provider_spend:{custom_llm_provider}:{budget_config.time_period}"
)
start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
await self._increment_spend_for_key(
budget_config=budget_config,
spend_key=spend_key,
start_time_key=start_time_key,
response_cost=response_cost,
) )
spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" deployment_budget_config = self._get_budget_config_for_deployment(model_id)
start_time_key = f"provider_budget_start_time:{custom_llm_provider}" if deployment_budget_config:
# increment spend for specific deployment id
deployment_spend_key = (
f"deployment_spend:{model_id}:{deployment_budget_config.time_period}"
)
deployment_start_time_key = f"deployment_budget_start_time:{model_id}"
await self._increment_spend_for_key(
budget_config=deployment_budget_config,
spend_key=deployment_spend_key,
start_time_key=deployment_start_time_key,
response_cost=response_cost,
)
async def _increment_spend_for_key(
self,
budget_config: GenericBudgetInfo,
spend_key: str,
start_time_key: str,
response_cost: float,
):
current_time = datetime.now(timezone.utc).timestamp() current_time = datetime.now(timezone.utc).timestamp()
ttl_seconds = duration_in_seconds(budget_config.time_period) ttl_seconds = duration_in_seconds(budget_config.time_period)
@ -392,10 +463,20 @@ class ProviderBudgetLimiting(CustomLogger):
# 2. Fetch all current provider spend from Redis to update in-memory cache # 2. Fetch all current provider spend from Redis to update in-memory cache
cache_keys = [] cache_keys = []
for provider, config in self.provider_budget_config.items():
if config is None: if self.provider_budget_config is not None:
continue for provider, config in self.provider_budget_config.items():
cache_keys.append(f"provider_spend:{provider}:{config.time_period}") if config is None:
continue
cache_keys.append(f"provider_spend:{provider}:{config.time_period}")
if self.deployment_budget_config is not None:
for model_id, config in self.deployment_budget_config.items():
if config is None:
continue
cache_keys.append(
f"deployment_spend:{model_id}:{config.time_period}"
)
# Batch fetch current spend values from Redis # Batch fetch current spend values from Redis
redis_values = await self.router_cache.redis_cache.async_batch_get_cache( redis_values = await self.router_cache.redis_cache.async_batch_get_cache(
@ -418,9 +499,19 @@ class ProviderBudgetLimiting(CustomLogger):
f"Error syncing in-memory cache with Redis: {str(e)}" f"Error syncing in-memory cache with Redis: {str(e)}"
) )
def _get_budget_config_for_deployment(
self,
model_id: str,
) -> Optional[GenericBudgetInfo]:
if self.deployment_budget_config is None:
return None
return self.deployment_budget_config.get(model_id, None)
def _get_budget_config_for_provider( def _get_budget_config_for_provider(
self, provider: str self, provider: str
) -> Optional[ProviderBudgetInfo]: ) -> Optional[GenericBudgetInfo]:
if self.provider_budget_config is None:
return None
return self.provider_budget_config.get(provider, None) return self.provider_budget_config.get(provider, None)
def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]: def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
@ -504,7 +595,7 @@ class ProviderBudgetLimiting(CustomLogger):
return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat() return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
async def _init_provider_budget_in_cache( async def _init_provider_budget_in_cache(
self, provider: str, budget_config: ProviderBudgetInfo self, provider: str, budget_config: GenericBudgetInfo
): ):
""" """
Initialize provider budget in cache by storing the following keys if they don't exist: Initialize provider budget in cache by storing the following keys if they don't exist:
@ -527,3 +618,92 @@ class ProviderBudgetLimiting(CustomLogger):
await self.router_cache.async_set_cache( await self.router_cache.async_set_cache(
key=spend_key, value=0.0, ttl=ttl_seconds key=spend_key, value=0.0, ttl=ttl_seconds
) )
@staticmethod
def should_init_router_budget_limiter(
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
"""
Returns `True` if the router budget routing settings are set and RouterBudgetLimiting should be initialized
Either:
- provider_budget_config is set
- budgets are set for deployments in the model_list
"""
if provider_budget_config is not None:
return True
if model_list is None:
return False
for _model in model_list:
_litellm_params = _model.get("litellm_params", {})
if (
_litellm_params.get("max_budget")
or _litellm_params.get("budget_duration") is not None
):
return True
return False
def _init_provider_budgets(self):
if self.provider_budget_config is not None:
# cast elements of provider_budget_config to GenericBudgetInfo
for provider, config in self.provider_budget_config.items():
if config is None:
raise ValueError(
f"No budget config found for provider {provider}, provider_budget_config: {self.provider_budget_config}"
)
if not isinstance(config, GenericBudgetInfo):
self.provider_budget_config[provider] = GenericBudgetInfo(
budget_limit=config.get("budget_limit"),
time_period=config.get("time_period"),
)
asyncio.create_task(
self._init_provider_budget_in_cache(
provider=provider,
budget_config=self.provider_budget_config[provider],
)
)
verbose_router_logger.debug(
f"Initalized Provider budget config: {self.provider_budget_config}"
)
def _init_deployment_budgets(
self,
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
if model_list is None:
return
for _model in model_list:
_litellm_params = _model.get("litellm_params", {})
_model_info = _model.get("model_info", {})
_model_id = _model_info.get("id")
_max_budget = _litellm_params.get("max_budget")
_budget_duration = _litellm_params.get("budget_duration")
verbose_router_logger.debug(
f"Init Deployment Budget: max_budget: {_max_budget}, budget_duration: {_budget_duration}, model_id: {_model_id}"
)
if (
_max_budget is not None
and _budget_duration is not None
and _model_id is not None
):
_budget_config = GenericBudgetInfo(
time_period=_budget_duration,
budget_limit=_max_budget,
)
if self.deployment_budget_config is None:
self.deployment_budget_config = {}
self.deployment_budget_config[_model_id] = _budget_config
verbose_router_logger.debug(
f"Initialized Deployment Budget Config: {self.deployment_budget_config}"
)

View file

@ -172,6 +172,10 @@ class GenericLiteLLMParams(BaseModel):
max_file_size_mb: Optional[float] = None max_file_size_mb: Optional[float] = None
# Deployment budgets
max_budget: Optional[float] = None
budget_duration: Optional[str] = None
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
def __init__( def __init__(
@ -207,6 +211,9 @@ class GenericLiteLLMParams(BaseModel):
input_cost_per_second: Optional[float] = None, input_cost_per_second: Optional[float] = None,
output_cost_per_second: Optional[float] = None, output_cost_per_second: Optional[float] = None,
max_file_size_mb: Optional[float] = None, max_file_size_mb: Optional[float] = None,
# Deployment budgets
max_budget: Optional[float] = None,
budget_duration: Optional[str] = None,
**params, **params,
): ):
args = locals() args = locals()
@ -351,6 +358,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
# use this for tag-based routing # use this for tag-based routing
tags: Optional[List[str]] tags: Optional[List[str]]
# deployment budgets
max_budget: Optional[float]
budget_duration: Optional[str]
class DeploymentTypedDict(TypedDict, total=False): class DeploymentTypedDict(TypedDict, total=False):
model_name: Required[str] model_name: Required[str]
@ -436,7 +447,7 @@ class RouterErrors(enum.Enum):
"Not allowed to access model due to tags configuration" "Not allowed to access model due to tags configuration"
) )
no_deployments_with_provider_budget_routing = ( no_deployments_with_provider_budget_routing = (
"No deployments available - crossed budget for provider" "No deployments available - crossed budget"
) )
@ -635,12 +646,12 @@ class RoutingStrategy(enum.Enum):
PROVIDER_BUDGET_LIMITING = "provider-budget-routing" PROVIDER_BUDGET_LIMITING = "provider-budget-routing"
class ProviderBudgetInfo(BaseModel): class GenericBudgetInfo(BaseModel):
time_period: str # e.g., '1d', '30d' time_period: str # e.g., '1d', '30d'
budget_limit: float budget_limit: float
ProviderBudgetConfigType = Dict[str, ProviderBudgetInfo] GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
class RouterCacheEnum(enum.Enum): class RouterCacheEnum(enum.Enum):
@ -648,7 +659,7 @@ class RouterCacheEnum(enum.Enum):
RPM = "global_router:{id}:{model}:rpm:{current_minute}" RPM = "global_router:{id}:{model}:rpm:{current_minute}"
class ProviderBudgetWindowDetails(BaseModel): class GenericBudgetWindowDetails(BaseModel):
"""Details about a provider's budget window""" """Details about a provider's budget window"""
budget_start: float budget_start: float

View file

@ -1432,6 +1432,8 @@ all_litellm_params = [
"user_continue_message", "user_continue_message",
"fallback_depth", "fallback_depth",
"max_fallbacks", "max_fallbacks",
"max_budget",
"budget_duration",
] ]

View file

@ -11,11 +11,11 @@ sys.path.insert(
) # Adds the parent directory to the system-path ) # Adds the parent directory to the system-path
import pytest import pytest
from litellm import Router from litellm import Router
from litellm.router_strategy.provider_budgets import ProviderBudgetLimiting from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.types.router import ( from litellm.types.router import (
RoutingStrategy, RoutingStrategy,
ProviderBudgetConfigType, GenericBudgetConfigType,
ProviderBudgetInfo, GenericBudgetInfo,
) )
from litellm.caching.caching import DualCache, RedisCache from litellm.caching.caching import DualCache, RedisCache
import logging import logging
@ -43,6 +43,9 @@ def cleanup_redis():
for key in redis_client.scan_iter("provider_spend:*"): for key in redis_client.scan_iter("provider_spend:*"):
print("deleting key", key) print("deleting key", key)
redis_client.delete(key) redis_client.delete(key)
for key in redis_client.scan_iter("deployment_spend:*"):
print("deleting key", key)
redis_client.delete(key)
except Exception as e: except Exception as e:
print(f"Error cleaning up Redis: {str(e)}") print(f"Error cleaning up Redis: {str(e)}")
@ -59,9 +62,9 @@ async def test_provider_budgets_e2e_test():
""" """
cleanup_redis() cleanup_redis()
# Modify for test # Modify for test
provider_budget_config: ProviderBudgetConfigType = { provider_budget_config: GenericBudgetConfigType = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001), "openai": GenericBudgetInfo(time_period="1d", budget_limit=0.000000000001),
"azure": ProviderBudgetInfo(time_period="1d", budget_limit=100), "azure": GenericBudgetInfo(time_period="1d", budget_limit=100),
} }
router = Router( router = Router(
@ -175,7 +178,7 @@ async def test_get_llm_provider_for_deployment():
""" """
cleanup_redis() cleanup_redis()
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={} router_cache=DualCache(), provider_budget_config={}
) )
@ -208,11 +211,11 @@ async def test_get_budget_config_for_provider():
""" """
cleanup_redis() cleanup_redis()
config = { config = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), "openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
"anthropic": ProviderBudgetInfo(time_period="7d", budget_limit=500), "anthropic": GenericBudgetInfo(time_period="7d", budget_limit=500),
} }
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config=config router_cache=DualCache(), provider_budget_config=config
) )
@ -244,18 +247,18 @@ async def test_prometheus_metric_tracking():
mock_prometheus = MagicMock(spec=PrometheusLogger) mock_prometheus = MagicMock(spec=PrometheusLogger)
# Setup provider budget limiting # Setup provider budget limiting
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), router_cache=DualCache(),
provider_budget_config={ provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100) "openai": GenericBudgetInfo(time_period="1d", budget_limit=100)
}, },
) )
litellm._async_success_callback = [mock_prometheus] litellm._async_success_callback = [mock_prometheus]
provider_budget_config: ProviderBudgetConfigType = { provider_budget_config: GenericBudgetConfigType = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001), "openai": GenericBudgetInfo(time_period="1d", budget_limit=0.000000000001),
"azure": ProviderBudgetInfo(time_period="1d", budget_limit=100), "azure": GenericBudgetInfo(time_period="1d", budget_limit=100),
} }
router = Router( router = Router(
@ -308,7 +311,7 @@ async def test_handle_new_budget_window():
Current Current
""" """
cleanup_redis() cleanup_redis()
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={} router_cache=DualCache(), provider_budget_config={}
) )
@ -349,7 +352,7 @@ async def test_get_or_set_budget_start_time():
scenario 2: existing start time in cache, should return existing start time scenario 2: existing start time in cache, should return existing start time
""" """
cleanup_redis() cleanup_redis()
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={} router_cache=DualCache(), provider_budget_config={}
) )
@ -390,7 +393,7 @@ async def test_increment_spend_in_current_window():
- Queue the increment operation to Redis - Queue the increment operation to Redis
""" """
cleanup_redis() cleanup_redis()
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={} router_cache=DualCache(), provider_budget_config={}
) )
@ -437,11 +440,11 @@ async def test_sync_in_memory_spend_with_redis():
""" """
cleanup_redis() cleanup_redis()
provider_budget_config = { provider_budget_config = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), "openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
"anthropic": ProviderBudgetInfo(time_period="1d", budget_limit=200), "anthropic": GenericBudgetInfo(time_period="1d", budget_limit=200),
} }
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache( router_cache=DualCache(
redis_cache=RedisCache( redis_cache=RedisCache(
host=os.getenv("REDIS_HOST"), host=os.getenv("REDIS_HOST"),
@ -491,10 +494,10 @@ async def test_get_current_provider_spend():
3. Provider with budget config and spend returns correct value 3. Provider with budget config and spend returns correct value
""" """
cleanup_redis() cleanup_redis()
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), router_cache=DualCache(),
provider_budget_config={ provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), "openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
}, },
) )
@ -526,7 +529,7 @@ async def test_get_current_provider_budget_reset_at():
3. Provider with budget config and TTL returns correct ISO timestamp 3. Provider with budget config and TTL returns correct ISO timestamp
""" """
cleanup_redis() cleanup_redis()
provider_budget = ProviderBudgetLimiting( provider_budget = RouterBudgetLimiting(
router_cache=DualCache( router_cache=DualCache(
redis_cache=RedisCache( redis_cache=RedisCache(
host=os.getenv("REDIS_HOST"), host=os.getenv("REDIS_HOST"),
@ -535,8 +538,8 @@ async def test_get_current_provider_budget_reset_at():
) )
), ),
provider_budget_config={ provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), "openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
"vertex_ai": ProviderBudgetInfo(time_period="1h", budget_limit=100), "vertex_ai": GenericBudgetInfo(time_period="1h", budget_limit=100),
}, },
) )
@ -565,3 +568,109 @@ async def test_get_current_provider_budget_reset_at():
# Allow for small time differences (within 5 seconds) # Allow for small time differences (within 5 seconds)
time_difference = abs((reset_time - expected_time).total_seconds()) time_difference = abs((reset_time - expected_time).total_seconds())
assert time_difference < 5 assert time_difference < 5
@pytest.mark.asyncio
async def test_deployment_budget_limits_e2e_test():
"""
Expected behavior:
- First request forced to openai/gpt-4o
- Hit budget limit for openai/gpt-4o
- Next 3 requests all go to openai/gpt-4o-mini
"""
litellm.set_verbose = True
cleanup_redis()
# Modify for test
router = Router(
model_list=[
{
"model_name": "gpt-4o", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "openai/gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
"max_budget": 0.000000000001,
"budget_duration": "1d",
},
"model_info": {"id": "openai-gpt-4o"},
},
{
"model_name": "gpt-4o", # openai model name
"litellm_params": {
"model": "openai/gpt-4o-mini",
"api_key": os.getenv("OPENAI_API_KEY"),
"max_budget": 10,
"budget_duration": "20d",
},
"model_info": {"id": "openai-gpt-4o-mini"},
},
],
)
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="openai-gpt-4o",
)
print(response)
await asyncio.sleep(2.5)
for _ in range(3):
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="gpt-4o",
)
print(response)
await asyncio.sleep(1)
print("response.hidden_params", response._hidden_params)
assert response._hidden_params.get("model_id") == "openai-gpt-4o-mini"
@pytest.mark.asyncio
async def test_deployment_budgets_e2e_test_expect_to_fail():
"""
Expected behavior:
- first request passes, all subsequent requests fail
"""
cleanup_redis()
router = Router(
model_list=[
{
"model_name": "openai/gpt-4o-mini", # openai model name
"litellm_params": {
"model": "openai/gpt-4o-mini",
"max_budget": 0.000000000001,
"budget_duration": "1d",
},
},
],
redis_host=os.getenv("REDIS_HOST"),
redis_port=int(os.getenv("REDIS_PORT")),
redis_password=os.getenv("REDIS_PASSWORD"),
)
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="openai/gpt-4o-mini",
)
print(response)
await asyncio.sleep(2.5)
for _ in range(3):
with pytest.raises(Exception) as exc_info:
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="openai/gpt-4o-mini",
)
print(response)
print("response.hidden_params", response._hidden_params)
await asyncio.sleep(0.5)
# Verify the error is related to budget exceeded
assert "Exceeded budget for deployment" in str(exc_info.value)