mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat proxy) v2 - model max budgets (#7302)
* clean up unused code * add _PROXY_VirtualKeyModelMaxBudgetLimiter * adjust type imports * working _PROXY_VirtualKeyModelMaxBudgetLimiter * fix user_api_key_model_max_budget * fix user_api_key_model_max_budget * update naming * update naming * fix changes to RouterBudgetLimiting * test_call_with_key_over_model_budget * test_call_with_key_over_model_budget * handle _get_request_model_budget_config * e2e test for test_call_with_key_over_model_budget * clean up test * run ci/cd again * add validate_model_max_budget * docs fix * update doc * add e2e testing for _PROXY_VirtualKeyModelMaxBudgetLimiter * test_unit_test_max_model_budget_limiter.py
This commit is contained in:
parent
1a4910f6c0
commit
6220e17ebf
14 changed files with 628 additions and 261 deletions
|
@ -10,7 +10,7 @@ Requirements:
|
|||
|
||||
## Set Budgets
|
||||
|
||||
You can set budgets at 3 levels:
|
||||
You can set budgets at 5 levels:
|
||||
- For the proxy
|
||||
- For an internal user
|
||||
- For a customer (end-user)
|
||||
|
@ -392,32 +392,88 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
|
|||
|
||||
<TabItem value="per-model-key" label="For Key (model specific)">
|
||||
|
||||
Apply model specific budgets on a key.
|
||||
|
||||
**Expected Behaviour**
|
||||
- `model_spend` gets auto-populated in `LiteLLM_VerificationToken` Table
|
||||
- After the key crosses the budget set for the `model` in `model_max_budget`, calls fail
|
||||
|
||||
By default the `model_max_budget` is set to `{}` and is not checked for keys
|
||||
|
||||
:::info
|
||||
|
||||
- LiteLLM will track the cost/budgets for the `model` passed to LLM endpoints (`/chat/completions`, `/embeddings`)
|
||||
|
||||
|
||||
:::
|
||||
Apply model specific budgets on a key. Example:
|
||||
- Budget for `gpt-4o` is $0.0000001, for time period `1d` for `key = "sk-12345"`
|
||||
- Budget for `gpt-4o-mini` is $10, for time period `30d` for `key = "sk-12345"`
|
||||
|
||||
#### **Add model specific budgets to keys**
|
||||
|
||||
The spec for `model_max_budget` is **[`Dict[str, GenericBudgetInfo]`](#genericbudgetinfo)**
|
||||
|
||||
```bash
|
||||
curl 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer <your-master-key>' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
model_max_budget={"gpt4": 0.5, "gpt-5": 0.01}
|
||||
"model_max_budget": {"gpt-4o": {"budget_limit": "0.0000001", "time_period": "1d"}}
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
#### Make a test request
|
||||
|
||||
We expect the first request to succeed, and the second request to fail since we cross the budget for `gpt-4o` on the Virtual Key
|
||||
|
||||
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="Successful Call " value = "allowed">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer <sk-generated-key>' \
|
||||
--data ' {
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "testing request"
|
||||
}
|
||||
]
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem label="Unsuccessful call" value = "not-allowed">
|
||||
|
||||
Expect this to fail since since we cross the budget `model=gpt-4o` on the Virtual Key
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer <sk-generated-key>' \
|
||||
--data ' {
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "testing request"
|
||||
}
|
||||
]
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
Expected response on failure
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "LiteLLM Virtual Key: 9769f3f6768a199f76cc29xxxx, key_alias: None, exceeded budget for model=gpt-4o",
|
||||
"type": "budget_exceeded",
|
||||
"param": null,
|
||||
"code": "400"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
@ -783,3 +839,32 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
|
|||
--header 'Content-Type: application/json' \
|
||||
--data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}'
|
||||
```
|
||||
|
||||
|
||||
## API Specification
|
||||
|
||||
### `GenericBudgetInfo`
|
||||
|
||||
A Pydantic model that defines budget information with a time period and limit.
|
||||
|
||||
```python
|
||||
class GenericBudgetInfo(BaseModel):
|
||||
budget_limit: float # The maximum budget amount in USD
|
||||
time_period: str # Duration string like "1d", "30d", etc.
|
||||
```
|
||||
|
||||
#### Fields:
|
||||
- `budget_limit` (float): The maximum budget amount in USD
|
||||
- `time_period` (str): Duration string specifying the time period for the budget. Supported formats:
|
||||
- Seconds: "30s"
|
||||
- Minutes: "30m"
|
||||
- Hours: "30h"
|
||||
- Days: "30d"
|
||||
|
||||
#### Example:
|
||||
```json
|
||||
{
|
||||
"budget_limit": "0.0001",
|
||||
"time_period": "1d"
|
||||
}
|
||||
```
|
|
@ -262,6 +262,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
llm_model_list,
|
||||
llm_router,
|
||||
master_key,
|
||||
model_max_budget_limiter,
|
||||
open_telemetry_logger,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
|
@ -1053,37 +1054,10 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
and valid_token.token is not None
|
||||
):
|
||||
## GET THE SPEND FOR THIS MODEL
|
||||
twenty_eight_days_ago = datetime.now() - timedelta(days=28)
|
||||
model_spend = await prisma_client.db.litellm_spendlogs.group_by(
|
||||
by=["model"],
|
||||
sum={"spend": True},
|
||||
where={
|
||||
"AND": [
|
||||
{"api_key": valid_token.token},
|
||||
{"startTime": {"gt": twenty_eight_days_ago}},
|
||||
{"model": current_model},
|
||||
]
|
||||
}, # type: ignore
|
||||
await model_max_budget_limiter.is_key_within_model_budget(
|
||||
user_api_key_dict=valid_token,
|
||||
model=current_model,
|
||||
)
|
||||
if (
|
||||
len(model_spend) > 0
|
||||
and max_budget_per_model.get(current_model, None) is not None
|
||||
):
|
||||
if (
|
||||
"model" in model_spend[0]
|
||||
and model_spend[0].get("model") == current_model
|
||||
and "_sum" in model_spend[0]
|
||||
and "spend" in model_spend[0]["_sum"]
|
||||
and model_spend[0]["_sum"]["spend"]
|
||||
>= max_budget_per_model[current_model]
|
||||
):
|
||||
current_model_spend = model_spend[0]["_sum"]["spend"]
|
||||
current_model_budget = max_budget_per_model[current_model]
|
||||
raise litellm.BudgetExceededError(
|
||||
current_cost=current_model_spend,
|
||||
max_budget=current_model_budget,
|
||||
)
|
||||
|
||||
# Check 6. Team spend is under Team budget
|
||||
if (
|
||||
hasattr(valid_token, "team_spend")
|
||||
|
|
196
litellm/proxy/hooks/model_max_budget_limiter.py
Normal file
196
litellm/proxy/hooks/model_max_budget_limiter.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
import json
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
GenericBudgetConfigType,
|
||||
GenericBudgetInfo,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
|
||||
|
||||
|
||||
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
|
||||
"""
|
||||
Handles budgets for model + virtual key
|
||||
|
||||
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
|
||||
"""
|
||||
|
||||
def __init__(self, dual_cache: DualCache):
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
async def is_key_within_model_budget(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
model: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the user_api_key_dict is within the model budget
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If the user_api_key_dict has exceeded the model budget
|
||||
"""
|
||||
_model_max_budget = user_api_key_dict.model_max_budget
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
|
||||
# case each element in _model_max_budget to GenericBudgetInfo
|
||||
for _model, _budget_info in _model_max_budget.items():
|
||||
internal_model_max_budget[_model] = GenericBudgetInfo(
|
||||
time_period=_budget_info.get("time_period"),
|
||||
budget_limit=float(_budget_info.get("budget_limit")),
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"internal_model_max_budget %s",
|
||||
json.dumps(internal_model_max_budget, indent=4, default=str),
|
||||
)
|
||||
|
||||
# check if current model is in internal_model_max_budget
|
||||
_current_model_budget_info = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if _current_model_budget_info is None:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} not found in internal_model_max_budget"
|
||||
)
|
||||
return True
|
||||
|
||||
# check if current model is within budget
|
||||
if _current_model_budget_info.budget_limit > 0:
|
||||
_current_spend = await self._get_virtual_key_spend_for_model(
|
||||
user_api_key_hash=user_api_key_dict.token,
|
||||
model=model,
|
||||
key_budget_config=_current_model_budget_info,
|
||||
)
|
||||
if (
|
||||
_current_spend is not None
|
||||
and _current_spend > _current_model_budget_info.budget_limit
|
||||
):
|
||||
raise litellm.BudgetExceededError(
|
||||
message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
|
||||
current_cost=_current_spend,
|
||||
max_budget=_current_model_budget_info.budget_limit,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _get_virtual_key_spend_for_model(
|
||||
self,
|
||||
user_api_key_hash: Optional[str],
|
||||
model: str,
|
||||
key_budget_config: GenericBudgetInfo,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the current spend for a virtual key for a model
|
||||
|
||||
Lookup model in this order:
|
||||
1. model: directly look up `model`
|
||||
2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||
"""
|
||||
|
||||
# 1. model: directly look up `model`
|
||||
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.time_period}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=virtual_key_model_spend_cache_key,
|
||||
)
|
||||
|
||||
if _current_spend is None:
|
||||
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||
# if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
|
||||
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.time_period}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=virtual_key_model_spend_cache_key,
|
||||
)
|
||||
return _current_spend
|
||||
|
||||
def _get_request_model_budget_config(
|
||||
self, model: str, internal_model_max_budget: GenericBudgetConfigType
|
||||
) -> Optional[GenericBudgetInfo]:
|
||||
"""
|
||||
Get the budget config for the request model
|
||||
|
||||
1. Check if `model` is in `internal_model_max_budget`
|
||||
2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
|
||||
"""
|
||||
return internal_model_max_budget.get(
|
||||
model, None
|
||||
) or internal_model_max_budget.get(
|
||||
self._get_model_without_custom_llm_provider(model), None
|
||||
)
|
||||
|
||||
def _get_model_without_custom_llm_provider(self, model: str) -> str:
|
||||
if "/" in model:
|
||||
return model.split("/")[-1]
|
||||
return model
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None, # type: ignore
|
||||
) -> List[dict]:
|
||||
return healthy_deployments
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Track spend for virtual key + model in DualCache
|
||||
|
||||
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
|
||||
"""
|
||||
verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_payload is required")
|
||||
|
||||
_litellm_params = kwargs.get("litellm_params", {})
|
||||
_metadata = _litellm_params.get("metadata", {})
|
||||
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
|
||||
"user_api_key_model_max_budget", None
|
||||
)
|
||||
if (
|
||||
user_api_key_model_max_budget is None
|
||||
or len(user_api_key_model_max_budget) == 0
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
|
||||
user_api_key_model_max_budget,
|
||||
)
|
||||
return
|
||||
response_cost: float = standard_logging_payload.get("response_cost", 0)
|
||||
model = standard_logging_payload.get("model")
|
||||
|
||||
virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash")
|
||||
model = standard_logging_payload.get("model")
|
||||
if virtual_key is not None:
|
||||
budget_config = GenericBudgetInfo(time_period="1d", budget_limit=0.1)
|
||||
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.time_period}"
|
||||
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=budget_config,
|
||||
spend_key=virtual_spend_key,
|
||||
start_time_key=virtual_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"current state of in memory cache %s",
|
||||
json.dumps(
|
||||
self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
|
||||
),
|
||||
)
|
|
@ -499,6 +499,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
data[_metadata_variable_name][
|
||||
"user_api_key_max_budget"
|
||||
] = user_api_key_dict.max_budget
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key_model_max_budget"
|
||||
] = user_api_key_dict.model_max_budget
|
||||
|
||||
data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||
_headers = dict(request.headers)
|
||||
|
|
|
@ -40,7 +40,11 @@ from litellm.proxy.utils import (
|
|||
handle_exception_on_proxy,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.utils import PersonalUIKeyGenerationConfig, TeamUIKeyGenerationConfig
|
||||
from litellm.types.utils import (
|
||||
GenericBudgetInfo,
|
||||
PersonalUIKeyGenerationConfig,
|
||||
TeamUIKeyGenerationConfig,
|
||||
)
|
||||
|
||||
|
||||
def _is_team_key(data: GenerateKeyRequest):
|
||||
|
@ -246,7 +250,7 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
||||
- guardrails: Optional[List[str]] - List of active guardrails for the key
|
||||
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
|
||||
- model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget.
|
||||
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget.
|
||||
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
|
||||
- model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit.
|
||||
- allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request
|
||||
|
@ -515,6 +519,10 @@ def prepare_key_update_data(
|
|||
|
||||
_metadata = existing_key_row.metadata or {}
|
||||
|
||||
# validate model_max_budget
|
||||
if "model_max_budget" in non_default_values:
|
||||
validate_model_max_budget(non_default_values["model_max_budget"])
|
||||
|
||||
non_default_values = prepare_metadata_fields(
|
||||
data=data, non_default_values=non_default_values, existing_metadata=_metadata
|
||||
)
|
||||
|
@ -548,7 +556,7 @@ async def update_key_fn(
|
|||
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
|
||||
- spend: Optional[float] - Amount spent by key
|
||||
- max_budget: Optional[float] - Max budget for key
|
||||
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}
|
||||
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}
|
||||
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
||||
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
|
||||
- max_parallel_requests: Optional[int] - Rate limit for parallel requests
|
||||
|
@ -1035,6 +1043,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
|||
metadata["guardrails"] = guardrails
|
||||
|
||||
metadata_json = json.dumps(metadata)
|
||||
validate_model_max_budget(model_max_budget)
|
||||
model_max_budget_json = json.dumps(model_max_budget)
|
||||
user_role = user_role
|
||||
tpm_limit = tpm_limit
|
||||
|
@ -1266,7 +1275,7 @@ async def regenerate_key_fn(
|
|||
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
|
||||
- spend: Optional[float] - Amount spent by key
|
||||
- max_budget: Optional[float] - Max budget for key
|
||||
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}
|
||||
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}
|
||||
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
||||
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
|
||||
- max_parallel_requests: Optional[int] - Rate limit for parallel requests
|
||||
|
@ -1293,8 +1302,7 @@ async def regenerate_key_fn(
|
|||
--data-raw '{
|
||||
"max_budget": 100,
|
||||
"metadata": {"team": "core-infra"},
|
||||
"models": ["gpt-4", "gpt-3.5-turbo"],
|
||||
"model_max_budget": {"gpt-4": 50, "gpt-3.5-turbo": 50}
|
||||
"models": ["gpt-4", "gpt-3.5-turbo"]
|
||||
}'
|
||||
```
|
||||
|
||||
|
@ -1949,3 +1957,29 @@ async def _enforce_unique_key_alias(
|
|||
param="key_alias",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
def validate_model_max_budget(model_max_budget: Optional[Dict]) -> None:
|
||||
"""
|
||||
Validate the model_max_budget is GenericBudgetConfigType
|
||||
|
||||
Raises:
|
||||
Exception: If model_max_budget is not a valid GenericBudgetConfigType
|
||||
"""
|
||||
try:
|
||||
if model_max_budget is None:
|
||||
return
|
||||
if len(model_max_budget) == 0:
|
||||
return
|
||||
if model_max_budget is not None:
|
||||
for _model, _budget_info in model_max_budget.items():
|
||||
assert isinstance(_model, str)
|
||||
|
||||
# /CRUD endpoints can pass budget_limit as a string, so we need to convert it to a float
|
||||
if "budget_limit" in _budget_info:
|
||||
_budget_info["budget_limit"] = float(_budget_info["budget_limit"])
|
||||
GenericBudgetInfo(**_budget_info)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Invalid model_max_budget: {str(e)}. Example of valid model_max_budget: https://docs.litellm.ai/docs/proxy/users"
|
||||
)
|
||||
|
|
|
@ -1,42 +1,10 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
- model_name: openai/o1-preview
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
tags: ["teamA"]
|
||||
model_info:
|
||||
id: "team-a-model"
|
||||
- model_name: fake-openai-endpoint
|
||||
model: openai/o1-preview
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
tags: ["teamB"]
|
||||
model_info:
|
||||
id: "team-b-model"
|
||||
- model_name: rerank-english-v3.0
|
||||
litellm_params:
|
||||
model: cohere/rerank-english-v3.0
|
||||
api_key: os.environ/COHERE_API_KEY
|
||||
- model_name: fake-azure-endpoint
|
||||
litellm_params:
|
||||
model: openai/429
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
- model_name: llava-hf
|
||||
litellm_params:
|
||||
model: openai/llava-hf/llava-v1.6-vicuna-7b-hf
|
||||
api_base: http://localhost:8000
|
||||
api_key: fake-key
|
||||
model_info:
|
||||
supports_vision: True
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
|
||||
|
||||
litellm_settings:
|
||||
cache: true
|
||||
callbacks: ["otel", "prometheus"]
|
||||
|
||||
router_settings:
|
||||
enable_tag_filtering: True # 👈 Key Change
|
|
@ -173,6 +173,9 @@ from litellm.proxy.guardrails.init_guardrails import (
|
|||
)
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||
from litellm.proxy.hooks.model_max_budget_limiter import (
|
||||
_PROXY_VirtualKeyModelMaxBudgetLimiter,
|
||||
)
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
|
@ -495,6 +498,10 @@ prisma_client: Optional[PrismaClient] = None
|
|||
user_api_key_cache = DualCache(
|
||||
default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
|
||||
)
|
||||
model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||
dual_cache=user_api_key_cache
|
||||
)
|
||||
litellm.callbacks.append(model_max_budget_limiter)
|
||||
redis_usage_cache: Optional[RedisCache] = (
|
||||
None # redis cache used for tracking spend, tpm/rpm limits
|
||||
)
|
||||
|
|
|
@ -631,7 +631,7 @@ class Router:
|
|||
_callback = PromptCachingDeploymentCheck(cache=self.cache)
|
||||
elif pre_call_check == "router_budget_limiting":
|
||||
_callback = RouterBudgetLimiting(
|
||||
router_cache=self.cache,
|
||||
dual_cache=self.cache,
|
||||
provider_budget_config=self.provider_budget_config,
|
||||
model_list=self.model_list,
|
||||
)
|
||||
|
@ -5292,14 +5292,6 @@ class Router:
|
|||
healthy_deployments=healthy_deployments,
|
||||
)
|
||||
|
||||
# if self.router_budget_logger:
|
||||
# healthy_deployments = (
|
||||
# await self.router_budget_logger.async_filter_deployments(
|
||||
# healthy_deployments=healthy_deployments,
|
||||
# request_kwargs=request_kwargs,
|
||||
# )
|
||||
# )
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
exception = await async_raise_no_deployment_exception(
|
||||
litellm_router_instance=self,
|
||||
|
|
|
@ -49,13 +49,13 @@ DEFAULT_REDIS_SYNC_INTERVAL = 1
|
|||
class RouterBudgetLimiting(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
router_cache: DualCache,
|
||||
dual_cache: DualCache,
|
||||
provider_budget_config: Optional[dict],
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
self.router_cache = router_cache
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
|
||||
self.provider_budget_config: Optional[GenericBudgetConfigType] = (
|
||||
|
@ -108,7 +108,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
|
||||
# Single cache read for all spend values
|
||||
if len(cache_keys) > 0:
|
||||
_current_spends = await self.router_cache.async_batch_get_cache(
|
||||
_current_spends = await self.dual_cache.async_batch_get_cache(
|
||||
keys=cache_keys,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
@ -286,9 +286,9 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
If it does, return the value.
|
||||
If it does not, set the key to `current_time` and return the value.
|
||||
"""
|
||||
budget_start = await self.router_cache.async_get_cache(start_time_key)
|
||||
budget_start = await self.dual_cache.async_get_cache(start_time_key)
|
||||
if budget_start is None:
|
||||
await self.router_cache.async_set_cache(
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=current_time, ttl=ttl_seconds
|
||||
)
|
||||
return current_time
|
||||
|
@ -314,10 +314,10 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
- stores key: `provider_budget_start_time:{provider}`, value: current_time.
|
||||
This stores the start time of the new budget window
|
||||
"""
|
||||
await self.router_cache.async_set_cache(
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=spend_key, value=response_cost, ttl=ttl_seconds
|
||||
)
|
||||
await self.router_cache.async_set_cache(
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=current_time, ttl=ttl_seconds
|
||||
)
|
||||
return current_time
|
||||
|
@ -333,7 +333,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
- Increments the spend in memory cache (so spend instantly updated in memory)
|
||||
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
|
||||
"""
|
||||
await self.router_cache.in_memory_cache.async_increment(
|
||||
await self.dual_cache.in_memory_cache.async_increment(
|
||||
key=spend_key,
|
||||
value=response_cost,
|
||||
ttl=ttl,
|
||||
|
@ -481,7 +481,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
Only runs if Redis is initialized
|
||||
"""
|
||||
try:
|
||||
if not self.router_cache.redis_cache:
|
||||
if not self.dual_cache.redis_cache:
|
||||
return # Redis is not initialized
|
||||
|
||||
verbose_router_logger.debug(
|
||||
|
@ -490,7 +490,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
)
|
||||
if len(self.redis_increment_operation_queue) > 0:
|
||||
asyncio.create_task(
|
||||
self.router_cache.redis_cache.async_increment_pipeline(
|
||||
self.dual_cache.redis_cache.async_increment_pipeline(
|
||||
increment_list=self.redis_increment_operation_queue,
|
||||
)
|
||||
)
|
||||
|
@ -517,7 +517,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
|
||||
try:
|
||||
# No need to sync if Redis cache is not initialized
|
||||
if self.router_cache.redis_cache is None:
|
||||
if self.dual_cache.redis_cache is None:
|
||||
return
|
||||
|
||||
# 1. Push all provider spend increments to Redis
|
||||
|
@ -547,7 +547,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
cache_keys.append(f"tag_spend:{tag}:{config.time_period}")
|
||||
|
||||
# Batch fetch current spend values from Redis
|
||||
redis_values = await self.router_cache.redis_cache.async_batch_get_cache(
|
||||
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
|
||||
key_list=cache_keys
|
||||
)
|
||||
|
||||
|
@ -555,7 +555,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
|
||||
for key, value in redis_values.items():
|
||||
if value is not None:
|
||||
await self.router_cache.in_memory_cache.async_set_cache(
|
||||
await self.dual_cache.in_memory_cache.async_set_cache(
|
||||
key=key, value=float(value)
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
|
@ -639,14 +639,12 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
|
||||
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
||||
|
||||
if self.router_cache.redis_cache:
|
||||
if self.dual_cache.redis_cache:
|
||||
# use Redis as source of truth since that has spend across all instances
|
||||
current_spend = await self.router_cache.redis_cache.async_get_cache(
|
||||
spend_key
|
||||
)
|
||||
current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key)
|
||||
else:
|
||||
# use in-memory cache if Redis is not initialized
|
||||
current_spend = await self.router_cache.async_get_cache(spend_key)
|
||||
current_spend = await self.dual_cache.async_get_cache(spend_key)
|
||||
return float(current_spend) if current_spend is not None else 0.0
|
||||
|
||||
async def _get_current_provider_budget_reset_at(
|
||||
|
@ -657,10 +655,10 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
return None
|
||||
|
||||
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
||||
if self.router_cache.redis_cache:
|
||||
ttl_seconds = await self.router_cache.redis_cache.async_get_ttl(spend_key)
|
||||
if self.dual_cache.redis_cache:
|
||||
ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
|
||||
else:
|
||||
ttl_seconds = await self.router_cache.async_get_ttl(spend_key)
|
||||
ttl_seconds = await self.dual_cache.async_get_ttl(spend_key)
|
||||
|
||||
if ttl_seconds is None:
|
||||
return None
|
||||
|
@ -679,16 +677,16 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
||||
start_time_key = f"provider_budget_start_time:{provider}"
|
||||
ttl_seconds = duration_in_seconds(budget_config.time_period)
|
||||
budget_start = await self.router_cache.async_get_cache(start_time_key)
|
||||
budget_start = await self.dual_cache.async_get_cache(start_time_key)
|
||||
if budget_start is None:
|
||||
budget_start = datetime.now(timezone.utc).timestamp()
|
||||
await self.router_cache.async_set_cache(
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=budget_start, ttl=ttl_seconds
|
||||
)
|
||||
|
||||
_spend_key = await self.router_cache.async_get_cache(spend_key)
|
||||
_spend_key = await self.dual_cache.async_get_cache(spend_key)
|
||||
if _spend_key is None:
|
||||
await self.router_cache.async_set_cache(
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=spend_key, value=0.0, ttl=ttl_seconds
|
||||
)
|
||||
|
||||
|
|
|
@ -11,6 +11,8 @@ import httpx
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from litellm.types.utils import GenericBudgetConfigType, GenericBudgetInfo
|
||||
|
||||
from ..exceptions import RateLimitError
|
||||
from .completion import CompletionRequest
|
||||
from .embedding import EmbeddingRequest
|
||||
|
@ -647,14 +649,6 @@ class RoutingStrategy(enum.Enum):
|
|||
PROVIDER_BUDGET_LIMITING = "provider-budget-routing"
|
||||
|
||||
|
||||
class GenericBudgetInfo(BaseModel):
|
||||
time_period: str # e.g., '1d', '30d'
|
||||
budget_limit: float
|
||||
|
||||
|
||||
GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
|
||||
|
||||
|
||||
class RouterCacheEnum(enum.Enum):
|
||||
TPM = "global_router:{id}:{model}:tpm:{current_minute}"
|
||||
RPM = "global_router:{id}:{model}:rpm:{current_minute}"
|
||||
|
|
|
@ -1665,6 +1665,14 @@ class StandardKeyGenerationConfig(TypedDict, total=False):
|
|||
personal_key_generation: PersonalUIKeyGenerationConfig
|
||||
|
||||
|
||||
class GenericBudgetInfo(BaseModel):
|
||||
time_period: str # e.g., '1d', '30d'
|
||||
budget_limit: float
|
||||
|
||||
|
||||
GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
|
||||
|
||||
|
||||
class BudgetConfig(BaseModel):
|
||||
max_budget: float
|
||||
budget_duration: str
|
||||
|
|
|
@ -183,7 +183,7 @@ async def test_get_llm_provider_for_deployment():
|
|||
"""
|
||||
cleanup_redis()
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(), provider_budget_config={}
|
||||
dual_cache=DualCache(), provider_budget_config={}
|
||||
)
|
||||
|
||||
# Test OpenAI deployment
|
||||
|
@ -220,7 +220,7 @@ async def test_get_budget_config_for_provider():
|
|||
}
|
||||
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(), provider_budget_config=config
|
||||
dual_cache=DualCache(), provider_budget_config=config
|
||||
)
|
||||
|
||||
# Test existing providers
|
||||
|
@ -252,7 +252,7 @@ async def test_prometheus_metric_tracking():
|
|||
|
||||
# Setup provider budget limiting
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(),
|
||||
dual_cache=DualCache(),
|
||||
provider_budget_config={
|
||||
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100)
|
||||
},
|
||||
|
@ -316,7 +316,7 @@ async def test_handle_new_budget_window():
|
|||
"""
|
||||
cleanup_redis()
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(), provider_budget_config={}
|
||||
dual_cache=DualCache(), provider_budget_config={}
|
||||
)
|
||||
|
||||
spend_key = "provider_spend:openai:7d"
|
||||
|
@ -337,12 +337,12 @@ async def test_handle_new_budget_window():
|
|||
assert new_start_time == current_time
|
||||
|
||||
# Verify the spend was set correctly
|
||||
spend = await provider_budget.router_cache.async_get_cache(spend_key)
|
||||
spend = await provider_budget.dual_cache.async_get_cache(spend_key)
|
||||
print("spend in cache for key", spend_key, "is", spend)
|
||||
assert float(spend) == response_cost
|
||||
|
||||
# Verify start time was set correctly
|
||||
start_time = await provider_budget.router_cache.async_get_cache(start_time_key)
|
||||
start_time = await provider_budget.dual_cache.async_get_cache(start_time_key)
|
||||
print("start time in cache for key", start_time_key, "is", start_time)
|
||||
assert float(start_time) == current_time
|
||||
|
||||
|
@ -357,7 +357,7 @@ async def test_get_or_set_budget_start_time():
|
|||
"""
|
||||
cleanup_redis()
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(), provider_budget_config={}
|
||||
dual_cache=DualCache(), provider_budget_config={}
|
||||
)
|
||||
|
||||
start_time_key = "test_start_time"
|
||||
|
@ -398,7 +398,7 @@ async def test_increment_spend_in_current_window():
|
|||
"""
|
||||
cleanup_redis()
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(), provider_budget_config={}
|
||||
dual_cache=DualCache(), provider_budget_config={}
|
||||
)
|
||||
|
||||
spend_key = "provider_spend:openai:1d"
|
||||
|
@ -406,9 +406,7 @@ async def test_increment_spend_in_current_window():
|
|||
ttl = 86400 # 1 day
|
||||
|
||||
# Set initial spend
|
||||
await provider_budget.router_cache.async_set_cache(
|
||||
key=spend_key, value=1.0, ttl=ttl
|
||||
)
|
||||
await provider_budget.dual_cache.async_set_cache(key=spend_key, value=1.0, ttl=ttl)
|
||||
|
||||
# Test incrementing spend
|
||||
await provider_budget._increment_spend_in_current_window(
|
||||
|
@ -418,7 +416,7 @@ async def test_increment_spend_in_current_window():
|
|||
)
|
||||
|
||||
# Verify the spend was incremented correctly in memory
|
||||
spend = await provider_budget.router_cache.async_get_cache(spend_key)
|
||||
spend = await provider_budget.dual_cache.async_get_cache(spend_key)
|
||||
assert float(spend) == 1.5
|
||||
|
||||
# Verify the increment operation was queued for Redis
|
||||
|
@ -449,7 +447,7 @@ async def test_sync_in_memory_spend_with_redis():
|
|||
}
|
||||
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(
|
||||
dual_cache=DualCache(
|
||||
redis_cache=RedisCache(
|
||||
host=os.getenv("REDIS_HOST"),
|
||||
port=int(os.getenv("REDIS_PORT")),
|
||||
|
@ -463,10 +461,10 @@ async def test_sync_in_memory_spend_with_redis():
|
|||
spend_key_openai = "provider_spend:openai:1d"
|
||||
spend_key_anthropic = "provider_spend:anthropic:1d"
|
||||
|
||||
await provider_budget.router_cache.redis_cache.async_set_cache(
|
||||
await provider_budget.dual_cache.redis_cache.async_set_cache(
|
||||
key=spend_key_openai, value=50.0
|
||||
)
|
||||
await provider_budget.router_cache.redis_cache.async_set_cache(
|
||||
await provider_budget.dual_cache.redis_cache.async_set_cache(
|
||||
key=spend_key_anthropic, value=75.0
|
||||
)
|
||||
|
||||
|
@ -474,14 +472,12 @@ async def test_sync_in_memory_spend_with_redis():
|
|||
await provider_budget._sync_in_memory_spend_with_redis()
|
||||
|
||||
# Verify in-memory cache was updated
|
||||
openai_spend = await provider_budget.router_cache.in_memory_cache.async_get_cache(
|
||||
openai_spend = await provider_budget.dual_cache.in_memory_cache.async_get_cache(
|
||||
spend_key_openai
|
||||
)
|
||||
anthropic_spend = (
|
||||
await provider_budget.router_cache.in_memory_cache.async_get_cache(
|
||||
anthropic_spend = await provider_budget.dual_cache.in_memory_cache.async_get_cache(
|
||||
spend_key_anthropic
|
||||
)
|
||||
)
|
||||
|
||||
assert float(openai_spend) == 50.0
|
||||
assert float(anthropic_spend) == 75.0
|
||||
|
@ -499,7 +495,7 @@ async def test_get_current_provider_spend():
|
|||
"""
|
||||
cleanup_redis()
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(),
|
||||
dual_cache=DualCache(),
|
||||
provider_budget_config={
|
||||
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
|
||||
},
|
||||
|
@ -515,7 +511,7 @@ async def test_get_current_provider_spend():
|
|||
|
||||
# Test provider with budget config and spend
|
||||
spend_key = "provider_spend:openai:1d"
|
||||
await provider_budget.router_cache.async_set_cache(key=spend_key, value=50.5)
|
||||
await provider_budget.dual_cache.async_set_cache(key=spend_key, value=50.5)
|
||||
|
||||
spend = await provider_budget._get_current_provider_spend("openai")
|
||||
assert spend == 50.5
|
||||
|
@ -534,7 +530,7 @@ async def test_get_current_provider_budget_reset_at():
|
|||
"""
|
||||
cleanup_redis()
|
||||
provider_budget = RouterBudgetLimiting(
|
||||
router_cache=DualCache(
|
||||
dual_cache=DualCache(
|
||||
redis_cache=RedisCache(
|
||||
host=os.getenv("REDIS_HOST"),
|
||||
port=int(os.getenv("REDIS_PORT")),
|
||||
|
|
|
@ -1684,22 +1684,51 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
|
|||
print(vars(e))
|
||||
|
||||
|
||||
def test_call_with_key_over_model_budget(prisma_client):
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize(
|
||||
"request_model,should_pass",
|
||||
[
|
||||
("openai/gpt-4o-mini", False),
|
||||
("gpt-4o-mini", False),
|
||||
("gpt-4o", True),
|
||||
],
|
||||
)
|
||||
async def test_call_with_key_over_model_budget(
|
||||
prisma_client, request_model, should_pass
|
||||
):
|
||||
# 12. Make a call with a key over budget, expect to fail
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||
|
||||
# init model max budget limiter
|
||||
from litellm.proxy.hooks.model_max_budget_limiter import (
|
||||
_PROXY_VirtualKeyModelMaxBudgetLimiter,
|
||||
)
|
||||
|
||||
model_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||
dual_cache=DualCache()
|
||||
)
|
||||
litellm.callbacks.append(model_budget_limiter)
|
||||
|
||||
try:
|
||||
|
||||
async def test():
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
# set budget for chatgpt-v-2 to 0.000001, expect the next request to fail
|
||||
request = GenerateKeyRequest(
|
||||
max_budget=1000,
|
||||
model_max_budget={
|
||||
"chatgpt-v-2": 0.000001,
|
||||
model_max_budget = {
|
||||
"gpt-4o-mini": {
|
||||
"budget_limit": "0.000001",
|
||||
"time_period": "1d",
|
||||
},
|
||||
metadata={"user_api_key": 0.0001},
|
||||
"gpt-4o": {
|
||||
"budget_limit": "200",
|
||||
"time_period": "30d",
|
||||
},
|
||||
}
|
||||
|
||||
request = GenerateKeyRequest(
|
||||
max_budget=100000, # the key itself has a very high budget
|
||||
model_max_budget=model_max_budget,
|
||||
)
|
||||
key = await generate_key_fn(request)
|
||||
print(key)
|
||||
|
@ -1712,7 +1741,8 @@ def test_call_with_key_over_model_budget(prisma_client):
|
|||
request._url = URL(url="/chat/completions")
|
||||
|
||||
async def return_body():
|
||||
return b'{"model": "chatgpt-v-2"}'
|
||||
request_str = f'{{"model": "{request_model}"}}' # Added extra curly braces to escape JSON
|
||||
return request_str.encode()
|
||||
|
||||
request.body = return_body
|
||||
|
||||
|
@ -1721,87 +1751,42 @@ def test_call_with_key_over_model_budget(prisma_client):
|
|||
print("result from user auth with new key", result)
|
||||
|
||||
# update spend using track_cost callback, make 2nd request, it should fail
|
||||
from litellm import Choices, Message, ModelResponse, Usage
|
||||
from litellm.caching.caching import Cache
|
||||
from litellm.proxy.proxy_server import (
|
||||
_PROXY_track_cost_callback as track_cost_callback,
|
||||
)
|
||||
|
||||
litellm.cache = Cache()
|
||||
import time
|
||||
import uuid
|
||||
|
||||
request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
|
||||
resp = ModelResponse(
|
||||
id=request_id,
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
message=Message(
|
||||
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"model": "chatgpt-v-2",
|
||||
"stream": False,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
await litellm.acompletion(
|
||||
model=request_model,
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
metadata={
|
||||
"user_api_key": hash_token(generated_key),
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
"user_api_key_model_max_budget": model_max_budget,
|
||||
},
|
||||
"response_cost": 0.00002,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
end_time=datetime.now(),
|
||||
)
|
||||
await update_spend(
|
||||
prisma_client=prisma_client,
|
||||
db_writer_client=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# test spend_log was written and we can read it
|
||||
spend_logs = await view_spend_logs(
|
||||
request_id=request_id,
|
||||
user_api_key_dict=UserAPIKeyAuth(api_key=generated_key),
|
||||
)
|
||||
|
||||
print("read spend logs", spend_logs)
|
||||
assert len(spend_logs) == 1
|
||||
|
||||
spend_log = spend_logs[0]
|
||||
|
||||
assert spend_log.request_id == request_id
|
||||
assert spend_log.spend == float("2e-05")
|
||||
assert spend_log.model == "chatgpt-v-2"
|
||||
assert (
|
||||
spend_log.cache_key
|
||||
== "c891d64397a472e6deb31b87a5ac4d3ed5b2dcc069bc87e2afe91e6d64e95a1e"
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# use generated key to auth in
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
if should_pass is True:
|
||||
print(
|
||||
f"Passed request for model={request_model}, model_max_budget={model_max_budget}"
|
||||
)
|
||||
return
|
||||
print("result from user auth with new key", result)
|
||||
pytest.fail("This should have failed!. They key crossed it's budget")
|
||||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
# print(f"Error - {str(e)}")
|
||||
print(
|
||||
f"Failed request for model={request_model}, model_max_budget={model_max_budget}"
|
||||
)
|
||||
assert (
|
||||
should_pass is False
|
||||
), f"This should have failed!. They key crossed it's budget for model={request_model}. {e}"
|
||||
traceback.print_exc()
|
||||
error_detail = e.message
|
||||
assert "Budget has been exceeded!" in error_detail
|
||||
assert f"exceeded budget for model={request_model}" in error_detail
|
||||
assert isinstance(e, ProxyException)
|
||||
assert e.type == ProxyErrorTypes.budget_exceeded
|
||||
print(vars(e))
|
||||
finally:
|
||||
litellm.callbacks.remove(model_budget_limiter)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pydantic.main import Model
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
from datetime import datetime as dt_object
|
||||
import time
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
import json
|
||||
from litellm.types.utils import GenericBudgetInfo
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy.hooks.model_max_budget_limiter import (
|
||||
_PROXY_VirtualKeyModelMaxBudgetLimiter,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
import litellm
|
||||
|
||||
|
||||
# Test class setup
|
||||
@pytest.fixture
|
||||
def budget_limiter():
|
||||
dual_cache = DualCache()
|
||||
return _PROXY_VirtualKeyModelMaxBudgetLimiter(dual_cache=dual_cache)
|
||||
|
||||
|
||||
# Test _get_model_without_custom_llm_provider
|
||||
def test_get_model_without_custom_llm_provider(budget_limiter):
|
||||
# Test with custom provider
|
||||
assert (
|
||||
budget_limiter._get_model_without_custom_llm_provider("openai/gpt-4") == "gpt-4"
|
||||
)
|
||||
|
||||
# Test without custom provider
|
||||
assert budget_limiter._get_model_without_custom_llm_provider("gpt-4") == "gpt-4"
|
||||
|
||||
|
||||
# Test _get_request_model_budget_config
|
||||
def test_get_request_model_budget_config(budget_limiter):
|
||||
internal_budget = {
|
||||
"gpt-4": GenericBudgetInfo(budget_limit=100.0, time_period="1d"),
|
||||
"claude-3": GenericBudgetInfo(budget_limit=50.0, time_period="1d"),
|
||||
}
|
||||
|
||||
# Test direct model match
|
||||
config = budget_limiter._get_request_model_budget_config(
|
||||
model="gpt-4", internal_model_max_budget=internal_budget
|
||||
)
|
||||
assert config.budget_limit == 100.0
|
||||
|
||||
# Test model with provider
|
||||
config = budget_limiter._get_request_model_budget_config(
|
||||
model="openai/gpt-4", internal_model_max_budget=internal_budget
|
||||
)
|
||||
assert config.budget_limit == 100.0
|
||||
|
||||
# Test non-existent model
|
||||
config = budget_limiter._get_request_model_budget_config(
|
||||
model="non-existent", internal_model_max_budget=internal_budget
|
||||
)
|
||||
assert config is None
|
||||
|
||||
|
||||
# Test is_key_within_model_budget
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_key_within_model_budget(budget_limiter):
|
||||
# Mock user API key dict
|
||||
user_api_key = UserAPIKeyAuth(
|
||||
token="test-key",
|
||||
key_alias="test-alias",
|
||||
model_max_budget={"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
|
||||
)
|
||||
|
||||
# Test when model is within budget
|
||||
with patch.object(
|
||||
budget_limiter, "_get_virtual_key_spend_for_model", return_value=50.0
|
||||
):
|
||||
assert (
|
||||
await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-4")
|
||||
is True
|
||||
)
|
||||
|
||||
# Test when model exceeds budget
|
||||
with patch.object(
|
||||
budget_limiter, "_get_virtual_key_spend_for_model", return_value=150.0
|
||||
):
|
||||
with pytest.raises(litellm.BudgetExceededError):
|
||||
await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-4")
|
||||
|
||||
# Test model not in budget config
|
||||
assert (
|
||||
await budget_limiter.is_key_within_model_budget(user_api_key, "non-existent")
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# Test _get_virtual_key_spend_for_model
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_virtual_key_spend_for_model(budget_limiter):
|
||||
budget_config = GenericBudgetInfo(budget_limit=100.0, time_period="1d")
|
||||
|
||||
# Mock cache get
|
||||
with patch.object(budget_limiter.dual_cache, "async_get_cache", return_value=50.0):
|
||||
spend = await budget_limiter._get_virtual_key_spend_for_model(
|
||||
user_api_key_hash="test-key", model="gpt-4", key_budget_config=budget_config
|
||||
)
|
||||
assert spend == 50.0
|
||||
|
||||
# Test with provider prefix
|
||||
spend = await budget_limiter._get_virtual_key_spend_for_model(
|
||||
user_api_key_hash="test-key",
|
||||
model="openai/gpt-4",
|
||||
key_budget_config=budget_config,
|
||||
)
|
||||
assert spend == 50.0
|
Loading…
Add table
Add a link
Reference in a new issue