(feat proxy) v2 - model max budgets (#7302)

* clean up unused code

* add _PROXY_VirtualKeyModelMaxBudgetLimiter

* adjust type imports

* working _PROXY_VirtualKeyModelMaxBudgetLimiter

* fix user_api_key_model_max_budget

* fix user_api_key_model_max_budget

* update naming

* update naming

* fix changes to RouterBudgetLimiting

* test_call_with_key_over_model_budget

* test_call_with_key_over_model_budget

* handle _get_request_model_budget_config

* e2e test for test_call_with_key_over_model_budget

* clean up test

* run ci/cd again

* add validate_model_max_budget

* docs fix

* update doc

* add e2e testing for _PROXY_VirtualKeyModelMaxBudgetLimiter

* test_unit_test_max_model_budget_limiter.py
This commit is contained in:
Ishaan Jaff 2024-12-18 19:42:46 -08:00 committed by GitHub
parent 1a4910f6c0
commit 6220e17ebf
14 changed files with 628 additions and 261 deletions

View file

@ -10,7 +10,7 @@ Requirements:
## Set Budgets
You can set budgets at 3 levels:
You can set budgets at 5 levels:
- For the proxy
- For an internal user
- For a customer (end-user)
@ -392,32 +392,88 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
<TabItem value="per-model-key" label="For Key (model specific)">
Apply model specific budgets on a key.
**Expected Behaviour**
- `model_spend` gets auto-populated in `LiteLLM_VerificationToken` Table
- After the key crosses the budget set for the `model` in `model_max_budget`, calls fail
By default the `model_max_budget` is set to `{}` and is not checked for keys
:::info
- LiteLLM will track the cost/budgets for the `model` passed to LLM endpoints (`/chat/completions`, `/embeddings`)
:::
Apply model specific budgets on a key. Example:
- Budget for `gpt-4o` is $0.0000001, for time period `1d` for `key = "sk-12345"`
- Budget for `gpt-4o-mini` is $10, for time period `30d` for `key = "sk-12345"`
#### **Add model specific budgets to keys**
The spec for `model_max_budget` is **[`Dict[str, GenericBudgetInfo]`](#genericbudgetinfo)**
```bash
curl 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{
model_max_budget={"gpt4": 0.5, "gpt-5": 0.01}
"model_max_budget": {"gpt-4o": {"budget_limit": "0.0000001", "time_period": "1d"}}
}'
```
#### Make a test request
We expect the first request to succeed, and the second request to fail since we cross the budget for `gpt-4o` on the Virtual Key
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
<Tabs>
<TabItem label="Successful Call " value = "allowed">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <sk-generated-key>' \
--data ' {
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "testing request"
}
]
}
'
```
</TabItem>
<TabItem label="Unsuccessful call" value = "not-allowed">
Expect this to fail since since we cross the budget `model=gpt-4o` on the Virtual Key
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <sk-generated-key>' \
--data ' {
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "testing request"
}
]
}
'
```
Expected response on failure
```json
{
"error": {
"message": "LiteLLM Virtual Key: 9769f3f6768a199f76cc29xxxx, key_alias: None, exceeded budget for model=gpt-4o",
"type": "budget_exceeded",
"param": null,
"code": "400"
}
}
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
@ -783,3 +839,32 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Content-Type: application/json' \
--data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}'
```
## API Specification
### `GenericBudgetInfo`
A Pydantic model that defines budget information with a time period and limit.
```python
class GenericBudgetInfo(BaseModel):
budget_limit: float # The maximum budget amount in USD
time_period: str # Duration string like "1d", "30d", etc.
```
#### Fields:
- `budget_limit` (float): The maximum budget amount in USD
- `time_period` (str): Duration string specifying the time period for the budget. Supported formats:
- Seconds: "30s"
- Minutes: "30m"
- Hours: "30h"
- Days: "30d"
#### Example:
```json
{
"budget_limit": "0.0001",
"time_period": "1d"
}
```

View file

@ -262,6 +262,7 @@ async def user_api_key_auth( # noqa: PLR0915
llm_model_list,
llm_router,
master_key,
model_max_budget_limiter,
open_telemetry_logger,
prisma_client,
proxy_logging_obj,
@ -1053,37 +1054,10 @@ async def user_api_key_auth( # noqa: PLR0915
and valid_token.token is not None
):
## GET THE SPEND FOR THIS MODEL
twenty_eight_days_ago = datetime.now() - timedelta(days=28)
model_spend = await prisma_client.db.litellm_spendlogs.group_by(
by=["model"],
sum={"spend": True},
where={
"AND": [
{"api_key": valid_token.token},
{"startTime": {"gt": twenty_eight_days_ago}},
{"model": current_model},
]
}, # type: ignore
await model_max_budget_limiter.is_key_within_model_budget(
user_api_key_dict=valid_token,
model=current_model,
)
if (
len(model_spend) > 0
and max_budget_per_model.get(current_model, None) is not None
):
if (
"model" in model_spend[0]
and model_spend[0].get("model") == current_model
and "_sum" in model_spend[0]
and "spend" in model_spend[0]["_sum"]
and model_spend[0]["_sum"]["spend"]
>= max_budget_per_model[current_model]
):
current_model_spend = model_spend[0]["_sum"]["spend"]
current_model_budget = max_budget_per_model[current_model]
raise litellm.BudgetExceededError(
current_cost=current_model_spend,
max_budget=current_model_budget,
)
# Check 6. Team spend is under Team budget
if (
hasattr(valid_token, "team_spend")

View file

@ -0,0 +1,196 @@
import json
import traceback
from typing import List, Optional
from fastapi import HTTPException
import litellm
from litellm import verbose_logger
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.proxy._types import UserAPIKeyAuth
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
GenericBudgetConfigType,
GenericBudgetInfo,
StandardLoggingPayload,
)
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
"""
Handles budgets for model + virtual key
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
"""
def __init__(self, dual_cache: DualCache):
self.dual_cache = dual_cache
self.redis_increment_operation_queue = []
async def is_key_within_model_budget(
self,
user_api_key_dict: UserAPIKeyAuth,
model: str,
) -> bool:
"""
Check if the user_api_key_dict is within the model budget
Raises:
BudgetExceededError: If the user_api_key_dict has exceeded the model budget
"""
_model_max_budget = user_api_key_dict.model_max_budget
internal_model_max_budget: GenericBudgetConfigType = {}
# case each element in _model_max_budget to GenericBudgetInfo
for _model, _budget_info in _model_max_budget.items():
internal_model_max_budget[_model] = GenericBudgetInfo(
time_period=_budget_info.get("time_period"),
budget_limit=float(_budget_info.get("budget_limit")),
)
verbose_proxy_logger.debug(
"internal_model_max_budget %s",
json.dumps(internal_model_max_budget, indent=4, default=str),
)
# check if current model is in internal_model_max_budget
_current_model_budget_info = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if _current_model_budget_info is None:
verbose_proxy_logger.debug(
f"Model {model} not found in internal_model_max_budget"
)
return True
# check if current model is within budget
if _current_model_budget_info.budget_limit > 0:
_current_spend = await self._get_virtual_key_spend_for_model(
user_api_key_hash=user_api_key_dict.token,
model=model,
key_budget_config=_current_model_budget_info,
)
if (
_current_spend is not None
and _current_spend > _current_model_budget_info.budget_limit
):
raise litellm.BudgetExceededError(
message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
current_cost=_current_spend,
max_budget=_current_model_budget_info.budget_limit,
)
return True
async def _get_virtual_key_spend_for_model(
self,
user_api_key_hash: Optional[str],
model: str,
key_budget_config: GenericBudgetInfo,
) -> Optional[float]:
"""
Get the current spend for a virtual key for a model
Lookup model in this order:
1. model: directly look up `model`
2. If 1, does not exist, check if passed as {custom_llm_provider}/model
"""
# 1. model: directly look up `model`
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.time_period}"
_current_spend = await self.dual_cache.async_get_cache(
key=virtual_key_model_spend_cache_key,
)
if _current_spend is None:
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
# if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.time_period}"
_current_spend = await self.dual_cache.async_get_cache(
key=virtual_key_model_spend_cache_key,
)
return _current_spend
def _get_request_model_budget_config(
self, model: str, internal_model_max_budget: GenericBudgetConfigType
) -> Optional[GenericBudgetInfo]:
"""
Get the budget config for the request model
1. Check if `model` is in `internal_model_max_budget`
2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
"""
return internal_model_max_budget.get(
model, None
) or internal_model_max_budget.get(
self._get_model_without_custom_llm_provider(model), None
)
def _get_model_without_custom_llm_provider(self, model: str) -> str:
if "/" in model:
return model.split("/")[-1]
return model
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None, # type: ignore
) -> List[dict]:
return healthy_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Track spend for virtual key + model in DualCache
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
"""
verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_payload is None:
raise ValueError("standard_logging_payload is required")
_litellm_params = kwargs.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {})
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
"user_api_key_model_max_budget", None
)
if (
user_api_key_model_max_budget is None
or len(user_api_key_model_max_budget) == 0
):
verbose_proxy_logger.debug(
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
user_api_key_model_max_budget,
)
return
response_cost: float = standard_logging_payload.get("response_cost", 0)
model = standard_logging_payload.get("model")
virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash")
model = standard_logging_payload.get("model")
if virtual_key is not None:
budget_config = GenericBudgetInfo(time_period="1d", budget_limit=0.1)
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.time_period}"
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
await self._increment_spend_for_key(
budget_config=budget_config,
spend_key=virtual_spend_key,
start_time_key=virtual_start_time_key,
response_cost=response_cost,
)
verbose_proxy_logger.debug(
"current state of in memory cache %s",
json.dumps(
self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
),
)

View file

@ -499,6 +499,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
data[_metadata_variable_name][
"user_api_key_max_budget"
] = user_api_key_dict.max_budget
data[_metadata_variable_name][
"user_api_key_model_max_budget"
] = user_api_key_dict.model_max_budget
data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)

View file

@ -40,7 +40,11 @@ from litellm.proxy.utils import (
handle_exception_on_proxy,
)
from litellm.secret_managers.main import get_secret
from litellm.types.utils import PersonalUIKeyGenerationConfig, TeamUIKeyGenerationConfig
from litellm.types.utils import (
GenericBudgetInfo,
PersonalUIKeyGenerationConfig,
TeamUIKeyGenerationConfig,
)
def _is_team_key(data: GenerateKeyRequest):
@ -246,7 +250,7 @@ async def generate_key_fn( # noqa: PLR0915
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
- guardrails: Optional[List[str]] - List of active guardrails for the key
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
- model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget.
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget.
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
- model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit.
- allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request
@ -515,6 +519,10 @@ def prepare_key_update_data(
_metadata = existing_key_row.metadata or {}
# validate model_max_budget
if "model_max_budget" in non_default_values:
validate_model_max_budget(non_default_values["model_max_budget"])
non_default_values = prepare_metadata_fields(
data=data, non_default_values=non_default_values, existing_metadata=_metadata
)
@ -548,7 +556,7 @@ async def update_key_fn(
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
- spend: Optional[float] - Amount spent by key
- max_budget: Optional[float] - Max budget for key
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
- max_parallel_requests: Optional[int] - Rate limit for parallel requests
@ -1035,6 +1043,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
metadata["guardrails"] = guardrails
metadata_json = json.dumps(metadata)
validate_model_max_budget(model_max_budget)
model_max_budget_json = json.dumps(model_max_budget)
user_role = user_role
tpm_limit = tpm_limit
@ -1266,7 +1275,7 @@ async def regenerate_key_fn(
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
- spend: Optional[float] - Amount spent by key
- max_budget: Optional[float] - Max budget for key
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
- max_parallel_requests: Optional[int] - Rate limit for parallel requests
@ -1293,8 +1302,7 @@ async def regenerate_key_fn(
--data-raw '{
"max_budget": 100,
"metadata": {"team": "core-infra"},
"models": ["gpt-4", "gpt-3.5-turbo"],
"model_max_budget": {"gpt-4": 50, "gpt-3.5-turbo": 50}
"models": ["gpt-4", "gpt-3.5-turbo"]
}'
```
@ -1949,3 +1957,29 @@ async def _enforce_unique_key_alias(
param="key_alias",
code=status.HTTP_400_BAD_REQUEST,
)
def validate_model_max_budget(model_max_budget: Optional[Dict]) -> None:
"""
Validate the model_max_budget is GenericBudgetConfigType
Raises:
Exception: If model_max_budget is not a valid GenericBudgetConfigType
"""
try:
if model_max_budget is None:
return
if len(model_max_budget) == 0:
return
if model_max_budget is not None:
for _model, _budget_info in model_max_budget.items():
assert isinstance(_model, str)
# /CRUD endpoints can pass budget_limit as a string, so we need to convert it to a float
if "budget_limit" in _budget_info:
_budget_info["budget_limit"] = float(_budget_info["budget_limit"])
GenericBudgetInfo(**_budget_info)
except Exception as e:
raise ValueError(
f"Invalid model_max_budget: {str(e)}. Example of valid model_max_budget: https://docs.litellm.ai/docs/proxy/users"
)

View file

@ -1,42 +1,10 @@
model_list:
- model_name: fake-openai-endpoint
- model_name: openai/o1-preview
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
tags: ["teamA"]
model_info:
id: "team-a-model"
- model_name: fake-openai-endpoint
model: openai/o1-preview
api_key: os.environ/OPENAI_API_KEY
- model_name: openai/*
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
tags: ["teamB"]
model_info:
id: "team-b-model"
- model_name: rerank-english-v3.0
litellm_params:
model: cohere/rerank-english-v3.0
api_key: os.environ/COHERE_API_KEY
- model_name: fake-azure-endpoint
litellm_params:
model: openai/429
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app
- model_name: llava-hf
litellm_params:
model: openai/llava-hf/llava-v1.6-vicuna-7b-hf
api_base: http://localhost:8000
api_key: fake-key
model_info:
supports_vision: True
model: openai/*
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
cache: true
callbacks: ["otel", "prometheus"]
router_settings:
enable_tag_filtering: True # 👈 Key Change

View file

@ -173,6 +173,9 @@ from litellm.proxy.guardrails.init_guardrails import (
)
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.model_max_budget_limiter import (
_PROXY_VirtualKeyModelMaxBudgetLimiter,
)
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
@ -495,6 +498,10 @@ prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache(
default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
)
model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache
)
litellm.callbacks.append(model_max_budget_limiter)
redis_usage_cache: Optional[RedisCache] = (
None # redis cache used for tracking spend, tpm/rpm limits
)

View file

@ -631,7 +631,7 @@ class Router:
_callback = PromptCachingDeploymentCheck(cache=self.cache)
elif pre_call_check == "router_budget_limiting":
_callback = RouterBudgetLimiting(
router_cache=self.cache,
dual_cache=self.cache,
provider_budget_config=self.provider_budget_config,
model_list=self.model_list,
)
@ -5292,14 +5292,6 @@ class Router:
healthy_deployments=healthy_deployments,
)
# if self.router_budget_logger:
# healthy_deployments = (
# await self.router_budget_logger.async_filter_deployments(
# healthy_deployments=healthy_deployments,
# request_kwargs=request_kwargs,
# )
# )
if len(healthy_deployments) == 0:
exception = await async_raise_no_deployment_exception(
litellm_router_instance=self,

View file

@ -49,13 +49,13 @@ DEFAULT_REDIS_SYNC_INTERVAL = 1
class RouterBudgetLimiting(CustomLogger):
def __init__(
self,
router_cache: DualCache,
dual_cache: DualCache,
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
self.router_cache = router_cache
self.dual_cache = dual_cache
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
self.provider_budget_config: Optional[GenericBudgetConfigType] = (
@ -108,7 +108,7 @@ class RouterBudgetLimiting(CustomLogger):
# Single cache read for all spend values
if len(cache_keys) > 0:
_current_spends = await self.router_cache.async_batch_get_cache(
_current_spends = await self.dual_cache.async_batch_get_cache(
keys=cache_keys,
parent_otel_span=parent_otel_span,
)
@ -286,9 +286,9 @@ class RouterBudgetLimiting(CustomLogger):
If it does, return the value.
If it does not, set the key to `current_time` and return the value.
"""
budget_start = await self.router_cache.async_get_cache(start_time_key)
budget_start = await self.dual_cache.async_get_cache(start_time_key)
if budget_start is None:
await self.router_cache.async_set_cache(
await self.dual_cache.async_set_cache(
key=start_time_key, value=current_time, ttl=ttl_seconds
)
return current_time
@ -314,10 +314,10 @@ class RouterBudgetLimiting(CustomLogger):
- stores key: `provider_budget_start_time:{provider}`, value: current_time.
This stores the start time of the new budget window
"""
await self.router_cache.async_set_cache(
await self.dual_cache.async_set_cache(
key=spend_key, value=response_cost, ttl=ttl_seconds
)
await self.router_cache.async_set_cache(
await self.dual_cache.async_set_cache(
key=start_time_key, value=current_time, ttl=ttl_seconds
)
return current_time
@ -333,7 +333,7 @@ class RouterBudgetLimiting(CustomLogger):
- Increments the spend in memory cache (so spend instantly updated in memory)
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
"""
await self.router_cache.in_memory_cache.async_increment(
await self.dual_cache.in_memory_cache.async_increment(
key=spend_key,
value=response_cost,
ttl=ttl,
@ -481,7 +481,7 @@ class RouterBudgetLimiting(CustomLogger):
Only runs if Redis is initialized
"""
try:
if not self.router_cache.redis_cache:
if not self.dual_cache.redis_cache:
return # Redis is not initialized
verbose_router_logger.debug(
@ -490,7 +490,7 @@ class RouterBudgetLimiting(CustomLogger):
)
if len(self.redis_increment_operation_queue) > 0:
asyncio.create_task(
self.router_cache.redis_cache.async_increment_pipeline(
self.dual_cache.redis_cache.async_increment_pipeline(
increment_list=self.redis_increment_operation_queue,
)
)
@ -517,7 +517,7 @@ class RouterBudgetLimiting(CustomLogger):
try:
# No need to sync if Redis cache is not initialized
if self.router_cache.redis_cache is None:
if self.dual_cache.redis_cache is None:
return
# 1. Push all provider spend increments to Redis
@ -547,7 +547,7 @@ class RouterBudgetLimiting(CustomLogger):
cache_keys.append(f"tag_spend:{tag}:{config.time_period}")
# Batch fetch current spend values from Redis
redis_values = await self.router_cache.redis_cache.async_batch_get_cache(
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
key_list=cache_keys
)
@ -555,7 +555,7 @@ class RouterBudgetLimiting(CustomLogger):
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
for key, value in redis_values.items():
if value is not None:
await self.router_cache.in_memory_cache.async_set_cache(
await self.dual_cache.in_memory_cache.async_set_cache(
key=key, value=float(value)
)
verbose_router_logger.debug(
@ -639,14 +639,12 @@ class RouterBudgetLimiting(CustomLogger):
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
if self.router_cache.redis_cache:
if self.dual_cache.redis_cache:
# use Redis as source of truth since that has spend across all instances
current_spend = await self.router_cache.redis_cache.async_get_cache(
spend_key
)
current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key)
else:
# use in-memory cache if Redis is not initialized
current_spend = await self.router_cache.async_get_cache(spend_key)
current_spend = await self.dual_cache.async_get_cache(spend_key)
return float(current_spend) if current_spend is not None else 0.0
async def _get_current_provider_budget_reset_at(
@ -657,10 +655,10 @@ class RouterBudgetLimiting(CustomLogger):
return None
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
if self.router_cache.redis_cache:
ttl_seconds = await self.router_cache.redis_cache.async_get_ttl(spend_key)
if self.dual_cache.redis_cache:
ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
else:
ttl_seconds = await self.router_cache.async_get_ttl(spend_key)
ttl_seconds = await self.dual_cache.async_get_ttl(spend_key)
if ttl_seconds is None:
return None
@ -679,16 +677,16 @@ class RouterBudgetLimiting(CustomLogger):
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
start_time_key = f"provider_budget_start_time:{provider}"
ttl_seconds = duration_in_seconds(budget_config.time_period)
budget_start = await self.router_cache.async_get_cache(start_time_key)
budget_start = await self.dual_cache.async_get_cache(start_time_key)
if budget_start is None:
budget_start = datetime.now(timezone.utc).timestamp()
await self.router_cache.async_set_cache(
await self.dual_cache.async_set_cache(
key=start_time_key, value=budget_start, ttl=ttl_seconds
)
_spend_key = await self.router_cache.async_get_cache(spend_key)
_spend_key = await self.dual_cache.async_get_cache(spend_key)
if _spend_key is None:
await self.router_cache.async_set_cache(
await self.dual_cache.async_set_cache(
key=spend_key, value=0.0, ttl=ttl_seconds
)

View file

@ -11,6 +11,8 @@ import httpx
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Required, TypedDict
from litellm.types.utils import GenericBudgetConfigType, GenericBudgetInfo
from ..exceptions import RateLimitError
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
@ -647,14 +649,6 @@ class RoutingStrategy(enum.Enum):
PROVIDER_BUDGET_LIMITING = "provider-budget-routing"
class GenericBudgetInfo(BaseModel):
time_period: str # e.g., '1d', '30d'
budget_limit: float
GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
class RouterCacheEnum(enum.Enum):
TPM = "global_router:{id}:{model}:tpm:{current_minute}"
RPM = "global_router:{id}:{model}:rpm:{current_minute}"

View file

@ -1665,6 +1665,14 @@ class StandardKeyGenerationConfig(TypedDict, total=False):
personal_key_generation: PersonalUIKeyGenerationConfig
class GenericBudgetInfo(BaseModel):
time_period: str # e.g., '1d', '30d'
budget_limit: float
GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
class BudgetConfig(BaseModel):
max_budget: float
budget_duration: str

View file

@ -183,7 +183,7 @@ async def test_get_llm_provider_for_deployment():
"""
cleanup_redis()
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
dual_cache=DualCache(), provider_budget_config={}
)
# Test OpenAI deployment
@ -220,7 +220,7 @@ async def test_get_budget_config_for_provider():
}
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config=config
dual_cache=DualCache(), provider_budget_config=config
)
# Test existing providers
@ -252,7 +252,7 @@ async def test_prometheus_metric_tracking():
# Setup provider budget limiting
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(),
dual_cache=DualCache(),
provider_budget_config={
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100)
},
@ -316,7 +316,7 @@ async def test_handle_new_budget_window():
"""
cleanup_redis()
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
dual_cache=DualCache(), provider_budget_config={}
)
spend_key = "provider_spend:openai:7d"
@ -337,12 +337,12 @@ async def test_handle_new_budget_window():
assert new_start_time == current_time
# Verify the spend was set correctly
spend = await provider_budget.router_cache.async_get_cache(spend_key)
spend = await provider_budget.dual_cache.async_get_cache(spend_key)
print("spend in cache for key", spend_key, "is", spend)
assert float(spend) == response_cost
# Verify start time was set correctly
start_time = await provider_budget.router_cache.async_get_cache(start_time_key)
start_time = await provider_budget.dual_cache.async_get_cache(start_time_key)
print("start time in cache for key", start_time_key, "is", start_time)
assert float(start_time) == current_time
@ -357,7 +357,7 @@ async def test_get_or_set_budget_start_time():
"""
cleanup_redis()
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
dual_cache=DualCache(), provider_budget_config={}
)
start_time_key = "test_start_time"
@ -398,7 +398,7 @@ async def test_increment_spend_in_current_window():
"""
cleanup_redis()
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
dual_cache=DualCache(), provider_budget_config={}
)
spend_key = "provider_spend:openai:1d"
@ -406,9 +406,7 @@ async def test_increment_spend_in_current_window():
ttl = 86400 # 1 day
# Set initial spend
await provider_budget.router_cache.async_set_cache(
key=spend_key, value=1.0, ttl=ttl
)
await provider_budget.dual_cache.async_set_cache(key=spend_key, value=1.0, ttl=ttl)
# Test incrementing spend
await provider_budget._increment_spend_in_current_window(
@ -418,7 +416,7 @@ async def test_increment_spend_in_current_window():
)
# Verify the spend was incremented correctly in memory
spend = await provider_budget.router_cache.async_get_cache(spend_key)
spend = await provider_budget.dual_cache.async_get_cache(spend_key)
assert float(spend) == 1.5
# Verify the increment operation was queued for Redis
@ -449,7 +447,7 @@ async def test_sync_in_memory_spend_with_redis():
}
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(
dual_cache=DualCache(
redis_cache=RedisCache(
host=os.getenv("REDIS_HOST"),
port=int(os.getenv("REDIS_PORT")),
@ -463,10 +461,10 @@ async def test_sync_in_memory_spend_with_redis():
spend_key_openai = "provider_spend:openai:1d"
spend_key_anthropic = "provider_spend:anthropic:1d"
await provider_budget.router_cache.redis_cache.async_set_cache(
await provider_budget.dual_cache.redis_cache.async_set_cache(
key=spend_key_openai, value=50.0
)
await provider_budget.router_cache.redis_cache.async_set_cache(
await provider_budget.dual_cache.redis_cache.async_set_cache(
key=spend_key_anthropic, value=75.0
)
@ -474,14 +472,12 @@ async def test_sync_in_memory_spend_with_redis():
await provider_budget._sync_in_memory_spend_with_redis()
# Verify in-memory cache was updated
openai_spend = await provider_budget.router_cache.in_memory_cache.async_get_cache(
openai_spend = await provider_budget.dual_cache.in_memory_cache.async_get_cache(
spend_key_openai
)
anthropic_spend = (
await provider_budget.router_cache.in_memory_cache.async_get_cache(
anthropic_spend = await provider_budget.dual_cache.in_memory_cache.async_get_cache(
spend_key_anthropic
)
)
assert float(openai_spend) == 50.0
assert float(anthropic_spend) == 75.0
@ -499,7 +495,7 @@ async def test_get_current_provider_spend():
"""
cleanup_redis()
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(),
dual_cache=DualCache(),
provider_budget_config={
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
},
@ -515,7 +511,7 @@ async def test_get_current_provider_spend():
# Test provider with budget config and spend
spend_key = "provider_spend:openai:1d"
await provider_budget.router_cache.async_set_cache(key=spend_key, value=50.5)
await provider_budget.dual_cache.async_set_cache(key=spend_key, value=50.5)
spend = await provider_budget._get_current_provider_spend("openai")
assert spend == 50.5
@ -534,7 +530,7 @@ async def test_get_current_provider_budget_reset_at():
"""
cleanup_redis()
provider_budget = RouterBudgetLimiting(
router_cache=DualCache(
dual_cache=DualCache(
redis_cache=RedisCache(
host=os.getenv("REDIS_HOST"),
port=int(os.getenv("REDIS_PORT")),

View file

@ -1684,22 +1684,51 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
print(vars(e))
def test_call_with_key_over_model_budget(prisma_client):
@pytest.mark.asyncio()
@pytest.mark.parametrize(
"request_model,should_pass",
[
("openai/gpt-4o-mini", False),
("gpt-4o-mini", False),
("gpt-4o", True),
],
)
async def test_call_with_key_over_model_budget(
prisma_client, request_model, should_pass
):
# 12. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
verbose_proxy_logger.setLevel(logging.DEBUG)
# init model max budget limiter
from litellm.proxy.hooks.model_max_budget_limiter import (
_PROXY_VirtualKeyModelMaxBudgetLimiter,
)
model_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=DualCache()
)
litellm.callbacks.append(model_budget_limiter)
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
# set budget for chatgpt-v-2 to 0.000001, expect the next request to fail
request = GenerateKeyRequest(
max_budget=1000,
model_max_budget={
"chatgpt-v-2": 0.000001,
model_max_budget = {
"gpt-4o-mini": {
"budget_limit": "0.000001",
"time_period": "1d",
},
metadata={"user_api_key": 0.0001},
"gpt-4o": {
"budget_limit": "200",
"time_period": "30d",
},
}
request = GenerateKeyRequest(
max_budget=100000, # the key itself has a very high budget
model_max_budget=model_max_budget,
)
key = await generate_key_fn(request)
print(key)
@ -1712,7 +1741,8 @@ def test_call_with_key_over_model_budget(prisma_client):
request._url = URL(url="/chat/completions")
async def return_body():
return b'{"model": "chatgpt-v-2"}'
request_str = f'{{"model": "{request_model}"}}' # Added extra curly braces to escape JSON
return request_str.encode()
request.body = return_body
@ -1721,87 +1751,42 @@ def test_call_with_key_over_model_budget(prisma_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm import Choices, Message, ModelResponse, Usage
from litellm.caching.caching import Cache
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
litellm.cache = Cache()
import time
import uuid
request_id = f"chatcmpl-{uuid.uuid4()}"
resp = ModelResponse(
id=request_id,
choices=[
Choices(
finish_reason=None,
index=0,
message=Message(
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
role="assistant",
),
)
],
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
await track_cost_callback(
kwargs={
"model": "chatgpt-v-2",
"stream": False,
"litellm_params": {
"metadata": {
await litellm.acompletion(
model=request_model,
messages=[{"role": "user", "content": "Hello, how are you?"}],
metadata={
"user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id,
}
"user_api_key_model_max_budget": model_max_budget,
},
"response_cost": 0.00002,
},
completion_response=resp,
start_time=datetime.now(),
end_time=datetime.now(),
)
await update_spend(
prisma_client=prisma_client,
db_writer_client=None,
proxy_logging_obj=proxy_logging_obj,
)
# test spend_log was written and we can read it
spend_logs = await view_spend_logs(
request_id=request_id,
user_api_key_dict=UserAPIKeyAuth(api_key=generated_key),
)
print("read spend logs", spend_logs)
assert len(spend_logs) == 1
spend_log = spend_logs[0]
assert spend_log.request_id == request_id
assert spend_log.spend == float("2e-05")
assert spend_log.model == "chatgpt-v-2"
assert (
spend_log.cache_key
== "c891d64397a472e6deb31b87a5ac4d3ed5b2dcc069bc87e2afe91e6d64e95a1e"
)
await asyncio.sleep(2)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
if should_pass is True:
print(
f"Passed request for model={request_model}, model_max_budget={model_max_budget}"
)
return
print("result from user auth with new key", result)
pytest.fail("This should have failed!. They key crossed it's budget")
asyncio.run(test())
except Exception as e:
# print(f"Error - {str(e)}")
print(
f"Failed request for model={request_model}, model_max_budget={model_max_budget}"
)
assert (
should_pass is False
), f"This should have failed!. They key crossed it's budget for model={request_model}. {e}"
traceback.print_exc()
error_detail = e.message
assert "Budget has been exceeded!" in error_detail
assert f"exceeded budget for model={request_model}" in error_detail
assert isinstance(e, ProxyException)
assert e.type == ProxyErrorTypes.budget_exceeded
print(vars(e))
finally:
litellm.callbacks.remove(model_budget_limiter)
@pytest.mark.asyncio()

View file

@ -0,0 +1,127 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
from pydantic.main import Model
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
from datetime import datetime as dt_object
import time
import pytest
import litellm
import json
from litellm.types.utils import GenericBudgetInfo
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest
from litellm.caching.caching import DualCache
from litellm.proxy.hooks.model_max_budget_limiter import (
_PROXY_VirtualKeyModelMaxBudgetLimiter,
)
from litellm.proxy._types import UserAPIKeyAuth
import litellm
# Test class setup
@pytest.fixture
def budget_limiter():
dual_cache = DualCache()
return _PROXY_VirtualKeyModelMaxBudgetLimiter(dual_cache=dual_cache)
# Test _get_model_without_custom_llm_provider
def test_get_model_without_custom_llm_provider(budget_limiter):
# Test with custom provider
assert (
budget_limiter._get_model_without_custom_llm_provider("openai/gpt-4") == "gpt-4"
)
# Test without custom provider
assert budget_limiter._get_model_without_custom_llm_provider("gpt-4") == "gpt-4"
# Test _get_request_model_budget_config
def test_get_request_model_budget_config(budget_limiter):
internal_budget = {
"gpt-4": GenericBudgetInfo(budget_limit=100.0, time_period="1d"),
"claude-3": GenericBudgetInfo(budget_limit=50.0, time_period="1d"),
}
# Test direct model match
config = budget_limiter._get_request_model_budget_config(
model="gpt-4", internal_model_max_budget=internal_budget
)
assert config.budget_limit == 100.0
# Test model with provider
config = budget_limiter._get_request_model_budget_config(
model="openai/gpt-4", internal_model_max_budget=internal_budget
)
assert config.budget_limit == 100.0
# Test non-existent model
config = budget_limiter._get_request_model_budget_config(
model="non-existent", internal_model_max_budget=internal_budget
)
assert config is None
# Test is_key_within_model_budget
@pytest.mark.asyncio
async def test_is_key_within_model_budget(budget_limiter):
# Mock user API key dict
user_api_key = UserAPIKeyAuth(
token="test-key",
key_alias="test-alias",
model_max_budget={"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
)
# Test when model is within budget
with patch.object(
budget_limiter, "_get_virtual_key_spend_for_model", return_value=50.0
):
assert (
await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-4")
is True
)
# Test when model exceeds budget
with patch.object(
budget_limiter, "_get_virtual_key_spend_for_model", return_value=150.0
):
with pytest.raises(litellm.BudgetExceededError):
await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-4")
# Test model not in budget config
assert (
await budget_limiter.is_key_within_model_budget(user_api_key, "non-existent")
is True
)
# Test _get_virtual_key_spend_for_model
@pytest.mark.asyncio
async def test_get_virtual_key_spend_for_model(budget_limiter):
budget_config = GenericBudgetInfo(budget_limit=100.0, time_period="1d")
# Mock cache get
with patch.object(budget_limiter.dual_cache, "async_get_cache", return_value=50.0):
spend = await budget_limiter._get_virtual_key_spend_for_model(
user_api_key_hash="test-key", model="gpt-4", key_budget_config=budget_config
)
assert spend == 50.0
# Test with provider prefix
spend = await budget_limiter._get_virtual_key_spend_for_model(
user_api_key_hash="test-key",
model="openai/gpt-4",
key_budget_config=budget_config,
)
assert spend == 50.0