(feat) provider budget routing improvements (#6827)

* minor fix for provider budget

* fix raise good error message when budget crossed for provider budget

* fix test provider budgets

* test provider budgets

* feat - emit llm provider spend on prometheus

* test_prometheus_metric_tracking

* doc provider budgets
This commit is contained in:
Ishaan Jaff 2024-11-19 21:25:08 -08:00 committed by GitHub
parent 3c6fe21935
commit 7463dab9c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 261 additions and 20 deletions

View file

@ -4,18 +4,16 @@ import TabItem from '@theme/TabItem';
# Provider Budget Routing # Provider Budget Routing
Use this to set budgets for LLM Providers - example $100/day for OpenAI, $100/day for Azure. Use this to set budgets for LLM Providers - example $100/day for OpenAI, $100/day for Azure.
## Quick Start
Set provider budgets in your `proxy_config.yaml` file
### Proxy Config setup
```yaml ```yaml
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: openai/gpt-3.5-turbo model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-functioncalling
api_key: os.environ/AZURE_API_KEY
api_version: os.environ/AZURE_API_VERSION
api_base: os.environ/AZURE_API_BASE
router_settings: router_settings:
redis_host: <your-redis-host> redis_host: <your-redis-host>
@ -42,8 +40,66 @@ general_settings:
master_key: sk-1234 master_key: sk-1234
``` ```
### Make a test request
#### How provider-budget-routing works We expect the first request to succeed, and the second request to fail since we cross the budget for `openai`
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
<Tabs>
<TabItem label="Successful Call " value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "hi my name is test request"}
]
}'
```
</TabItem>
<TabItem label="Unsuccessful call" value = "not-allowed">
Expect this to fail since since `ishaan@berri.ai` in the request is PII
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "hi my name is test request"}
]
}'
```
Expected response on failure
```json
{
"error": {
"message": "No deployments available - crossed budget for provider: Exceeded budget for provider openai: 0.0007350000000000001 >= 1e-12",
"type": "None",
"param": "None",
"code": "429"
}
}
```
</TabItem>
</Tabs>
## How provider budget routing works
1. **Budget Tracking**: 1. **Budget Tracking**:
- Uses Redis to track spend for each provider - Uses Redis to track spend for each provider
@ -62,3 +118,33 @@ general_settings:
4. **Requirements**: 4. **Requirements**:
- Redis required for tracking spend across instances - Redis required for tracking spend across instances
- Provider names must be litellm provider names. See [Supported Providers](https://docs.litellm.ai/docs/providers) - Provider names must be litellm provider names. See [Supported Providers](https://docs.litellm.ai/docs/providers)
## Monitoring Provider Remaining Budget
LiteLLM will emit the following metric on Prometheus to track the remaining budget for each provider
This metric indicates the remaining budget for a provider in dollars (USD)
```
litellm_provider_remaining_budget_metric{api_provider="openai"} 10
```
## Spec for provider_budget_config
The `provider_budget_config` is a dictionary where:
- **Key**: Provider name (string) - Must be a valid [LiteLLM provider name](https://docs.litellm.ai/docs/providers)
- **Value**: Budget configuration object with the following parameters:
- `budget_limit`: Float value representing the budget in USD
- `time_period`: String in the format "Xd" where X is the number of days (e.g., "1d", "30d")
Example structure:
```yaml
provider_budget_config:
openai:
budget_limit: 100.0 # $100 USD
time_period: "1d" # 1 day period
azure:
budget_limit: 500.0 # $500 USD
time_period: "30d" # 30 day period
```

View file

@ -228,6 +228,13 @@ class PrometheusLogger(CustomLogger):
"api_key_alias", "api_key_alias",
], ],
) )
# llm api provider budget metrics
self.litellm_provider_remaining_budget_metric = Gauge(
"litellm_provider_remaining_budget_metric",
"Remaining budget for provider - used when you set provider budget limits",
labelnames=["api_provider"],
)
# Get all keys # Get all keys
_logged_llm_labels = [ _logged_llm_labels = [
"litellm_model_name", "litellm_model_name",
@ -1130,6 +1137,19 @@ class PrometheusLogger(CustomLogger):
litellm_model_name, model_id, api_base, api_provider, exception_status litellm_model_name, model_id, api_base, api_provider, exception_status
).inc() ).inc()
def track_provider_remaining_budget(
self, provider: str, spend: float, budget_limit: float
):
"""
Track provider remaining budget in Prometheus
"""
self.litellm_provider_remaining_budget_metric.labels(provider).set(
self._safe_get_remaining_budget(
max_budget=budget_limit,
spend=spend,
)
)
def _safe_get_remaining_budget( def _safe_get_remaining_budget(
self, max_budget: Optional[float], spend: Optional[float] self, max_budget: Optional[float], spend: Optional[float]
) -> float: ) -> float:

View file

@ -1,14 +1,18 @@
model_list: model_list:
- model_name: fake-openai-endpoint - model_name: gpt-4o
litellm_params: litellm_params:
model: openai/fake model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
api_base: https://exampleopenaiendpoint-production.up.railway.app/
router_settings:
provider_budget_config:
openai:
budget_limit: 0.000000000001 # float of $ value budget for time period
time_period: 1d # can be 1d, 2d, 30d
azure:
budget_limit: 100
time_period: 1d
general_settings: litellm_settings:
key_management_system: "aws_secret_manager" callbacks: ["prometheus"]
key_management_settings:
store_virtual_keys: true
access_mode: "write_only"

View file

@ -25,10 +25,14 @@ from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.router_utils.cooldown_callbacks import (
_get_prometheus_logger_from_callbacks,
)
from litellm.types.router import ( from litellm.types.router import (
LiteLLM_Params, LiteLLM_Params,
ProviderBudgetConfigType, ProviderBudgetConfigType,
ProviderBudgetInfo, ProviderBudgetInfo,
RouterErrors,
) )
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
@ -43,6 +47,20 @@ else:
class ProviderBudgetLimiting(CustomLogger): class ProviderBudgetLimiting(CustomLogger):
def __init__(self, router_cache: DualCache, provider_budget_config: dict): def __init__(self, router_cache: DualCache, provider_budget_config: dict):
self.router_cache = router_cache self.router_cache = router_cache
# cast elements of provider_budget_config to ProviderBudgetInfo
for provider, config in provider_budget_config.items():
if config is None:
raise ValueError(
f"No budget config found for provider {provider}, provider_budget_config: {provider_budget_config}"
)
if not isinstance(config, ProviderBudgetInfo):
provider_budget_config[provider] = ProviderBudgetInfo(
budget_limit=config.get("budget_limit"),
time_period=config.get("time_period"),
)
self.provider_budget_config: ProviderBudgetConfigType = provider_budget_config self.provider_budget_config: ProviderBudgetConfigType = provider_budget_config
verbose_router_logger.debug( verbose_router_logger.debug(
f"Initalized Provider budget config: {self.provider_budget_config}" f"Initalized Provider budget config: {self.provider_budget_config}"
@ -71,6 +89,10 @@ class ProviderBudgetLimiting(CustomLogger):
if isinstance(healthy_deployments, dict): if isinstance(healthy_deployments, dict):
healthy_deployments = [healthy_deployments] healthy_deployments = [healthy_deployments]
# Don't do any filtering if there are no healthy deployments
if len(healthy_deployments) == 0:
return healthy_deployments
potential_deployments: List[Dict] = [] potential_deployments: List[Dict] = []
# Extract the parent OpenTelemetry span for tracing # Extract the parent OpenTelemetry span for tracing
@ -113,6 +135,7 @@ class ProviderBudgetLimiting(CustomLogger):
provider_spend_map[provider] = float(current_spends[idx] or 0.0) provider_spend_map[provider] = float(current_spends[idx] or 0.0)
# Filter healthy deployments based on budget constraints # Filter healthy deployments based on budget constraints
deployment_above_budget_info: str = "" # used to return in error message
for deployment in healthy_deployments: for deployment in healthy_deployments:
provider = self._get_llm_provider_for_deployment(deployment) provider = self._get_llm_provider_for_deployment(deployment)
if provider is None: if provider is None:
@ -128,15 +151,25 @@ class ProviderBudgetLimiting(CustomLogger):
verbose_router_logger.debug( verbose_router_logger.debug(
f"Current spend for {provider}: {current_spend}, budget limit: {budget_limit}" f"Current spend for {provider}: {current_spend}, budget limit: {budget_limit}"
) )
self._track_provider_remaining_budget_prometheus(
provider=provider,
spend=current_spend,
budget_limit=budget_limit,
)
if current_spend >= budget_limit: if current_spend >= budget_limit:
verbose_router_logger.debug( debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {budget_limit}"
f"Skipping deployment {deployment} for provider {provider} as spend limit exceeded" verbose_router_logger.debug(debug_msg)
) deployment_above_budget_info += f"{debug_msg}\n"
continue continue
potential_deployments.append(deployment) potential_deployments.append(deployment)
if len(potential_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
)
return potential_deployments return potential_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -217,3 +250,21 @@ class ProviderBudgetLimiting(CustomLogger):
days = int(time_period[:-1]) days = int(time_period[:-1])
return days * 24 * 60 * 60 return days * 24 * 60 * 60
raise ValueError(f"Unsupported time period format: {time_period}") raise ValueError(f"Unsupported time period format: {time_period}")
def _track_provider_remaining_budget_prometheus(
self, provider: str, spend: float, budget_limit: float
):
"""
Optional helper - emit provider remaining budget metric to Prometheus
This is helpful for debugging and monitoring provider budget limits.
"""
from litellm.integrations.prometheus import PrometheusLogger
prometheus_logger = _get_prometheus_logger_from_callbacks()
if prometheus_logger:
prometheus_logger.track_provider_remaining_budget(
provider=provider,
spend=spend,
budget_limit=budget_limit,
)

View file

@ -88,6 +88,9 @@ def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
""" """
from litellm.integrations.prometheus import PrometheusLogger from litellm.integrations.prometheus import PrometheusLogger
for _callback in litellm._async_success_callback:
if isinstance(_callback, PrometheusLogger):
return _callback
for _callback in litellm.callbacks: for _callback in litellm.callbacks:
if isinstance(_callback, PrometheusLogger): if isinstance(_callback, PrometheusLogger):
return _callback return _callback

View file

@ -434,6 +434,9 @@ class RouterErrors(enum.Enum):
no_deployments_with_tag_routing = ( no_deployments_with_tag_routing = (
"Not allowed to access model due to tags configuration" "Not allowed to access model due to tags configuration"
) )
no_deployments_with_provider_budget_routing = (
"No deployments available - crossed budget for provider"
)
class AllowedFailsPolicy(BaseModel): class AllowedFailsPolicy(BaseModel):

View file

@ -20,6 +20,7 @@ from litellm.types.router import (
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
import logging import logging
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
import litellm
verbose_router_logger.setLevel(logging.DEBUG) verbose_router_logger.setLevel(logging.DEBUG)
@ -93,8 +94,14 @@ async def test_provider_budgets_e2e_test_expect_to_fail():
- first request passes, all subsequent requests fail - first request passes, all subsequent requests fail
""" """
provider_budget_config: ProviderBudgetConfigType = {
"anthropic": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001), # Note: We intentionally use a dictionary with string keys for budget_limit and time_period
# we want to test that the router can handle type conversion, since the proxy config yaml passes these values as a dictionary
provider_budget_config = {
"anthropic": {
"budget_limit": 0.000000000001,
"time_period": "1d",
}
} }
router = Router( router = Router(
@ -132,6 +139,8 @@ async def test_provider_budgets_e2e_test_expect_to_fail():
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# Verify the error is related to budget exceeded # Verify the error is related to budget exceeded
assert "Exceeded budget for provider" in str(exc_info.value)
def test_get_ttl_seconds(): def test_get_ttl_seconds():
""" """
@ -207,3 +216,68 @@ def test_get_budget_config_for_provider():
# Test non-existent provider # Test non-existent provider
assert provider_budget._get_budget_config_for_provider("unknown") is None assert provider_budget._get_budget_config_for_provider("unknown") is None
@pytest.mark.asyncio
async def test_prometheus_metric_tracking():
"""
Test that the Prometheus metric for provider budget is tracked correctly
"""
from unittest.mock import MagicMock
from litellm.integrations.prometheus import PrometheusLogger
# Create a mock PrometheusLogger
mock_prometheus = MagicMock(spec=PrometheusLogger)
# Setup provider budget limiting
provider_budget = ProviderBudgetLimiting(
router_cache=DualCache(),
provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100)
},
)
litellm._async_success_callback = [mock_prometheus]
provider_budget_config: ProviderBudgetConfigType = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001),
"azure": ProviderBudgetInfo(time_period="1d", budget_limit=100),
}
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {"id": "azure-model-id"},
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": {
"model": "openai/gpt-4o-mini",
},
"model_info": {"id": "openai-model-id"},
},
],
provider_budget_config=provider_budget_config,
redis_host=os.getenv("REDIS_HOST"),
redis_port=int(os.getenv("REDIS_PORT")),
redis_password=os.getenv("REDIS_PASSWORD"),
)
response = await router.acompletion(
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="openai/gpt-4o-mini",
mock_response="hi",
)
print(response)
await asyncio.sleep(0.5)
# Verify the mock was called correctly
mock_prometheus.track_provider_remaining_budget.assert_called_once()