Provider Budget Routing - Get Budget, Spend Details (#7063)

* add async_get_ttl to dual cache

* add ProviderBudgetResponse

* add provider_budgets

* test_redis_get_ttl

* _init_or_get_provider_budget_in_cache

* test_init_or_get_provider_budget_in_cache

* use _init_provider_budget_in_cache

* test_get_current_provider_budget_reset_at

* doc Get Budget, Spend Details

* doc Provider Budget Routing
This commit is contained in:
Ishaan Jaff 2024-12-06 21:14:12 -08:00 committed by GitHub
parent aaa4d4178a
commit 87ca62943b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 444 additions and 1 deletions

View file

@ -126,6 +126,53 @@ Expected response on failure
## Monitoring Provider Remaining Budget ## Monitoring Provider Remaining Budget
### Get Budget, Spend Details
Use this endpoint to check current budget, spend and budget reset time for a provider
Example Request
```bash
curl -X GET http://localhost:4000/provider/budgets \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234"
```
Example Response
```json
{
"providers": {
"openai": {
"budget_limit": 1e-12,
"time_period": "1d",
"spend": 0.0,
"budget_reset_at": null
},
"azure": {
"budget_limit": 100.0,
"time_period": "1d",
"spend": 0.0,
"budget_reset_at": null
},
"anthropic": {
"budget_limit": 100.0,
"time_period": "10d",
"spend": 0.0,
"budget_reset_at": null
},
"vertex_ai": {
"budget_limit": 100.0,
"time_period": "12d",
"spend": 0.0,
"budget_reset_at": null
}
}
}
```
### Prometheus Metric
LiteLLM will emit the following metric on Prometheus to track the remaining budget for each provider LiteLLM will emit the following metric on Prometheus to track the remaining budget for each provider
This metric indicates the remaining budget for a provider in dollars (USD) This metric indicates the remaining budget for a provider in dollars (USD)

View file

@ -423,3 +423,12 @@ class DualCache(BaseCache):
self.in_memory_cache.delete_cache(key) self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None: if self.redis_cache is not None:
await self.redis_cache.async_delete_cache(key) await self.redis_cache.async_delete_cache(key)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache or redis
"""
ttl = await self.in_memory_cache.async_get_ttl(key)
if ttl is None and self.redis_cache is not None:
ttl = await self.redis_cache.async_get_ttl(key)
return ttl

View file

@ -145,3 +145,9 @@ class InMemoryCache(BaseCache):
def delete_cache(self, key): def delete_cache(self, key):
self.cache_dict.pop(key, None) self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None) self.ttl_dict.pop(key, None)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache
"""
return self.ttl_dict.get(key, None)

View file

@ -980,3 +980,26 @@ class RedisCache(BaseCache):
str(e), str(e),
) )
raise e raise e
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in Redis
Args:
key (str): The key to get TTL for
Returns:
Optional[int]: The remaining TTL in seconds, or None if key doesn't exist
Redis ref: https://redis.io/docs/latest/commands/ttl/
"""
try:
_redis_client = await self.init_async_client()
async with _redis_client as redis_client:
ttl = await redis_client.ttl(key)
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
return None
return ttl
except Exception as e:
verbose_logger.debug(f"Redis TTL Error: {e}")
return None

View file

@ -2193,3 +2193,25 @@ LiteLLM_ManagementEndpoint_MetadataFields = [
"tags", "tags",
"enforced_params", "enforced_params",
] ]
class ProviderBudgetResponseObject(LiteLLMBase):
"""
Configuration for a single provider's budget settings
"""
budget_limit: float # Budget limit in USD for the time period
time_period: str # Time period for budget (e.g., '1d', '30d', '1mo')
spend: float = 0.0 # Current spend for this provider
budget_reset_at: Optional[str] = None # When the current budget period resets
class ProviderBudgetResponse(LiteLLMBase):
"""
Complete provider budget configuration and status.
Maps provider names to their budget configs.
"""
providers: Dict[str, ProviderBudgetResponseObject] = (
{}
) # Dictionary mapping provider names to their budget configurations

View file

@ -12,3 +12,27 @@ model_list:
litellm_settings: litellm_settings:
callbacks: ["datadog"] callbacks: ["datadog"]
router_settings:
provider_budget_config:
openai:
budget_limit: 0.000000000001 # float of $ value budget for time period
time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo
azure:
budget_limit: 100
time_period: 1d
anthropic:
budget_limit: 100
time_period: 10d
vertex_ai:
budget_limit: 100
time_period: 12d
gemini:
budget_limit: 100
time_period: 12d
# OPTIONAL: Set Redis Host, Port, and Password if using multiple instance of LiteLLM
redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT
redis_password: os.environ/REDIS_PASSWORD

View file

@ -8,10 +8,12 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.proxy._types import ProviderBudgetResponse, ProviderBudgetResponseObject
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.spend_tracking.spend_tracking_utils import ( from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_spend_by_team_and_customer, get_spend_by_team_and_customer,
) )
from litellm.proxy.utils import handle_exception_on_proxy
router = APIRouter() router = APIRouter()
@ -2464,3 +2466,92 @@ async def global_predict_spend_logs(request: Request):
data = await request.json() data = await request.json()
data = data.get("data") data = data.get("data")
return _forecast_daily_cost(data) return _forecast_daily_cost(data)
@router.get("/provider/budgets", response_model=ProviderBudgetResponse)
async def provider_budgets() -> ProviderBudgetResponse:
"""
Provider Budget Routing - Get Budget, Spend Details https://docs.litellm.ai/docs/proxy/provider_budget_routing
Use this endpoint to check current budget, spend and budget reset time for a provider
Example Request
```bash
curl -X GET http://localhost:4000/provider/budgets \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234"
```
Example Response
```json
{
"providers": {
"openai": {
"budget_limit": 1e-12,
"time_period": "1d",
"spend": 0.0,
"budget_reset_at": null
},
"azure": {
"budget_limit": 100.0,
"time_period": "1d",
"spend": 0.0,
"budget_reset_at": null
},
"anthropic": {
"budget_limit": 100.0,
"time_period": "10d",
"spend": 0.0,
"budget_reset_at": null
},
"vertex_ai": {
"budget_limit": 100.0,
"time_period": "12d",
"spend": 0.0,
"budget_reset_at": null
}
}
}
```
"""
from litellm.proxy.proxy_server import llm_router
try:
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": "No llm_router found"}
)
provider_budget_config = llm_router.provider_budget_config
if provider_budget_config is None:
raise ValueError(
"No provider budget config found. Please set a provider budget config in the router settings. https://docs.litellm.ai/docs/proxy/provider_budget_routing"
)
provider_budget_response_dict: Dict[str, ProviderBudgetResponseObject] = {}
for _provider, _budget_info in provider_budget_config.items():
_provider_spend = (
await llm_router.provider_budget_logger._get_current_provider_spend(
_provider
)
or 0.0
)
_provider_budget_ttl = await llm_router.provider_budget_logger._get_current_provider_budget_reset_at(
_provider
)
provider_budget_response_object = ProviderBudgetResponseObject(
budget_limit=_budget_info.budget_limit,
time_period=_budget_info.time_period,
spend=_provider_spend,
budget_reset_at=_provider_budget_ttl,
)
provider_budget_response_dict[_provider] = provider_budget_response_object
return ProviderBudgetResponse(providers=provider_budget_response_dict)
except Exception as e:
verbose_proxy_logger.exception(
"/provider/budgets: Exception occured - {}".format(str(e))
)
raise handle_exception_on_proxy(e)

View file

@ -19,7 +19,7 @@ anthropic:
""" """
import asyncio import asyncio
from datetime import datetime, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
import litellm import litellm
@ -68,6 +68,12 @@ class ProviderBudgetLimiting(CustomLogger):
budget_limit=config.get("budget_limit"), budget_limit=config.get("budget_limit"),
time_period=config.get("time_period"), time_period=config.get("time_period"),
) )
asyncio.create_task(
self._init_provider_budget_in_cache(
provider=provider,
budget_config=provider_budget_config[provider],
)
)
self.provider_budget_config: ProviderBudgetConfigType = provider_budget_config self.provider_budget_config: ProviderBudgetConfigType = provider_budget_config
verbose_router_logger.debug( verbose_router_logger.debug(
@ -450,3 +456,74 @@ class ProviderBudgetLimiting(CustomLogger):
spend=spend, spend=spend,
budget_limit=budget_limit, budget_limit=budget_limit,
) )
async def _get_current_provider_spend(self, provider: str) -> Optional[float]:
"""
GET the current spend for a provider from cache
used for GET /provider/budgets endpoint in spend_management_endpoints.py
Args:
provider (str): The provider to get spend for (e.g., "openai", "anthropic")
Returns:
Optional[float]: The current spend for the provider, or None if not found
"""
budget_config = self._get_budget_config_for_provider(provider)
if budget_config is None:
return None
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
if self.router_cache.redis_cache:
# use Redis as source of truth since that has spend across all instances
current_spend = await self.router_cache.redis_cache.async_get_cache(
spend_key
)
else:
# use in-memory cache if Redis is not initialized
current_spend = await self.router_cache.async_get_cache(spend_key)
return float(current_spend) if current_spend is not None else 0.0
async def _get_current_provider_budget_reset_at(
self, provider: str
) -> Optional[str]:
budget_config = self._get_budget_config_for_provider(provider)
if budget_config is None:
return None
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
if self.router_cache.redis_cache:
ttl_seconds = await self.router_cache.redis_cache.async_get_ttl(spend_key)
else:
ttl_seconds = await self.router_cache.async_get_ttl(spend_key)
if ttl_seconds is None:
return None
return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
async def _init_provider_budget_in_cache(
self, provider: str, budget_config: ProviderBudgetInfo
):
"""
Initialize provider budget in cache by storing the following keys if they don't exist:
- provider_spend:{provider}:{budget_config.time_period} - stores the current spend
- provider_budget_start_time:{provider} - stores the start time of the budget window
"""
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
start_time_key = f"provider_budget_start_time:{provider}"
ttl_seconds = duration_in_seconds(budget_config.time_period)
budget_start = await self.router_cache.async_get_cache(start_time_key)
if budget_start is None:
budget_start = datetime.now(timezone.utc).timestamp()
await self.router_cache.async_set_cache(
key=start_time_key, value=budget_start, ttl=ttl_seconds
)
_spend_key = await self.router_cache.async_get_cache(spend_key)
if _spend_key is None:
await self.router_cache.async_set_cache(
key=spend_key, value=0.0, ttl=ttl_seconds
)

View file

@ -646,3 +646,12 @@ ProviderBudgetConfigType = Dict[str, ProviderBudgetInfo]
class RouterCacheEnum(enum.Enum): class RouterCacheEnum(enum.Enum):
TPM = "global_router:{id}:{model}:tpm:{current_minute}" TPM = "global_router:{id}:{model}:tpm:{current_minute}"
RPM = "global_router:{id}:{model}:rpm:{current_minute}" RPM = "global_router:{id}:{model}:rpm:{current_minute}"
class ProviderBudgetWindowDetails(BaseModel):
"""Details about a provider's budget window"""
budget_start: float
spend_key: str
start_time_key: str
ttl_seconds: int

View file

@ -2478,3 +2478,51 @@ async def test_redis_increment_pipeline():
except Exception as e: except Exception as e:
print(f"Error occurred: {str(e)}") print(f"Error occurred: {str(e)}")
raise e raise e
@pytest.mark.asyncio
async def test_redis_get_ttl():
"""
Test Redis get TTL functionality
Redis returns -2 if the key does not exist and -1 if the key exists but has no associated expire.
test that litellm redis caching wrapper handles -1 and -2 values and returns them as None
"""
try:
from litellm.caching.redis_cache import RedisCache
redis_cache = RedisCache(
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
# Test case 1: Key does not exist
result = await redis_cache.async_get_ttl("nonexistent_key")
print("ttl for nonexistent key: ", result)
assert result is None, f"Expected None for nonexistent key, got {result}"
# Test case 2: Key exists with TTL
test_key = "test_key_ttl"
test_value = "test_value"
ttl = 10 # 10 seconds TTL
# Set a key with TTL
_redis_client = await redis_cache.init_async_client()
async with _redis_client as redis_client:
await redis_client.set(test_key, test_value, ex=ttl)
# Get TTL and verify it's close to what we set
result = await redis_cache.async_get_ttl(test_key)
print("ttl for test_key: ", result)
assert (
result is not None and 0 <= result <= ttl
), f"Expected TTL between 0 and {ttl}, got {result}"
# Clean up
await redis_client.delete(test_key)
except Exception as e:
print(f"Error occurred: {str(e)}")
raise e

View file

@ -21,6 +21,7 @@ from litellm.caching.caching import DualCache, RedisCache
import logging import logging
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
import litellm import litellm
from datetime import timezone, timedelta
verbose_router_logger.setLevel(logging.DEBUG) verbose_router_logger.setLevel(logging.DEBUG)
@ -476,3 +477,89 @@ async def test_sync_in_memory_spend_with_redis():
assert float(openai_spend) == 50.0 assert float(openai_spend) == 50.0
assert float(anthropic_spend) == 75.0 assert float(anthropic_spend) == 75.0
@pytest.mark.asyncio
async def test_get_current_provider_spend():
"""
Test _get_current_provider_spend helper method
Scenarios:
1. Provider with no budget config returns None
2. Provider with budget config but no spend returns 0.0
3. Provider with budget config and spend returns correct value
"""
cleanup_redis()
provider_budget = ProviderBudgetLimiting(
router_cache=DualCache(),
provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100),
},
)
# Test provider with no budget config
spend = await provider_budget._get_current_provider_spend("anthropic")
assert spend is None
# Test provider with budget config but no spend
spend = await provider_budget._get_current_provider_spend("openai")
assert spend == 0.0
# Test provider with budget config and spend
spend_key = "provider_spend:openai:1d"
await provider_budget.router_cache.async_set_cache(key=spend_key, value=50.5)
spend = await provider_budget._get_current_provider_spend("openai")
assert spend == 50.5
@pytest.mark.asyncio
async def test_get_current_provider_budget_reset_at():
"""
Test _get_current_provider_budget_reset_at helper method
Scenarios:
1. Provider with no budget config returns None
2. Provider with budget config but no TTL returns None
3. Provider with budget config and TTL returns correct ISO timestamp
"""
cleanup_redis()
provider_budget = ProviderBudgetLimiting(
router_cache=DualCache(
redis_cache=RedisCache(
host=os.getenv("REDIS_HOST"),
port=int(os.getenv("REDIS_PORT")),
password=os.getenv("REDIS_PASSWORD"),
)
),
provider_budget_config={
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100),
"vertex_ai": ProviderBudgetInfo(time_period="1h", budget_limit=100),
},
)
await asyncio.sleep(2)
# Test provider with no budget config
reset_at = await provider_budget._get_current_provider_budget_reset_at("anthropic")
assert reset_at is None
# Test provider with budget config but no TTL
reset_at = await provider_budget._get_current_provider_budget_reset_at("openai")
assert reset_at is not None
reset_time = datetime.fromisoformat(reset_at.replace("Z", "+00:00"))
expected_time = datetime.now(timezone.utc) + timedelta(seconds=(24 * 60 * 60))
time_difference = abs((reset_time - expected_time).total_seconds())
assert time_difference < 5
# Test provider with budget config and TTL
reset_at = await provider_budget._get_current_provider_budget_reset_at("vertex_ai")
assert reset_at is not None
# Verify the timestamp format and approximate time
reset_time = datetime.fromisoformat(reset_at.replace("Z", "+00:00"))
expected_time = datetime.now(timezone.utc) + timedelta(seconds=3600)
# Allow for small time differences (within 5 seconds)
time_difference = abs((reset_time - expected_time).total_seconds())
assert time_difference < 5