mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat proxy) v2 - model max budgets (#7302)
* clean up unused code * add _PROXY_VirtualKeyModelMaxBudgetLimiter * adjust type imports * working _PROXY_VirtualKeyModelMaxBudgetLimiter * fix user_api_key_model_max_budget * fix user_api_key_model_max_budget * update naming * update naming * fix changes to RouterBudgetLimiting * test_call_with_key_over_model_budget * test_call_with_key_over_model_budget * handle _get_request_model_budget_config * e2e test for test_call_with_key_over_model_budget * clean up test * run ci/cd again * add validate_model_max_budget * docs fix * update doc * add e2e testing for _PROXY_VirtualKeyModelMaxBudgetLimiter * test_unit_test_max_model_budget_limiter.py
This commit is contained in:
parent
1a4910f6c0
commit
6220e17ebf
14 changed files with 628 additions and 261 deletions
|
@ -10,7 +10,7 @@ Requirements:
|
||||||
|
|
||||||
## Set Budgets
|
## Set Budgets
|
||||||
|
|
||||||
You can set budgets at 3 levels:
|
You can set budgets at 5 levels:
|
||||||
- For the proxy
|
- For the proxy
|
||||||
- For an internal user
|
- For an internal user
|
||||||
- For a customer (end-user)
|
- For a customer (end-user)
|
||||||
|
@ -392,32 +392,88 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
|
||||||
<TabItem value="per-model-key" label="For Key (model specific)">
|
<TabItem value="per-model-key" label="For Key (model specific)">
|
||||||
|
|
||||||
Apply model specific budgets on a key.
|
Apply model specific budgets on a key. Example:
|
||||||
|
- Budget for `gpt-4o` is $0.0000001, for time period `1d` for `key = "sk-12345"`
|
||||||
**Expected Behaviour**
|
- Budget for `gpt-4o-mini` is $10, for time period `30d` for `key = "sk-12345"`
|
||||||
- `model_spend` gets auto-populated in `LiteLLM_VerificationToken` Table
|
|
||||||
- After the key crosses the budget set for the `model` in `model_max_budget`, calls fail
|
|
||||||
|
|
||||||
By default the `model_max_budget` is set to `{}` and is not checked for keys
|
|
||||||
|
|
||||||
:::info
|
|
||||||
|
|
||||||
- LiteLLM will track the cost/budgets for the `model` passed to LLM endpoints (`/chat/completions`, `/embeddings`)
|
|
||||||
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
#### **Add model specific budgets to keys**
|
#### **Add model specific budgets to keys**
|
||||||
|
|
||||||
|
The spec for `model_max_budget` is **[`Dict[str, GenericBudgetInfo]`](#genericbudgetinfo)**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl 'http://0.0.0.0:4000/key/generate' \
|
curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
--header 'Authorization: Bearer <your-master-key>' \
|
--header 'Authorization: Bearer <your-master-key>' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
--data-raw '{
|
--data-raw '{
|
||||||
model_max_budget={"gpt4": 0.5, "gpt-5": 0.01}
|
"model_max_budget": {"gpt-4o": {"budget_limit": "0.0000001", "time_period": "1d"}}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Make a test request
|
||||||
|
|
||||||
|
We expect the first request to succeed, and the second request to fail since we cross the budget for `gpt-4o` on the Virtual Key
|
||||||
|
|
||||||
|
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem label="Successful Call " value = "allowed">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer <sk-generated-key>' \
|
||||||
|
--data ' {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "testing request"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem label="Unsuccessful call" value = "not-allowed">
|
||||||
|
|
||||||
|
Expect this to fail since since we cross the budget `model=gpt-4o` on the Virtual Key
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer <sk-generated-key>' \
|
||||||
|
--data ' {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "testing request"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected response on failure
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": "LiteLLM Virtual Key: 9769f3f6768a199f76cc29xxxx, key_alias: None, exceeded budget for model=gpt-4o",
|
||||||
|
"type": "budget_exceeded",
|
||||||
|
"param": null,
|
||||||
|
"code": "400"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
@ -783,3 +839,32 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
--data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}'
|
--data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## API Specification
|
||||||
|
|
||||||
|
### `GenericBudgetInfo`
|
||||||
|
|
||||||
|
A Pydantic model that defines budget information with a time period and limit.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GenericBudgetInfo(BaseModel):
|
||||||
|
budget_limit: float # The maximum budget amount in USD
|
||||||
|
time_period: str # Duration string like "1d", "30d", etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Fields:
|
||||||
|
- `budget_limit` (float): The maximum budget amount in USD
|
||||||
|
- `time_period` (str): Duration string specifying the time period for the budget. Supported formats:
|
||||||
|
- Seconds: "30s"
|
||||||
|
- Minutes: "30m"
|
||||||
|
- Hours: "30h"
|
||||||
|
- Days: "30d"
|
||||||
|
|
||||||
|
#### Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"budget_limit": "0.0001",
|
||||||
|
"time_period": "1d"
|
||||||
|
}
|
||||||
|
```
|
|
@ -262,6 +262,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
llm_model_list,
|
llm_model_list,
|
||||||
llm_router,
|
llm_router,
|
||||||
master_key,
|
master_key,
|
||||||
|
model_max_budget_limiter,
|
||||||
open_telemetry_logger,
|
open_telemetry_logger,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
|
@ -1053,37 +1054,10 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
and valid_token.token is not None
|
and valid_token.token is not None
|
||||||
):
|
):
|
||||||
## GET THE SPEND FOR THIS MODEL
|
## GET THE SPEND FOR THIS MODEL
|
||||||
twenty_eight_days_ago = datetime.now() - timedelta(days=28)
|
await model_max_budget_limiter.is_key_within_model_budget(
|
||||||
model_spend = await prisma_client.db.litellm_spendlogs.group_by(
|
user_api_key_dict=valid_token,
|
||||||
by=["model"],
|
model=current_model,
|
||||||
sum={"spend": True},
|
|
||||||
where={
|
|
||||||
"AND": [
|
|
||||||
{"api_key": valid_token.token},
|
|
||||||
{"startTime": {"gt": twenty_eight_days_ago}},
|
|
||||||
{"model": current_model},
|
|
||||||
]
|
|
||||||
}, # type: ignore
|
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
len(model_spend) > 0
|
|
||||||
and max_budget_per_model.get(current_model, None) is not None
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
"model" in model_spend[0]
|
|
||||||
and model_spend[0].get("model") == current_model
|
|
||||||
and "_sum" in model_spend[0]
|
|
||||||
and "spend" in model_spend[0]["_sum"]
|
|
||||||
and model_spend[0]["_sum"]["spend"]
|
|
||||||
>= max_budget_per_model[current_model]
|
|
||||||
):
|
|
||||||
current_model_spend = model_spend[0]["_sum"]["spend"]
|
|
||||||
current_model_budget = max_budget_per_model[current_model]
|
|
||||||
raise litellm.BudgetExceededError(
|
|
||||||
current_cost=current_model_spend,
|
|
||||||
max_budget=current_model_budget,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check 6. Team spend is under Team budget
|
# Check 6. Team spend is under Team budget
|
||||||
if (
|
if (
|
||||||
hasattr(valid_token, "team_spend")
|
hasattr(valid_token, "team_spend")
|
||||||
|
|
196
litellm/proxy/hooks/model_max_budget_limiter.py
Normal file
196
litellm/proxy/hooks/model_max_budget_limiter.py
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import verbose_logger
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import (
|
||||||
|
GenericBudgetConfigType,
|
||||||
|
GenericBudgetInfo,
|
||||||
|
StandardLoggingPayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
|
||||||
|
|
||||||
|
|
||||||
|
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
|
||||||
|
"""
|
||||||
|
Handles budgets for model + virtual key
|
||||||
|
|
||||||
|
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dual_cache: DualCache):
|
||||||
|
self.dual_cache = dual_cache
|
||||||
|
self.redis_increment_operation_queue = []
|
||||||
|
|
||||||
|
async def is_key_within_model_budget(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
model: str,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the user_api_key_dict is within the model budget
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BudgetExceededError: If the user_api_key_dict has exceeded the model budget
|
||||||
|
"""
|
||||||
|
_model_max_budget = user_api_key_dict.model_max_budget
|
||||||
|
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||||
|
|
||||||
|
# case each element in _model_max_budget to GenericBudgetInfo
|
||||||
|
for _model, _budget_info in _model_max_budget.items():
|
||||||
|
internal_model_max_budget[_model] = GenericBudgetInfo(
|
||||||
|
time_period=_budget_info.get("time_period"),
|
||||||
|
budget_limit=float(_budget_info.get("budget_limit")),
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"internal_model_max_budget %s",
|
||||||
|
json.dumps(internal_model_max_budget, indent=4, default=str),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if current model is in internal_model_max_budget
|
||||||
|
_current_model_budget_info = self._get_request_model_budget_config(
|
||||||
|
model=model, internal_model_max_budget=internal_model_max_budget
|
||||||
|
)
|
||||||
|
if _current_model_budget_info is None:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Model {model} not found in internal_model_max_budget"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# check if current model is within budget
|
||||||
|
if _current_model_budget_info.budget_limit > 0:
|
||||||
|
_current_spend = await self._get_virtual_key_spend_for_model(
|
||||||
|
user_api_key_hash=user_api_key_dict.token,
|
||||||
|
model=model,
|
||||||
|
key_budget_config=_current_model_budget_info,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
_current_spend is not None
|
||||||
|
and _current_spend > _current_model_budget_info.budget_limit
|
||||||
|
):
|
||||||
|
raise litellm.BudgetExceededError(
|
||||||
|
message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
|
||||||
|
current_cost=_current_spend,
|
||||||
|
max_budget=_current_model_budget_info.budget_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _get_virtual_key_spend_for_model(
|
||||||
|
self,
|
||||||
|
user_api_key_hash: Optional[str],
|
||||||
|
model: str,
|
||||||
|
key_budget_config: GenericBudgetInfo,
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Get the current spend for a virtual key for a model
|
||||||
|
|
||||||
|
Lookup model in this order:
|
||||||
|
1. model: directly look up `model`
|
||||||
|
2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. model: directly look up `model`
|
||||||
|
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.time_period}"
|
||||||
|
_current_spend = await self.dual_cache.async_get_cache(
|
||||||
|
key=virtual_key_model_spend_cache_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _current_spend is None:
|
||||||
|
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||||
|
# if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
|
||||||
|
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.time_period}"
|
||||||
|
_current_spend = await self.dual_cache.async_get_cache(
|
||||||
|
key=virtual_key_model_spend_cache_key,
|
||||||
|
)
|
||||||
|
return _current_spend
|
||||||
|
|
||||||
|
def _get_request_model_budget_config(
|
||||||
|
self, model: str, internal_model_max_budget: GenericBudgetConfigType
|
||||||
|
) -> Optional[GenericBudgetInfo]:
|
||||||
|
"""
|
||||||
|
Get the budget config for the request model
|
||||||
|
|
||||||
|
1. Check if `model` is in `internal_model_max_budget`
|
||||||
|
2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
|
||||||
|
"""
|
||||||
|
return internal_model_max_budget.get(
|
||||||
|
model, None
|
||||||
|
) or internal_model_max_budget.get(
|
||||||
|
self._get_model_without_custom_llm_provider(model), None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_model_without_custom_llm_provider(self, model: str) -> str:
|
||||||
|
if "/" in model:
|
||||||
|
return model.split("/")[-1]
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def async_filter_deployments(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
healthy_deployments: List,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
request_kwargs: Optional[dict] = None,
|
||||||
|
parent_otel_span: Optional[Span] = None, # type: ignore
|
||||||
|
) -> List[dict]:
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
"""
|
||||||
|
Track spend for virtual key + model in DualCache
|
||||||
|
|
||||||
|
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
|
||||||
|
"""
|
||||||
|
verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
|
||||||
|
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
|
"standard_logging_object", None
|
||||||
|
)
|
||||||
|
if standard_logging_payload is None:
|
||||||
|
raise ValueError("standard_logging_payload is required")
|
||||||
|
|
||||||
|
_litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
_metadata = _litellm_params.get("metadata", {})
|
||||||
|
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
|
||||||
|
"user_api_key_model_max_budget", None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
user_api_key_model_max_budget is None
|
||||||
|
or len(user_api_key_model_max_budget) == 0
|
||||||
|
):
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
|
||||||
|
user_api_key_model_max_budget,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
response_cost: float = standard_logging_payload.get("response_cost", 0)
|
||||||
|
model = standard_logging_payload.get("model")
|
||||||
|
|
||||||
|
virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash")
|
||||||
|
model = standard_logging_payload.get("model")
|
||||||
|
if virtual_key is not None:
|
||||||
|
budget_config = GenericBudgetInfo(time_period="1d", budget_limit=0.1)
|
||||||
|
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.time_period}"
|
||||||
|
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
|
||||||
|
await self._increment_spend_for_key(
|
||||||
|
budget_config=budget_config,
|
||||||
|
spend_key=virtual_spend_key,
|
||||||
|
start_time_key=virtual_start_time_key,
|
||||||
|
response_cost=response_cost,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"current state of in memory cache %s",
|
||||||
|
json.dumps(
|
||||||
|
self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
|
||||||
|
),
|
||||||
|
)
|
|
@ -499,6 +499,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
data[_metadata_variable_name][
|
data[_metadata_variable_name][
|
||||||
"user_api_key_max_budget"
|
"user_api_key_max_budget"
|
||||||
] = user_api_key_dict.max_budget
|
] = user_api_key_dict.max_budget
|
||||||
|
data[_metadata_variable_name][
|
||||||
|
"user_api_key_model_max_budget"
|
||||||
|
] = user_api_key_dict.model_max_budget
|
||||||
|
|
||||||
data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
|
data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
_headers = dict(request.headers)
|
_headers = dict(request.headers)
|
||||||
|
|
|
@ -40,7 +40,11 @@ from litellm.proxy.utils import (
|
||||||
handle_exception_on_proxy,
|
handle_exception_on_proxy,
|
||||||
)
|
)
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
from litellm.types.utils import PersonalUIKeyGenerationConfig, TeamUIKeyGenerationConfig
|
from litellm.types.utils import (
|
||||||
|
GenericBudgetInfo,
|
||||||
|
PersonalUIKeyGenerationConfig,
|
||||||
|
TeamUIKeyGenerationConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_team_key(data: GenerateKeyRequest):
|
def _is_team_key(data: GenerateKeyRequest):
|
||||||
|
@ -246,7 +250,7 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
||||||
- guardrails: Optional[List[str]] - List of active guardrails for the key
|
- guardrails: Optional[List[str]] - List of active guardrails for the key
|
||||||
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
|
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
|
||||||
- model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget.
|
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget.
|
||||||
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
|
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
|
||||||
- model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit.
|
- model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit.
|
||||||
- allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request
|
- allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request
|
||||||
|
@ -515,6 +519,10 @@ def prepare_key_update_data(
|
||||||
|
|
||||||
_metadata = existing_key_row.metadata or {}
|
_metadata = existing_key_row.metadata or {}
|
||||||
|
|
||||||
|
# validate model_max_budget
|
||||||
|
if "model_max_budget" in non_default_values:
|
||||||
|
validate_model_max_budget(non_default_values["model_max_budget"])
|
||||||
|
|
||||||
non_default_values = prepare_metadata_fields(
|
non_default_values = prepare_metadata_fields(
|
||||||
data=data, non_default_values=non_default_values, existing_metadata=_metadata
|
data=data, non_default_values=non_default_values, existing_metadata=_metadata
|
||||||
)
|
)
|
||||||
|
@ -548,7 +556,7 @@ async def update_key_fn(
|
||||||
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
|
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
|
||||||
- spend: Optional[float] - Amount spent by key
|
- spend: Optional[float] - Amount spent by key
|
||||||
- max_budget: Optional[float] - Max budget for key
|
- max_budget: Optional[float] - Max budget for key
|
||||||
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}
|
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}
|
||||||
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
||||||
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
|
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
|
||||||
- max_parallel_requests: Optional[int] - Rate limit for parallel requests
|
- max_parallel_requests: Optional[int] - Rate limit for parallel requests
|
||||||
|
@ -1035,6 +1043,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
||||||
metadata["guardrails"] = guardrails
|
metadata["guardrails"] = guardrails
|
||||||
|
|
||||||
metadata_json = json.dumps(metadata)
|
metadata_json = json.dumps(metadata)
|
||||||
|
validate_model_max_budget(model_max_budget)
|
||||||
model_max_budget_json = json.dumps(model_max_budget)
|
model_max_budget_json = json.dumps(model_max_budget)
|
||||||
user_role = user_role
|
user_role = user_role
|
||||||
tpm_limit = tpm_limit
|
tpm_limit = tpm_limit
|
||||||
|
@ -1266,7 +1275,7 @@ async def regenerate_key_fn(
|
||||||
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
|
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
|
||||||
- spend: Optional[float] - Amount spent by key
|
- spend: Optional[float] - Amount spent by key
|
||||||
- max_budget: Optional[float] - Max budget for key
|
- max_budget: Optional[float] - Max budget for key
|
||||||
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}
|
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}
|
||||||
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
||||||
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
|
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
|
||||||
- max_parallel_requests: Optional[int] - Rate limit for parallel requests
|
- max_parallel_requests: Optional[int] - Rate limit for parallel requests
|
||||||
|
@ -1293,8 +1302,7 @@ async def regenerate_key_fn(
|
||||||
--data-raw '{
|
--data-raw '{
|
||||||
"max_budget": 100,
|
"max_budget": 100,
|
||||||
"metadata": {"team": "core-infra"},
|
"metadata": {"team": "core-infra"},
|
||||||
"models": ["gpt-4", "gpt-3.5-turbo"],
|
"models": ["gpt-4", "gpt-3.5-turbo"]
|
||||||
"model_max_budget": {"gpt-4": 50, "gpt-3.5-turbo": 50}
|
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1949,3 +1957,29 @@ async def _enforce_unique_key_alias(
|
||||||
param="key_alias",
|
param="key_alias",
|
||||||
code=status.HTTP_400_BAD_REQUEST,
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model_max_budget(model_max_budget: Optional[Dict]) -> None:
|
||||||
|
"""
|
||||||
|
Validate the model_max_budget is GenericBudgetConfigType
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If model_max_budget is not a valid GenericBudgetConfigType
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if model_max_budget is None:
|
||||||
|
return
|
||||||
|
if len(model_max_budget) == 0:
|
||||||
|
return
|
||||||
|
if model_max_budget is not None:
|
||||||
|
for _model, _budget_info in model_max_budget.items():
|
||||||
|
assert isinstance(_model, str)
|
||||||
|
|
||||||
|
# /CRUD endpoints can pass budget_limit as a string, so we need to convert it to a float
|
||||||
|
if "budget_limit" in _budget_info:
|
||||||
|
_budget_info["budget_limit"] = float(_budget_info["budget_limit"])
|
||||||
|
GenericBudgetInfo(**_budget_info)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid model_max_budget: {str(e)}. Example of valid model_max_budget: https://docs.litellm.ai/docs/proxy/users"
|
||||||
|
)
|
||||||
|
|
|
@ -1,42 +1,10 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: openai/o1-preview
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/fake
|
model: openai/o1-preview
|
||||||
api_key: fake-key
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
- model_name: openai/*
|
||||||
tags: ["teamA"]
|
litellm_params:
|
||||||
model_info:
|
model: openai/*
|
||||||
id: "team-a-model"
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: fake-openai-endpoint
|
|
||||||
litellm_params:
|
|
||||||
model: openai/fake
|
|
||||||
api_key: fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
|
||||||
tags: ["teamB"]
|
|
||||||
model_info:
|
|
||||||
id: "team-b-model"
|
|
||||||
- model_name: rerank-english-v3.0
|
|
||||||
litellm_params:
|
|
||||||
model: cohere/rerank-english-v3.0
|
|
||||||
api_key: os.environ/COHERE_API_KEY
|
|
||||||
- model_name: fake-azure-endpoint
|
|
||||||
litellm_params:
|
|
||||||
model: openai/429
|
|
||||||
api_key: fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
|
||||||
- model_name: llava-hf
|
|
||||||
litellm_params:
|
|
||||||
model: openai/llava-hf/llava-v1.6-vicuna-7b-hf
|
|
||||||
api_base: http://localhost:8000
|
|
||||||
api_key: fake-key
|
|
||||||
model_info:
|
|
||||||
supports_vision: True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
cache: true
|
|
||||||
callbacks: ["otel", "prometheus"]
|
|
||||||
|
|
||||||
router_settings:
|
|
||||||
enable_tag_filtering: True # 👈 Key Change
|
|
|
@ -173,6 +173,9 @@ from litellm.proxy.guardrails.init_guardrails import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||||
|
from litellm.proxy.hooks.model_max_budget_limiter import (
|
||||||
|
_PROXY_VirtualKeyModelMaxBudgetLimiter,
|
||||||
|
)
|
||||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
_OPTIONAL_PromptInjectionDetection,
|
_OPTIONAL_PromptInjectionDetection,
|
||||||
)
|
)
|
||||||
|
@ -495,6 +498,10 @@ prisma_client: Optional[PrismaClient] = None
|
||||||
user_api_key_cache = DualCache(
|
user_api_key_cache = DualCache(
|
||||||
default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
|
default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
|
||||||
)
|
)
|
||||||
|
model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||||
|
dual_cache=user_api_key_cache
|
||||||
|
)
|
||||||
|
litellm.callbacks.append(model_max_budget_limiter)
|
||||||
redis_usage_cache: Optional[RedisCache] = (
|
redis_usage_cache: Optional[RedisCache] = (
|
||||||
None # redis cache used for tracking spend, tpm/rpm limits
|
None # redis cache used for tracking spend, tpm/rpm limits
|
||||||
)
|
)
|
||||||
|
|
|
@ -631,7 +631,7 @@ class Router:
|
||||||
_callback = PromptCachingDeploymentCheck(cache=self.cache)
|
_callback = PromptCachingDeploymentCheck(cache=self.cache)
|
||||||
elif pre_call_check == "router_budget_limiting":
|
elif pre_call_check == "router_budget_limiting":
|
||||||
_callback = RouterBudgetLimiting(
|
_callback = RouterBudgetLimiting(
|
||||||
router_cache=self.cache,
|
dual_cache=self.cache,
|
||||||
provider_budget_config=self.provider_budget_config,
|
provider_budget_config=self.provider_budget_config,
|
||||||
model_list=self.model_list,
|
model_list=self.model_list,
|
||||||
)
|
)
|
||||||
|
@ -5292,14 +5292,6 @@ class Router:
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if self.router_budget_logger:
|
|
||||||
# healthy_deployments = (
|
|
||||||
# await self.router_budget_logger.async_filter_deployments(
|
|
||||||
# healthy_deployments=healthy_deployments,
|
|
||||||
# request_kwargs=request_kwargs,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
exception = await async_raise_no_deployment_exception(
|
exception = await async_raise_no_deployment_exception(
|
||||||
litellm_router_instance=self,
|
litellm_router_instance=self,
|
||||||
|
|
|
@ -49,13 +49,13 @@ DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||||
class RouterBudgetLimiting(CustomLogger):
|
class RouterBudgetLimiting(CustomLogger):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
router_cache: DualCache,
|
dual_cache: DualCache,
|
||||||
provider_budget_config: Optional[dict],
|
provider_budget_config: Optional[dict],
|
||||||
model_list: Optional[
|
model_list: Optional[
|
||||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||||
] = None,
|
] = None,
|
||||||
):
|
):
|
||||||
self.router_cache = router_cache
|
self.dual_cache = dual_cache
|
||||||
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||||
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
|
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
|
||||||
self.provider_budget_config: Optional[GenericBudgetConfigType] = (
|
self.provider_budget_config: Optional[GenericBudgetConfigType] = (
|
||||||
|
@ -108,7 +108,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
|
|
||||||
# Single cache read for all spend values
|
# Single cache read for all spend values
|
||||||
if len(cache_keys) > 0:
|
if len(cache_keys) > 0:
|
||||||
_current_spends = await self.router_cache.async_batch_get_cache(
|
_current_spends = await self.dual_cache.async_batch_get_cache(
|
||||||
keys=cache_keys,
|
keys=cache_keys,
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
)
|
)
|
||||||
|
@ -286,9 +286,9 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
If it does, return the value.
|
If it does, return the value.
|
||||||
If it does not, set the key to `current_time` and return the value.
|
If it does not, set the key to `current_time` and return the value.
|
||||||
"""
|
"""
|
||||||
budget_start = await self.router_cache.async_get_cache(start_time_key)
|
budget_start = await self.dual_cache.async_get_cache(start_time_key)
|
||||||
if budget_start is None:
|
if budget_start is None:
|
||||||
await self.router_cache.async_set_cache(
|
await self.dual_cache.async_set_cache(
|
||||||
key=start_time_key, value=current_time, ttl=ttl_seconds
|
key=start_time_key, value=current_time, ttl=ttl_seconds
|
||||||
)
|
)
|
||||||
return current_time
|
return current_time
|
||||||
|
@ -314,10 +314,10 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
- stores key: `provider_budget_start_time:{provider}`, value: current_time.
|
- stores key: `provider_budget_start_time:{provider}`, value: current_time.
|
||||||
This stores the start time of the new budget window
|
This stores the start time of the new budget window
|
||||||
"""
|
"""
|
||||||
await self.router_cache.async_set_cache(
|
await self.dual_cache.async_set_cache(
|
||||||
key=spend_key, value=response_cost, ttl=ttl_seconds
|
key=spend_key, value=response_cost, ttl=ttl_seconds
|
||||||
)
|
)
|
||||||
await self.router_cache.async_set_cache(
|
await self.dual_cache.async_set_cache(
|
||||||
key=start_time_key, value=current_time, ttl=ttl_seconds
|
key=start_time_key, value=current_time, ttl=ttl_seconds
|
||||||
)
|
)
|
||||||
return current_time
|
return current_time
|
||||||
|
@ -333,7 +333,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
- Increments the spend in memory cache (so spend instantly updated in memory)
|
- Increments the spend in memory cache (so spend instantly updated in memory)
|
||||||
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
|
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
|
||||||
"""
|
"""
|
||||||
await self.router_cache.in_memory_cache.async_increment(
|
await self.dual_cache.in_memory_cache.async_increment(
|
||||||
key=spend_key,
|
key=spend_key,
|
||||||
value=response_cost,
|
value=response_cost,
|
||||||
ttl=ttl,
|
ttl=ttl,
|
||||||
|
@ -481,7 +481,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
Only runs if Redis is initialized
|
Only runs if Redis is initialized
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not self.router_cache.redis_cache:
|
if not self.dual_cache.redis_cache:
|
||||||
return # Redis is not initialized
|
return # Redis is not initialized
|
||||||
|
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
|
@ -490,7 +490,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
)
|
)
|
||||||
if len(self.redis_increment_operation_queue) > 0:
|
if len(self.redis_increment_operation_queue) > 0:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.router_cache.redis_cache.async_increment_pipeline(
|
self.dual_cache.redis_cache.async_increment_pipeline(
|
||||||
increment_list=self.redis_increment_operation_queue,
|
increment_list=self.redis_increment_operation_queue,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -517,7 +517,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# No need to sync if Redis cache is not initialized
|
# No need to sync if Redis cache is not initialized
|
||||||
if self.router_cache.redis_cache is None:
|
if self.dual_cache.redis_cache is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 1. Push all provider spend increments to Redis
|
# 1. Push all provider spend increments to Redis
|
||||||
|
@ -547,7 +547,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
cache_keys.append(f"tag_spend:{tag}:{config.time_period}")
|
cache_keys.append(f"tag_spend:{tag}:{config.time_period}")
|
||||||
|
|
||||||
# Batch fetch current spend values from Redis
|
# Batch fetch current spend values from Redis
|
||||||
redis_values = await self.router_cache.redis_cache.async_batch_get_cache(
|
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
|
||||||
key_list=cache_keys
|
key_list=cache_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -555,7 +555,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
|
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
|
||||||
for key, value in redis_values.items():
|
for key, value in redis_values.items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
await self.router_cache.in_memory_cache.async_set_cache(
|
await self.dual_cache.in_memory_cache.async_set_cache(
|
||||||
key=key, value=float(value)
|
key=key, value=float(value)
|
||||||
)
|
)
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
|
@ -639,14 +639,12 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
|
|
||||||
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
||||||
|
|
||||||
if self.router_cache.redis_cache:
|
if self.dual_cache.redis_cache:
|
||||||
# use Redis as source of truth since that has spend across all instances
|
# use Redis as source of truth since that has spend across all instances
|
||||||
current_spend = await self.router_cache.redis_cache.async_get_cache(
|
current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key)
|
||||||
spend_key
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# use in-memory cache if Redis is not initialized
|
# use in-memory cache if Redis is not initialized
|
||||||
current_spend = await self.router_cache.async_get_cache(spend_key)
|
current_spend = await self.dual_cache.async_get_cache(spend_key)
|
||||||
return float(current_spend) if current_spend is not None else 0.0
|
return float(current_spend) if current_spend is not None else 0.0
|
||||||
|
|
||||||
async def _get_current_provider_budget_reset_at(
|
async def _get_current_provider_budget_reset_at(
|
||||||
|
@ -657,10 +655,10 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
||||||
if self.router_cache.redis_cache:
|
if self.dual_cache.redis_cache:
|
||||||
ttl_seconds = await self.router_cache.redis_cache.async_get_ttl(spend_key)
|
ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
|
||||||
else:
|
else:
|
||||||
ttl_seconds = await self.router_cache.async_get_ttl(spend_key)
|
ttl_seconds = await self.dual_cache.async_get_ttl(spend_key)
|
||||||
|
|
||||||
if ttl_seconds is None:
|
if ttl_seconds is None:
|
||||||
return None
|
return None
|
||||||
|
@ -679,16 +677,16 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
|
||||||
start_time_key = f"provider_budget_start_time:{provider}"
|
start_time_key = f"provider_budget_start_time:{provider}"
|
||||||
ttl_seconds = duration_in_seconds(budget_config.time_period)
|
ttl_seconds = duration_in_seconds(budget_config.time_period)
|
||||||
budget_start = await self.router_cache.async_get_cache(start_time_key)
|
budget_start = await self.dual_cache.async_get_cache(start_time_key)
|
||||||
if budget_start is None:
|
if budget_start is None:
|
||||||
budget_start = datetime.now(timezone.utc).timestamp()
|
budget_start = datetime.now(timezone.utc).timestamp()
|
||||||
await self.router_cache.async_set_cache(
|
await self.dual_cache.async_set_cache(
|
||||||
key=start_time_key, value=budget_start, ttl=ttl_seconds
|
key=start_time_key, value=budget_start, ttl=ttl_seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
_spend_key = await self.router_cache.async_get_cache(spend_key)
|
_spend_key = await self.dual_cache.async_get_cache(spend_key)
|
||||||
if _spend_key is None:
|
if _spend_key is None:
|
||||||
await self.router_cache.async_set_cache(
|
await self.dual_cache.async_set_cache(
|
||||||
key=spend_key, value=0.0, ttl=ttl_seconds
|
key=spend_key, value=0.0, ttl=ttl_seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,8 @@ import httpx
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
||||||
|
from litellm.types.utils import GenericBudgetConfigType, GenericBudgetInfo
|
||||||
|
|
||||||
from ..exceptions import RateLimitError
|
from ..exceptions import RateLimitError
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
from .embedding import EmbeddingRequest
|
from .embedding import EmbeddingRequest
|
||||||
|
@ -647,14 +649,6 @@ class RoutingStrategy(enum.Enum):
|
||||||
PROVIDER_BUDGET_LIMITING = "provider-budget-routing"
|
PROVIDER_BUDGET_LIMITING = "provider-budget-routing"
|
||||||
|
|
||||||
|
|
||||||
class GenericBudgetInfo(BaseModel):
|
|
||||||
time_period: str # e.g., '1d', '30d'
|
|
||||||
budget_limit: float
|
|
||||||
|
|
||||||
|
|
||||||
GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
|
|
||||||
|
|
||||||
|
|
||||||
class RouterCacheEnum(enum.Enum):
|
class RouterCacheEnum(enum.Enum):
|
||||||
TPM = "global_router:{id}:{model}:tpm:{current_minute}"
|
TPM = "global_router:{id}:{model}:tpm:{current_minute}"
|
||||||
RPM = "global_router:{id}:{model}:rpm:{current_minute}"
|
RPM = "global_router:{id}:{model}:rpm:{current_minute}"
|
||||||
|
|
|
@ -1665,6 +1665,14 @@ class StandardKeyGenerationConfig(TypedDict, total=False):
|
||||||
personal_key_generation: PersonalUIKeyGenerationConfig
|
personal_key_generation: PersonalUIKeyGenerationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class GenericBudgetInfo(BaseModel):
|
||||||
|
time_period: str # e.g., '1d', '30d'
|
||||||
|
budget_limit: float
|
||||||
|
|
||||||
|
|
||||||
|
GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
|
||||||
|
|
||||||
|
|
||||||
class BudgetConfig(BaseModel):
|
class BudgetConfig(BaseModel):
|
||||||
max_budget: float
|
max_budget: float
|
||||||
budget_duration: str
|
budget_duration: str
|
||||||
|
|
|
@ -183,7 +183,7 @@ async def test_get_llm_provider_for_deployment():
|
||||||
"""
|
"""
|
||||||
cleanup_redis()
|
cleanup_redis()
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(), provider_budget_config={}
|
dual_cache=DualCache(), provider_budget_config={}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test OpenAI deployment
|
# Test OpenAI deployment
|
||||||
|
@ -220,7 +220,7 @@ async def test_get_budget_config_for_provider():
|
||||||
}
|
}
|
||||||
|
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(), provider_budget_config=config
|
dual_cache=DualCache(), provider_budget_config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test existing providers
|
# Test existing providers
|
||||||
|
@ -252,7 +252,7 @@ async def test_prometheus_metric_tracking():
|
||||||
|
|
||||||
# Setup provider budget limiting
|
# Setup provider budget limiting
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(),
|
dual_cache=DualCache(),
|
||||||
provider_budget_config={
|
provider_budget_config={
|
||||||
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100)
|
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100)
|
||||||
},
|
},
|
||||||
|
@ -316,7 +316,7 @@ async def test_handle_new_budget_window():
|
||||||
"""
|
"""
|
||||||
cleanup_redis()
|
cleanup_redis()
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(), provider_budget_config={}
|
dual_cache=DualCache(), provider_budget_config={}
|
||||||
)
|
)
|
||||||
|
|
||||||
spend_key = "provider_spend:openai:7d"
|
spend_key = "provider_spend:openai:7d"
|
||||||
|
@ -337,12 +337,12 @@ async def test_handle_new_budget_window():
|
||||||
assert new_start_time == current_time
|
assert new_start_time == current_time
|
||||||
|
|
||||||
# Verify the spend was set correctly
|
# Verify the spend was set correctly
|
||||||
spend = await provider_budget.router_cache.async_get_cache(spend_key)
|
spend = await provider_budget.dual_cache.async_get_cache(spend_key)
|
||||||
print("spend in cache for key", spend_key, "is", spend)
|
print("spend in cache for key", spend_key, "is", spend)
|
||||||
assert float(spend) == response_cost
|
assert float(spend) == response_cost
|
||||||
|
|
||||||
# Verify start time was set correctly
|
# Verify start time was set correctly
|
||||||
start_time = await provider_budget.router_cache.async_get_cache(start_time_key)
|
start_time = await provider_budget.dual_cache.async_get_cache(start_time_key)
|
||||||
print("start time in cache for key", start_time_key, "is", start_time)
|
print("start time in cache for key", start_time_key, "is", start_time)
|
||||||
assert float(start_time) == current_time
|
assert float(start_time) == current_time
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ async def test_get_or_set_budget_start_time():
|
||||||
"""
|
"""
|
||||||
cleanup_redis()
|
cleanup_redis()
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(), provider_budget_config={}
|
dual_cache=DualCache(), provider_budget_config={}
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time_key = "test_start_time"
|
start_time_key = "test_start_time"
|
||||||
|
@ -398,7 +398,7 @@ async def test_increment_spend_in_current_window():
|
||||||
"""
|
"""
|
||||||
cleanup_redis()
|
cleanup_redis()
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(), provider_budget_config={}
|
dual_cache=DualCache(), provider_budget_config={}
|
||||||
)
|
)
|
||||||
|
|
||||||
spend_key = "provider_spend:openai:1d"
|
spend_key = "provider_spend:openai:1d"
|
||||||
|
@ -406,9 +406,7 @@ async def test_increment_spend_in_current_window():
|
||||||
ttl = 86400 # 1 day
|
ttl = 86400 # 1 day
|
||||||
|
|
||||||
# Set initial spend
|
# Set initial spend
|
||||||
await provider_budget.router_cache.async_set_cache(
|
await provider_budget.dual_cache.async_set_cache(key=spend_key, value=1.0, ttl=ttl)
|
||||||
key=spend_key, value=1.0, ttl=ttl
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test incrementing spend
|
# Test incrementing spend
|
||||||
await provider_budget._increment_spend_in_current_window(
|
await provider_budget._increment_spend_in_current_window(
|
||||||
|
@ -418,7 +416,7 @@ async def test_increment_spend_in_current_window():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the spend was incremented correctly in memory
|
# Verify the spend was incremented correctly in memory
|
||||||
spend = await provider_budget.router_cache.async_get_cache(spend_key)
|
spend = await provider_budget.dual_cache.async_get_cache(spend_key)
|
||||||
assert float(spend) == 1.5
|
assert float(spend) == 1.5
|
||||||
|
|
||||||
# Verify the increment operation was queued for Redis
|
# Verify the increment operation was queued for Redis
|
||||||
|
@ -449,7 +447,7 @@ async def test_sync_in_memory_spend_with_redis():
|
||||||
}
|
}
|
||||||
|
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(
|
dual_cache=DualCache(
|
||||||
redis_cache=RedisCache(
|
redis_cache=RedisCache(
|
||||||
host=os.getenv("REDIS_HOST"),
|
host=os.getenv("REDIS_HOST"),
|
||||||
port=int(os.getenv("REDIS_PORT")),
|
port=int(os.getenv("REDIS_PORT")),
|
||||||
|
@ -463,10 +461,10 @@ async def test_sync_in_memory_spend_with_redis():
|
||||||
spend_key_openai = "provider_spend:openai:1d"
|
spend_key_openai = "provider_spend:openai:1d"
|
||||||
spend_key_anthropic = "provider_spend:anthropic:1d"
|
spend_key_anthropic = "provider_spend:anthropic:1d"
|
||||||
|
|
||||||
await provider_budget.router_cache.redis_cache.async_set_cache(
|
await provider_budget.dual_cache.redis_cache.async_set_cache(
|
||||||
key=spend_key_openai, value=50.0
|
key=spend_key_openai, value=50.0
|
||||||
)
|
)
|
||||||
await provider_budget.router_cache.redis_cache.async_set_cache(
|
await provider_budget.dual_cache.redis_cache.async_set_cache(
|
||||||
key=spend_key_anthropic, value=75.0
|
key=spend_key_anthropic, value=75.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -474,13 +472,11 @@ async def test_sync_in_memory_spend_with_redis():
|
||||||
await provider_budget._sync_in_memory_spend_with_redis()
|
await provider_budget._sync_in_memory_spend_with_redis()
|
||||||
|
|
||||||
# Verify in-memory cache was updated
|
# Verify in-memory cache was updated
|
||||||
openai_spend = await provider_budget.router_cache.in_memory_cache.async_get_cache(
|
openai_spend = await provider_budget.dual_cache.in_memory_cache.async_get_cache(
|
||||||
spend_key_openai
|
spend_key_openai
|
||||||
)
|
)
|
||||||
anthropic_spend = (
|
anthropic_spend = await provider_budget.dual_cache.in_memory_cache.async_get_cache(
|
||||||
await provider_budget.router_cache.in_memory_cache.async_get_cache(
|
spend_key_anthropic
|
||||||
spend_key_anthropic
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert float(openai_spend) == 50.0
|
assert float(openai_spend) == 50.0
|
||||||
|
@ -499,7 +495,7 @@ async def test_get_current_provider_spend():
|
||||||
"""
|
"""
|
||||||
cleanup_redis()
|
cleanup_redis()
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(),
|
dual_cache=DualCache(),
|
||||||
provider_budget_config={
|
provider_budget_config={
|
||||||
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
|
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100),
|
||||||
},
|
},
|
||||||
|
@ -515,7 +511,7 @@ async def test_get_current_provider_spend():
|
||||||
|
|
||||||
# Test provider with budget config and spend
|
# Test provider with budget config and spend
|
||||||
spend_key = "provider_spend:openai:1d"
|
spend_key = "provider_spend:openai:1d"
|
||||||
await provider_budget.router_cache.async_set_cache(key=spend_key, value=50.5)
|
await provider_budget.dual_cache.async_set_cache(key=spend_key, value=50.5)
|
||||||
|
|
||||||
spend = await provider_budget._get_current_provider_spend("openai")
|
spend = await provider_budget._get_current_provider_spend("openai")
|
||||||
assert spend == 50.5
|
assert spend == 50.5
|
||||||
|
@ -534,7 +530,7 @@ async def test_get_current_provider_budget_reset_at():
|
||||||
"""
|
"""
|
||||||
cleanup_redis()
|
cleanup_redis()
|
||||||
provider_budget = RouterBudgetLimiting(
|
provider_budget = RouterBudgetLimiting(
|
||||||
router_cache=DualCache(
|
dual_cache=DualCache(
|
||||||
redis_cache=RedisCache(
|
redis_cache=RedisCache(
|
||||||
host=os.getenv("REDIS_HOST"),
|
host=os.getenv("REDIS_HOST"),
|
||||||
port=int(os.getenv("REDIS_PORT")),
|
port=int(os.getenv("REDIS_PORT")),
|
||||||
|
|
|
@ -1684,124 +1684,109 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
|
||||||
print(vars(e))
|
print(vars(e))
|
||||||
|
|
||||||
|
|
||||||
def test_call_with_key_over_model_budget(prisma_client):
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"request_model,should_pass",
|
||||||
|
[
|
||||||
|
("openai/gpt-4o-mini", False),
|
||||||
|
("gpt-4o-mini", False),
|
||||||
|
("gpt-4o", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_call_with_key_over_model_budget(
|
||||||
|
prisma_client, request_model, should_pass
|
||||||
|
):
|
||||||
# 12. Make a call with a key over budget, expect to fail
|
# 12. Make a call with a key over budget, expect to fail
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# init model max budget limiter
|
||||||
|
from litellm.proxy.hooks.model_max_budget_limiter import (
|
||||||
|
_PROXY_VirtualKeyModelMaxBudgetLimiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||||
|
dual_cache=DualCache()
|
||||||
|
)
|
||||||
|
litellm.callbacks.append(model_budget_limiter)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
async def test():
|
# set budget for chatgpt-v-2 to 0.000001, expect the next request to fail
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
model_max_budget = {
|
||||||
|
"gpt-4o-mini": {
|
||||||
|
"budget_limit": "0.000001",
|
||||||
|
"time_period": "1d",
|
||||||
|
},
|
||||||
|
"gpt-4o": {
|
||||||
|
"budget_limit": "200",
|
||||||
|
"time_period": "30d",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# set budget for chatgpt-v-2 to 0.000001, expect the next request to fail
|
request = GenerateKeyRequest(
|
||||||
request = GenerateKeyRequest(
|
max_budget=100000, # the key itself has a very high budget
|
||||||
max_budget=1000,
|
model_max_budget=model_max_budget,
|
||||||
model_max_budget={
|
)
|
||||||
"chatgpt-v-2": 0.000001,
|
key = await generate_key_fn(request)
|
||||||
},
|
print(key)
|
||||||
metadata={"user_api_key": 0.0001},
|
|
||||||
|
generated_key = key.key
|
||||||
|
user_id = key.user_id
|
||||||
|
bearer_token = "Bearer " + generated_key
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
async def return_body():
|
||||||
|
request_str = f'{{"model": "{request_model}"}}' # Added extra curly braces to escape JSON
|
||||||
|
return request_str.encode()
|
||||||
|
|
||||||
|
request.body = return_body
|
||||||
|
|
||||||
|
# use generated key to auth in
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
|
await litellm.acompletion(
|
||||||
|
model=request_model,
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
metadata={
|
||||||
|
"user_api_key": hash_token(generated_key),
|
||||||
|
"user_api_key_model_max_budget": model_max_budget,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# use generated key to auth in
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
if should_pass is True:
|
||||||
|
print(
|
||||||
|
f"Passed request for model={request_model}, model_max_budget={model_max_budget}"
|
||||||
)
|
)
|
||||||
key = await generate_key_fn(request)
|
return
|
||||||
print(key)
|
print("result from user auth with new key", result)
|
||||||
|
pytest.fail("This should have failed!. They key crossed it's budget")
|
||||||
generated_key = key.key
|
|
||||||
user_id = key.user_id
|
|
||||||
bearer_token = "Bearer " + generated_key
|
|
||||||
|
|
||||||
request = Request(scope={"type": "http"})
|
|
||||||
request._url = URL(url="/chat/completions")
|
|
||||||
|
|
||||||
async def return_body():
|
|
||||||
return b'{"model": "chatgpt-v-2"}'
|
|
||||||
|
|
||||||
request.body = return_body
|
|
||||||
|
|
||||||
# use generated key to auth in
|
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
|
||||||
print("result from user auth with new key", result)
|
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
|
||||||
from litellm import Choices, Message, ModelResponse, Usage
|
|
||||||
from litellm.caching.caching import Cache
|
|
||||||
from litellm.proxy.proxy_server import (
|
|
||||||
_PROXY_track_cost_callback as track_cost_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
litellm.cache = Cache()
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
request_id = f"chatcmpl-{uuid.uuid4()}"
|
|
||||||
|
|
||||||
resp = ModelResponse(
|
|
||||||
id=request_id,
|
|
||||||
choices=[
|
|
||||||
Choices(
|
|
||||||
finish_reason=None,
|
|
||||||
index=0,
|
|
||||||
message=Message(
|
|
||||||
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
|
|
||||||
role="assistant",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model="gpt-35-turbo", # azure always has model written like this
|
|
||||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
|
||||||
)
|
|
||||||
await track_cost_callback(
|
|
||||||
kwargs={
|
|
||||||
"model": "chatgpt-v-2",
|
|
||||||
"stream": False,
|
|
||||||
"litellm_params": {
|
|
||||||
"metadata": {
|
|
||||||
"user_api_key": hash_token(generated_key),
|
|
||||||
"user_api_key_user_id": user_id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"response_cost": 0.00002,
|
|
||||||
},
|
|
||||||
completion_response=resp,
|
|
||||||
start_time=datetime.now(),
|
|
||||||
end_time=datetime.now(),
|
|
||||||
)
|
|
||||||
await update_spend(
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
db_writer_client=None,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
# test spend_log was written and we can read it
|
|
||||||
spend_logs = await view_spend_logs(
|
|
||||||
request_id=request_id,
|
|
||||||
user_api_key_dict=UserAPIKeyAuth(api_key=generated_key),
|
|
||||||
)
|
|
||||||
|
|
||||||
print("read spend logs", spend_logs)
|
|
||||||
assert len(spend_logs) == 1
|
|
||||||
|
|
||||||
spend_log = spend_logs[0]
|
|
||||||
|
|
||||||
assert spend_log.request_id == request_id
|
|
||||||
assert spend_log.spend == float("2e-05")
|
|
||||||
assert spend_log.model == "chatgpt-v-2"
|
|
||||||
assert (
|
|
||||||
spend_log.cache_key
|
|
||||||
== "c891d64397a472e6deb31b87a5ac4d3ed5b2dcc069bc87e2afe91e6d64e95a1e"
|
|
||||||
)
|
|
||||||
|
|
||||||
# use generated key to auth in
|
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
|
||||||
print("result from user auth with new key", result)
|
|
||||||
pytest.fail("This should have failed!. They key crossed it's budget")
|
|
||||||
|
|
||||||
asyncio.run(test())
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print(f"Error - {str(e)}")
|
# print(f"Error - {str(e)}")
|
||||||
|
print(
|
||||||
|
f"Failed request for model={request_model}, model_max_budget={model_max_budget}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
should_pass is False
|
||||||
|
), f"This should have failed!. They key crossed it's budget for model={request_model}. {e}"
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
error_detail = e.message
|
error_detail = e.message
|
||||||
assert "Budget has been exceeded!" in error_detail
|
assert f"exceeded budget for model={request_model}" in error_detail
|
||||||
assert isinstance(e, ProxyException)
|
assert isinstance(e, ProxyException)
|
||||||
assert e.type == ProxyErrorTypes.budget_exceeded
|
assert e.type == ProxyErrorTypes.budget_exceeded
|
||||||
print(vars(e))
|
print(vars(e))
|
||||||
|
finally:
|
||||||
|
litellm.callbacks.remove(model_budget_limiter)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from pydantic.main import Model
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system-path
|
||||||
|
from datetime import datetime as dt_object
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
import json
|
||||||
|
from litellm.types.utils import GenericBudgetInfo
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
import pytest
|
||||||
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.proxy.hooks.model_max_budget_limiter import (
|
||||||
|
_PROXY_VirtualKeyModelMaxBudgetLimiter,
|
||||||
|
)
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
# Test class setup
|
||||||
|
@pytest.fixture
|
||||||
|
def budget_limiter():
|
||||||
|
dual_cache = DualCache()
|
||||||
|
return _PROXY_VirtualKeyModelMaxBudgetLimiter(dual_cache=dual_cache)
|
||||||
|
|
||||||
|
|
||||||
|
# Test _get_model_without_custom_llm_provider
|
||||||
|
def test_get_model_without_custom_llm_provider(budget_limiter):
|
||||||
|
# Test with custom provider
|
||||||
|
assert (
|
||||||
|
budget_limiter._get_model_without_custom_llm_provider("openai/gpt-4") == "gpt-4"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test without custom provider
|
||||||
|
assert budget_limiter._get_model_without_custom_llm_provider("gpt-4") == "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
# Test _get_request_model_budget_config
|
||||||
|
def test_get_request_model_budget_config(budget_limiter):
|
||||||
|
internal_budget = {
|
||||||
|
"gpt-4": GenericBudgetInfo(budget_limit=100.0, time_period="1d"),
|
||||||
|
"claude-3": GenericBudgetInfo(budget_limit=50.0, time_period="1d"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test direct model match
|
||||||
|
config = budget_limiter._get_request_model_budget_config(
|
||||||
|
model="gpt-4", internal_model_max_budget=internal_budget
|
||||||
|
)
|
||||||
|
assert config.budget_limit == 100.0
|
||||||
|
|
||||||
|
# Test model with provider
|
||||||
|
config = budget_limiter._get_request_model_budget_config(
|
||||||
|
model="openai/gpt-4", internal_model_max_budget=internal_budget
|
||||||
|
)
|
||||||
|
assert config.budget_limit == 100.0
|
||||||
|
|
||||||
|
# Test non-existent model
|
||||||
|
config = budget_limiter._get_request_model_budget_config(
|
||||||
|
model="non-existent", internal_model_max_budget=internal_budget
|
||||||
|
)
|
||||||
|
assert config is None
|
||||||
|
|
||||||
|
|
||||||
|
# Test is_key_within_model_budget
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_key_within_model_budget(budget_limiter):
|
||||||
|
# Mock user API key dict
|
||||||
|
user_api_key = UserAPIKeyAuth(
|
||||||
|
token="test-key",
|
||||||
|
key_alias="test-alias",
|
||||||
|
model_max_budget={"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test when model is within budget
|
||||||
|
with patch.object(
|
||||||
|
budget_limiter, "_get_virtual_key_spend_for_model", return_value=50.0
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-4")
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test when model exceeds budget
|
||||||
|
with patch.object(
|
||||||
|
budget_limiter, "_get_virtual_key_spend_for_model", return_value=150.0
|
||||||
|
):
|
||||||
|
with pytest.raises(litellm.BudgetExceededError):
|
||||||
|
await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-4")
|
||||||
|
|
||||||
|
# Test model not in budget config
|
||||||
|
assert (
|
||||||
|
await budget_limiter.is_key_within_model_budget(user_api_key, "non-existent")
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test _get_virtual_key_spend_for_model
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_virtual_key_spend_for_model(budget_limiter):
|
||||||
|
budget_config = GenericBudgetInfo(budget_limit=100.0, time_period="1d")
|
||||||
|
|
||||||
|
# Mock cache get
|
||||||
|
with patch.object(budget_limiter.dual_cache, "async_get_cache", return_value=50.0):
|
||||||
|
spend = await budget_limiter._get_virtual_key_spend_for_model(
|
||||||
|
user_api_key_hash="test-key", model="gpt-4", key_budget_config=budget_config
|
||||||
|
)
|
||||||
|
assert spend == 50.0
|
||||||
|
|
||||||
|
# Test with provider prefix
|
||||||
|
spend = await budget_limiter._get_virtual_key_spend_for_model(
|
||||||
|
user_api_key_hash="test-key",
|
||||||
|
model="openai/gpt-4",
|
||||||
|
key_budget_config=budget_config,
|
||||||
|
)
|
||||||
|
assert spend == 50.0
|
Loading…
Add table
Add a link
Reference in a new issue