Support budget/rate limit tiers for keys (#7429)

* feat(proxy/utils.py): get associated litellm budget from db in combined_view for key

allows user to create rate limit tiers and associate those to keys

* feat(proxy/_types.py): update the value of key-level tpm/rpm/model max budget metrics with the associated budget table values if set

allows rate limit tiers to be easily applied to keys

* docs(rate_limit_tiers.md): add doc on setting rate limit / budget tiers

make feature discoverable

* feat(key_management_endpoints.py): return litellm_budget_table value in key generate

make it easy for user to know associated budget on key creation

* fix(key_management_endpoints.py): document 'budget_id' param in `/key/generate`

* docs(key_management_endpoints.py): document budget_id usage

* refactor(budget_management_endpoints.py): refactor budget endpoints into separate file - makes it easier to run documentation testing against it

* docs(test_api_docs.py): add budget endpoints to ci/cd doc test + add missing param info to docs

* fix(customer_endpoints.py): use new pydantic obj name

* docs(user_management_heirarchy.md): add simple doc explaining teams/keys/org/users on litellm

* Litellm dev 12 26 2024 p2 (#7432)

* (Feat) Add logging for `POST v1/fine_tuning/jobs`  (#7426)

* init commit ft jobs logging

* add ft logging

* add logging for FineTuningJob

* simple FT Job create test

* (docs) - show all supported Azure OpenAI endpoints in overview  (#7428)

* azure batches

* update doc

* docs azure endpoints

* docs endpoints on azure

* docs azure batches api

* docs azure batches api

* fix(key_management_endpoints.py): fix key update to actually work

* test(test_key_management.py): add e2e test asserting ui key update call works

* fix: proxy/_types - fix linting erros

* test: update test

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>

* fix: test

* fix(parallel_request_limiter.py): enforce tpm/rpm limits on key from tiers

* fix: fix linting errors

* test: fix test

* fix: remove unused import

* test: update test

* docs(customer_endpoints.py): document new model_max_budget param

* test: specify unique key alias

* docs(budget_management_endpoints.py): document new model_max_budget param

* test: fix test

* test: fix tests

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
Krish Dholakia 2024-12-26 19:05:27 -08:00 committed by GitHub
parent 12c4e7e695
commit 539f166166
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 764 additions and 376 deletions

View file

@ -2,11 +2,11 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# 🙋‍♂️ Customers # 🙋‍♂️ Customers / End-User Budgets
Track spend, set budgets for your customers. Track spend, set budgets for your customers.
## Tracking Customer Credit ## Tracking Customer Spend
### 1. Make LLM API call w/ Customer ID ### 1. Make LLM API call w/ Customer ID

View file

@ -0,0 +1,68 @@
# ✨ Budget / Rate Limit Tiers
Create tiers with different budgets and rate limits. Making it easy to manage different users and their usage.
:::info
This is a LiteLLM Enterprise feature.
Get a 7 day free trial + get in touch [here](https://litellm.ai/#trial).
See pricing [here](https://litellm.ai/#pricing).
:::
## 1. Create a budget
```bash
curl -L -X POST 'http://0.0.0.0:4000/budget/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"budget_id": "my-test-tier",
"rpm_limit": 0
}'
```
## 2. Assign budget to a key
```bash
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"budget_id": "my-test-tier"
}'
```
Expected Response:
```json
{
"key": "sk-...",
"budget_id": "my-test-tier",
"litellm_budget_table": {
"budget_id": "my-test-tier",
"rpm_limit": 0
}
}
```
## 3. Check if budget is enforced on key
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-...' \ # 👈 KEY from step 2.
-d '{
"model": "<REPLACE_WITH_MODEL_NAME_FROM_CONFIG.YAML>",
"messages": [
{"role": "user", "content": "hi my email is ishaan"}
]
}'
```
## [API Reference](https://litellm-api.up.railway.app/#/budget%20management)

View file

@ -0,0 +1,13 @@
import Image from '@theme/IdealImage';
# User Management Heirarchy
<Image img={require('../../img/litellm_user_heirarchy.png')} style={{ width: '100%', maxWidth: '4000px' }} />
LiteLLM supports a heirarchy of users, teams, organizations, and budgets.
- Organizations can have multiple teams. [API Reference](https://litellm-api.up.railway.app/#/organization%20management)
- Teams can have multiple users. [API Reference](https://litellm-api.up.railway.app/#/team%20management)
- Users can have multiple keys. [API Reference](https://litellm-api.up.railway.app/#/budget%20management)
- Keys can belong to either a team or a user. [API Reference](https://litellm-api.up.railway.app/#/end-user%20management)

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

View file

@ -51,7 +51,7 @@ const sidebars = {
{ {
type: "category", type: "category",
label: "Architecture", label: "Architecture",
items: ["proxy/architecture", "proxy/db_info", "router_architecture"], items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy"],
}, },
{ {
type: "link", type: "link",
@ -99,8 +99,13 @@ const sidebars = {
}, },
{ {
type: "category", type: "category",
label: "Spend Tracking + Budgets", label: "Spend Tracking",
items: ["proxy/cost_tracking", "proxy/users", "proxy/custom_pricing", "proxy/team_budgets", "proxy/billing", "proxy/customers"], items: ["proxy/cost_tracking", "proxy/custom_pricing", "proxy/billing",],
},
{
type: "category",
label: "Budgets + Rate Limits",
items: ["proxy/users", "proxy/rate_limit_tiers", "proxy/team_budgets", "proxy/customers"],
}, },
{ {
type: "link", type: "link",
@ -135,11 +140,19 @@ const sidebars = {
"oidc" "oidc"
] ]
}, },
"proxy/caching", {
type: "category",
label: "Create Custom Plugins",
description: "Modify requests, responses, and more",
items: [
"proxy/call_hooks", "proxy/call_hooks",
"proxy/rules", "proxy/rules",
] ]
}, },
"proxy/caching",
]
},
{ {
type: "category", type: "category",
label: "Supported Models & Providers", label: "Supported Models & Providers",

View file

@ -633,8 +633,12 @@ class PrometheusLogger(CustomLogger):
) )
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}" remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
remaining_requests = metadata.get(remaining_requests_variable_name, sys.maxsize) remaining_requests = (
remaining_tokens = metadata.get(remaining_tokens_variable_name, sys.maxsize) metadata.get(remaining_requests_variable_name, sys.maxsize) or sys.maxsize
)
remaining_tokens = (
metadata.get(remaining_tokens_variable_name, sys.maxsize) or sys.maxsize
)
self.litellm_remaining_api_key_requests_for_model.labels( self.litellm_remaining_api_key_requests_for_model.labels(
user_api_key, user_api_key_alias, model_group user_api_key, user_api_key_alias, model_group

View file

@ -12,6 +12,7 @@ from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.router import RouterErrors, UpdateRouterConfig
from litellm.types.utils import ( from litellm.types.utils import (
EmbeddingResponse, EmbeddingResponse,
GenericBudgetConfigType,
ImageResponse, ImageResponse,
LiteLLMPydanticObjectBase, LiteLLMPydanticObjectBase,
ModelResponse, ModelResponse,
@ -614,7 +615,6 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
allowed_cache_controls: Optional[list] = [] allowed_cache_controls: Optional[list] = []
soft_budget: Optional[float] = None
config: Optional[dict] = {} config: Optional[dict] = {}
permissions: Optional[dict] = {} permissions: Optional[dict] = {}
model_max_budget: Optional[dict] = ( model_max_budget: Optional[dict] = (
@ -622,7 +622,6 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
send_invite_email: Optional[bool] = None
model_rpm_limit: Optional[dict] = None model_rpm_limit: Optional[dict] = None
model_tpm_limit: Optional[dict] = None model_tpm_limit: Optional[dict] = None
guardrails: Optional[List[str]] = None guardrails: Optional[List[str]] = None
@ -630,21 +629,25 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
aliases: Optional[dict] = {} aliases: Optional[dict] = {}
class _GenerateKeyRequest(GenerateRequestBase): class KeyRequestBase(GenerateRequestBase):
key: Optional[str] = None key: Optional[str] = None
budget_id: Optional[str] = None
class GenerateKeyRequest(_GenerateKeyRequest):
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
enforced_params: Optional[List[str]] = None enforced_params: Optional[List[str]] = None
class GenerateKeyResponse(_GenerateKeyRequest): class GenerateKeyRequest(KeyRequestBase):
soft_budget: Optional[float] = None
send_invite_email: Optional[bool] = None
class GenerateKeyResponse(KeyRequestBase):
key: str # type: ignore key: str # type: ignore
key_name: Optional[str] = None key_name: Optional[str] = None
expires: Optional[datetime] expires: Optional[datetime]
user_id: Optional[str] = None user_id: Optional[str] = None
token_id: Optional[str] = None token_id: Optional[str] = None
litellm_budget_table: Optional[Any] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -669,7 +672,7 @@ class GenerateKeyResponse(_GenerateKeyRequest):
return values return values
class UpdateKeyRequest(GenerateKeyRequest): class UpdateKeyRequest(KeyRequestBase):
# Note: the defaults of all Params here MUST BE NONE # Note: the defaults of all Params here MUST BE NONE
# else they will get overwritten # else they will get overwritten
key: str # type: ignore key: str # type: ignore
@ -765,7 +768,7 @@ class DeleteUserRequest(LiteLLMPydanticObjectBase):
AllowedModelRegion = Literal["eu", "us"] AllowedModelRegion = Literal["eu", "us"]
class BudgetNew(LiteLLMPydanticObjectBase): class BudgetNewRequest(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = Field(default=None, description="The unique budget id.") budget_id: Optional[str] = Field(default=None, description="The unique budget id.")
max_budget: Optional[float] = Field( max_budget: Optional[float] = Field(
default=None, default=None,
@ -788,6 +791,10 @@ class BudgetNew(LiteLLMPydanticObjectBase):
default=None, default=None,
description="Max duration budget should be set for (e.g. '1hr', '1d', '28d')", description="Max duration budget should be set for (e.g. '1hr', '1d', '28d')",
) )
model_max_budget: Optional[GenericBudgetConfigType] = Field(
default=None,
description="Max budget for each model (e.g. {'gpt-4o': {'max_budget': '0.0000001', 'budget_duration': '1d', 'tpm_limit': 1000, 'rpm_limit': 1000}})",
)
class BudgetRequest(LiteLLMPydanticObjectBase): class BudgetRequest(LiteLLMPydanticObjectBase):
@ -805,11 +812,11 @@ class CustomerBase(LiteLLMPydanticObjectBase):
allowed_model_region: Optional[AllowedModelRegion] = None allowed_model_region: Optional[AllowedModelRegion] = None
default_model: Optional[str] = None default_model: Optional[str] = None
budget_id: Optional[str] = None budget_id: Optional[str] = None
litellm_budget_table: Optional[BudgetNew] = None litellm_budget_table: Optional[BudgetNewRequest] = None
blocked: bool = False blocked: bool = False
class NewCustomerRequest(BudgetNew): class NewCustomerRequest(BudgetNewRequest):
""" """
Create a new customer, allocate a budget to them Create a new customer, allocate a budget to them
""" """
@ -1426,6 +1433,19 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
# Time stamps # Time stamps
last_refreshed_at: Optional[float] = None # last time joint view was pulled from db last_refreshed_at: Optional[float] = None # last time joint view was pulled from db
def __init__(self, **kwargs):
# Handle litellm_budget_table_* keys
for key, value in list(kwargs.items()):
if key.startswith("litellm_budget_table_") and value is not None:
# Extract the corresponding attribute name
attr_name = key.replace("litellm_budget_table_", "")
# Check if the value is None and set the corresponding attribute
if getattr(self, attr_name, None) is None:
kwargs[attr_name] = value
# Initialize the superclass
super().__init__(**kwargs)
class UserAPIKeyAuth( class UserAPIKeyAuth(
LiteLLM_VerificationTokenView LiteLLM_VerificationTokenView
@ -2194,9 +2214,9 @@ class ProviderBudgetResponseObject(LiteLLMPydanticObjectBase):
Configuration for a single provider's budget settings Configuration for a single provider's budget settings
""" """
budget_limit: float # Budget limit in USD for the time period budget_limit: Optional[float] # Budget limit in USD for the time period
time_period: str # Time period for budget (e.g., '1d', '30d', '1mo') time_period: Optional[str] # Time period for budget (e.g., '1d', '30d', '1mo')
spend: float = 0.0 # Current spend for this provider spend: Optional[float] = 0.0 # Current spend for this provider
budget_reset_at: Optional[str] = None # When the current budget period resets budget_reset_at: Optional[str] = None # When the current budget period resets

View file

@ -418,6 +418,12 @@ def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]
if user_api_key_dict.metadata: if user_api_key_dict.metadata:
if "model_rpm_limit" in user_api_key_dict.metadata: if "model_rpm_limit" in user_api_key_dict.metadata:
return user_api_key_dict.metadata["model_rpm_limit"] return user_api_key_dict.metadata["model_rpm_limit"]
elif user_api_key_dict.model_max_budget:
model_rpm_limit: Dict[str, Any] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if "rpm_limit" in budget and budget["rpm_limit"] is not None:
model_rpm_limit[model] = budget["rpm_limit"]
return model_rpm_limit
return None return None
@ -426,6 +432,9 @@ def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]
if user_api_key_dict.metadata: if user_api_key_dict.metadata:
if "model_tpm_limit" in user_api_key_dict.metadata: if "model_tpm_limit" in user_api_key_dict.metadata:
return user_api_key_dict.metadata["model_tpm_limit"] return user_api_key_dict.metadata["model_tpm_limit"]
elif user_api_key_dict.model_max_budget:
if "tpm_limit" in user_api_key_dict.model_max_budget:
return user_api_key_dict.model_max_budget["tpm_limit"]
return None return None

View file

@ -9,8 +9,8 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ( from litellm.types.utils import (
BudgetConfig,
GenericBudgetConfigType, GenericBudgetConfigType,
GenericBudgetInfo,
StandardLoggingPayload, StandardLoggingPayload,
) )
@ -42,12 +42,8 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
_model_max_budget = user_api_key_dict.model_max_budget _model_max_budget = user_api_key_dict.model_max_budget
internal_model_max_budget: GenericBudgetConfigType = {} internal_model_max_budget: GenericBudgetConfigType = {}
# case each element in _model_max_budget to GenericBudgetInfo
for _model, _budget_info in _model_max_budget.items(): for _model, _budget_info in _model_max_budget.items():
internal_model_max_budget[_model] = GenericBudgetInfo( internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
time_period=_budget_info.get("time_period"),
budget_limit=float(_budget_info.get("budget_limit")),
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"internal_model_max_budget %s", "internal_model_max_budget %s",
@ -65,7 +61,10 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
return True return True
# check if current model is within budget # check if current model is within budget
if _current_model_budget_info.budget_limit > 0: if (
_current_model_budget_info.max_budget
and _current_model_budget_info.max_budget > 0
):
_current_spend = await self._get_virtual_key_spend_for_model( _current_spend = await self._get_virtual_key_spend_for_model(
user_api_key_hash=user_api_key_dict.token, user_api_key_hash=user_api_key_dict.token,
model=model, model=model,
@ -73,12 +72,13 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
) )
if ( if (
_current_spend is not None _current_spend is not None
and _current_spend > _current_model_budget_info.budget_limit and _current_model_budget_info.max_budget is not None
and _current_spend > _current_model_budget_info.max_budget
): ):
raise litellm.BudgetExceededError( raise litellm.BudgetExceededError(
message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}", message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
current_cost=_current_spend, current_cost=_current_spend,
max_budget=_current_model_budget_info.budget_limit, max_budget=_current_model_budget_info.max_budget,
) )
return True return True
@ -87,7 +87,7 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
self, self,
user_api_key_hash: Optional[str], user_api_key_hash: Optional[str],
model: str, model: str,
key_budget_config: GenericBudgetInfo, key_budget_config: BudgetConfig,
) -> Optional[float]: ) -> Optional[float]:
""" """
Get the current spend for a virtual key for a model Get the current spend for a virtual key for a model
@ -98,7 +98,7 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
""" """
# 1. model: directly look up `model` # 1. model: directly look up `model`
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.time_period}" virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache( _current_spend = await self.dual_cache.async_get_cache(
key=virtual_key_model_spend_cache_key, key=virtual_key_model_spend_cache_key,
) )
@ -106,7 +106,7 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
if _current_spend is None: if _current_spend is None:
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model # 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
# if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview # if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.time_period}" virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache( _current_spend = await self.dual_cache.async_get_cache(
key=virtual_key_model_spend_cache_key, key=virtual_key_model_spend_cache_key,
) )
@ -114,7 +114,7 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
def _get_request_model_budget_config( def _get_request_model_budget_config(
self, model: str, internal_model_max_budget: GenericBudgetConfigType self, model: str, internal_model_max_budget: GenericBudgetConfigType
) -> Optional[GenericBudgetInfo]: ) -> Optional[BudgetConfig]:
""" """
Get the budget config for the request model Get the budget config for the request model
@ -175,8 +175,8 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash") virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash")
model = standard_logging_payload.get("model") model = standard_logging_payload.get("model")
if virtual_key is not None: if virtual_key is not None:
budget_config = GenericBudgetInfo(time_period="1d", budget_limit=0.1) budget_config = BudgetConfig(time_period="1d", budget_limit=0.1)
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.time_period}" virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.budget_duration}"
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}" virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
await self._increment_spend_for_key( await self._increment_spend_for_key(
budget_config=budget_config, budget_config=budget_config,

View file

@ -317,7 +317,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
if _model is not None: if _model is not None:
if _tpm_limit_for_key_model: if _tpm_limit_for_key_model:
@ -325,6 +324,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if _rpm_limit_for_key_model: if _rpm_limit_for_key_model:
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model) rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
if current is None: if current is None:
new_val = { new_val = {
"current_requests": 1, "current_requests": 1,
@ -485,6 +485,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
) )
try: try:
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None "global_max_parallel_requests", None
) )
@ -495,6 +496,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_team_id", None "user_api_key_team_id", None
) )
user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get(
"user_api_key_model_max_budget", None
)
user_api_key_end_user_id = kwargs.get("user") user_api_key_end_user_id = kwargs.get("user")
user_api_key_metadata = ( user_api_key_metadata = (
@ -568,6 +572,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
and ( and (
"model_rpm_limit" in user_api_key_metadata "model_rpm_limit" in user_api_key_metadata
or "model_tpm_limit" in user_api_key_metadata or "model_tpm_limit" in user_api_key_metadata
or user_api_key_model_max_budget is not None
) )
): ):
request_count_api_key = ( request_count_api_key = (

View file

@ -0,0 +1,287 @@
"""
BUDGET MANAGEMENT
All /budget management endpoints
/budget/new
/budget/info
/budget/update
/budget/delete
/budget/settings
/budget/list
"""
#### BUDGET TABLE MANAGEMENT ####
from fastapi import APIRouter, Depends, HTTPException
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.utils import jsonify_object
router = APIRouter()
@router.post(
"/budget/new",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def new_budget(
budget_obj: BudgetNewRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
Parameters:
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
- budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated.
- max_budget: Optional[float] - The max budget for the budget.
- soft_budget: Optional[float] - The soft budget for the budget.
- max_parallel_requests: Optional[int] - The max number of parallel requests for the budget.
- tpm_limit: Optional[int] - The tokens per minute limit for the budget.
- rpm_limit: Optional[int] - The requests per minute limit for the budget.
- model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}}
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
budget_obj_json = budget_obj.model_dump(exclude_none=True)
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
response = await prisma_client.db.litellm_budgettable.create(
data={
**budget_obj_jsonified, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
} # type: ignore
)
return response
@router.post(
"/budget/update",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_budget(
budget_obj: BudgetNewRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update an existing budget object.
Parameters:
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
- budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated.
- max_budget: Optional[float] - The max budget for the budget.
- soft_budget: Optional[float] - The soft budget for the budget.
- max_parallel_requests: Optional[int] - The max number of parallel requests for the budget.
- tpm_limit: Optional[int] - The tokens per minute limit for the budget.
- rpm_limit: Optional[int] - The requests per minute limit for the budget.
- model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}}
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if budget_obj.budget_id is None:
raise HTTPException(status_code=400, detail={"error": "budget_id is required"})
response = await prisma_client.db.litellm_budgettable.update(
where={"budget_id": budget_obj.budget_id},
data={
**budget_obj.model_dump(exclude_none=True), # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}, # type: ignore
)
return response
@router.post(
"/budget/info",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_budget(data: BudgetRequest):
"""
Get the budget id specific information
Parameters:
- budgets: List[str] - The list of budget ids to get information for
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if len(data.budgets) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
},
)
response = await prisma_client.db.litellm_budgettable.find_many(
where={"budget_id": {"in": data.budgets}},
)
return response
@router.get(
"/budget/settings",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def budget_settings(
budget_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get list of configurable params + current value for a budget item + description of each field
Used on Admin UI.
Query Parameters:
- budget_id: str - The budget id to get information for
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
## get budget item from db
db_budget_row = await prisma_client.db.litellm_budgettable.find_first(
where={"budget_id": budget_id}
)
if db_budget_row is not None:
db_budget_row_dict = db_budget_row.model_dump(exclude_none=True)
else:
db_budget_row_dict = {}
allowed_args = {
"max_parallel_requests": {"type": "Integer"},
"tpm_limit": {"type": "Integer"},
"rpm_limit": {"type": "Integer"},
"budget_duration": {"type": "String"},
"max_budget": {"type": "Float"},
"soft_budget": {"type": "Float"},
}
return_val = []
for field_name, field_info in BudgetNewRequest.model_fields.items():
if field_name in allowed_args:
_stored_in_db = True
_response_obj = ConfigList(
field_name=field_name,
field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "",
field_value=db_budget_row_dict.get(field_name, None),
stored_in_db=_stored_in_db,
field_default_value=field_info.default,
)
return_val.append(_response_obj)
return return_val
@router.get(
"/budget/list",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def list_budget(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""List all the created budgets in proxy db. Used on Admin UI."""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
response = await prisma_client.db.litellm_budgettable.find_many()
return response
@router.post(
"/budget/delete",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_budget(
data: BudgetDeleteRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete budget
Parameters:
- id: str - The budget id to delete
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
response = await prisma_client.db.litellm_budgettable.delete(
where={"budget_id": data.id}
)
return response

View file

@ -131,11 +131,11 @@ async def unblock_user(data: BlockUsers):
return {"blocked_users": litellm.blocked_user_list} return {"blocked_users": litellm.blocked_user_list}
def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNew]: def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNewRequest]:
""" """
Return a new budget object if new budget params are passed. Return a new budget object if new budget params are passed.
""" """
budget_params = BudgetNew.model_fields.keys() budget_params = BudgetNewRequest.model_fields.keys()
budget_kv_pairs = {} budget_kv_pairs = {}
# Get the actual values from the data object using getattr # Get the actual values from the data object using getattr
@ -147,7 +147,7 @@ def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNew]:
budget_kv_pairs[field_name] = value budget_kv_pairs[field_name] = value
if budget_kv_pairs: if budget_kv_pairs:
return BudgetNew(**budget_kv_pairs) return BudgetNewRequest(**budget_kv_pairs)
return None return None
@ -182,6 +182,7 @@ async def new_end_user(
- budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). - budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
- tpm_limit: Optional[int] - [Not Implemented Yet] Specify tpm limit for a given customer (Tokens per minute) - tpm_limit: Optional[int] - [Not Implemented Yet] Specify tpm limit for a given customer (Tokens per minute)
- rpm_limit: Optional[int] - [Not Implemented Yet] Specify rpm limit for a given customer (Requests per minute) - rpm_limit: Optional[int] - [Not Implemented Yet] Specify rpm limit for a given customer (Requests per minute)
- model_max_budget: Optional[dict] - [Not Implemented Yet] Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d"}}
- max_parallel_requests: Optional[int] - [Not Implemented Yet] Specify max parallel requests for a given customer. - max_parallel_requests: Optional[int] - [Not Implemented Yet] Specify max parallel requests for a given customer.
- soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests. - soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests.
@ -271,7 +272,7 @@ async def new_end_user(
_user_data = data.dict(exclude_none=True) _user_data = data.dict(exclude_none=True)
for k, v in _user_data.items(): for k, v in _user_data.items():
if k not in BudgetNew.model_fields.keys(): if k not in BudgetNewRequest.model_fields.keys():
new_end_user_obj[k] = v new_end_user_obj[k] = v
## WRITE TO DB ## ## WRITE TO DB ##

View file

@ -40,7 +40,7 @@ from litellm.proxy.utils import (
) )
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.utils import ( from litellm.types.utils import (
GenericBudgetInfo, BudgetConfig,
PersonalUIKeyGenerationConfig, PersonalUIKeyGenerationConfig,
TeamUIKeyGenerationConfig, TeamUIKeyGenerationConfig,
) )
@ -238,6 +238,7 @@ async def generate_key_fn( # noqa: PLR0915
- key: Optional[str] - User defined key value. If not set, a 16-digit unique sk-key is created for you. - key: Optional[str] - User defined key value. If not set, a 16-digit unique sk-key is created for you.
- team_id: Optional[str] - The team id of the key - team_id: Optional[str] - The team id of the key
- user_id: Optional[str] - The user id of the key - user_id: Optional[str] - The user id of the key
- budget_id: Optional[str] - The budget id associated with the key. Created by calling `/budget/new`.
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml - config: Optional[dict] - any key-specific configs, overrides config in config.yaml
@ -249,7 +250,7 @@ async def generate_key_fn( # noqa: PLR0915
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
- guardrails: Optional[List[str]] - List of active guardrails for the key - guardrails: Optional[List[str]] - List of active guardrails for the key
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false} - permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget. - model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget.
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit. - model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
- model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit. - model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit.
- allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request - allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request
@ -376,7 +377,7 @@ async def generate_key_fn( # noqa: PLR0915
) )
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
_budget_id = None _budget_id = data.budget_id
if prisma_client is not None and data.soft_budget is not None: if prisma_client is not None and data.soft_budget is not None:
# create the Budget Row for the LiteLLM Verification Token # create the Budget Row for the LiteLLM Verification Token
budget_row = LiteLLM_BudgetTable( budget_row = LiteLLM_BudgetTable(
@ -547,14 +548,15 @@ async def update_key_fn(
- key_alias: Optional[str] - User-friendly key alias - key_alias: Optional[str] - User-friendly key alias
- user_id: Optional[str] - User ID associated with key - user_id: Optional[str] - User ID associated with key
- team_id: Optional[str] - Team ID associated with key - team_id: Optional[str] - Team ID associated with key
- budget_id: Optional[str] - The budget id associated with the key. Created by calling `/budget/new`.
- models: Optional[list] - Model_name's a user is allowed to call - models: Optional[list] - Model_name's a user is allowed to call
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only) - tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests) - enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
- spend: Optional[float] - Amount spent by key - spend: Optional[float] - Amount spent by key
- max_budget: Optional[float] - Max budget for key - max_budget: Optional[float] - Max budget for key
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}} - model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.) - budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached. - soft_budget: Optional[float] - [TODO] Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
- max_parallel_requests: Optional[int] - Rate limit for parallel requests - max_parallel_requests: Optional[int] - Rate limit for parallel requests
- metadata: Optional[dict] - Metadata for key. Example {"team": "core-infra", "app": "app2"} - metadata: Optional[dict] - Metadata for key. Example {"team": "core-infra", "app": "app2"}
- tpm_limit: Optional[int] - Tokens per minute limit - tpm_limit: Optional[int] - Tokens per minute limit
@ -592,7 +594,7 @@ async def update_key_fn(
) )
try: try:
data_json: dict = data.model_dump(exclude_unset=True) data_json: dict = data.model_dump(exclude_unset=True, exclude_none=True)
key = data_json.pop("key") key = data_json.pop("key")
# get the row from db # get the row from db
if prisma_client is None: if prisma_client is None:
@ -1135,6 +1137,9 @@ async def generate_key_helper_fn( # noqa: PLR0915
data=key_data, table_name="key" data=key_data, table_name="key"
) )
key_data["token_id"] = getattr(create_key_response, "token", None) key_data["token_id"] = getattr(create_key_response, "token", None)
key_data["litellm_budget_table"] = getattr(
create_key_response, "litellm_budget_table", None
)
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format( "litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format(
@ -1247,7 +1252,7 @@ async def regenerate_key_fn(
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only) - tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
- spend: Optional[float] - Amount spent by key - spend: Optional[float] - Amount spent by key
- max_budget: Optional[float] - Max budget for key - max_budget: Optional[float] - Max budget for key
- model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}} - model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.) - budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
- soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached. - soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached.
- max_parallel_requests: Optional[int] - Rate limit for parallel requests - max_parallel_requests: Optional[int] - Rate limit for parallel requests
@ -1956,7 +1961,7 @@ def validate_model_max_budget(model_max_budget: Optional[Dict]) -> None:
# /CRUD endpoints can pass budget_limit as a string, so we need to convert it to a float # /CRUD endpoints can pass budget_limit as a string, so we need to convert it to a float
if "budget_limit" in _budget_info: if "budget_limit" in _budget_info:
_budget_info["budget_limit"] = float(_budget_info["budget_limit"]) _budget_info["budget_limit"] = float(_budget_info["budget_limit"])
GenericBudgetInfo(**_budget_info) BudgetConfig(**_budget_info)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"Invalid model_max_budget: {str(e)}. Example of valid model_max_budget: https://docs.litellm.ai/docs/proxy/users" f"Invalid model_max_budget: {str(e)}. Example of valid model_max_budget: https://docs.litellm.ai/docs/proxy/users"

View file

@ -178,6 +178,9 @@ from litellm.proxy.hooks.prompt_injection_detection import (
from litellm.proxy.hooks.proxy_failure_handler import _PROXY_failure_handler from litellm.proxy.hooks.proxy_failure_handler import _PROXY_failure_handler
from litellm.proxy.hooks.proxy_track_cost_callback import _PROXY_track_cost_callback from litellm.proxy.hooks.proxy_track_cost_callback import _PROXY_track_cost_callback
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.management_endpoints.budget_management_endpoints import (
router as budget_management_router,
)
from litellm.proxy.management_endpoints.customer_endpoints import ( from litellm.proxy.management_endpoints.customer_endpoints import (
router as customer_router, router as customer_router,
) )
@ -5531,238 +5534,6 @@ async def supported_openai_params(model: str):
) )
#### BUDGET TABLE MANAGEMENT ####
@router.post(
"/budget/new",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def new_budget(
budget_obj: BudgetNew,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
"""
global prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
response = await prisma_client.db.litellm_budgettable.create(
data={
**budget_obj.model_dump(exclude_none=True), # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
} # type: ignore
)
return response
@router.post(
"/budget/update",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_budget(
budget_obj: BudgetNew,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
"""
global prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if budget_obj.budget_id is None:
raise HTTPException(status_code=400, detail={"error": "budget_id is required"})
response = await prisma_client.db.litellm_budgettable.update(
where={"budget_id": budget_obj.budget_id},
data={
**budget_obj.model_dump(exclude_none=True), # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}, # type: ignore
)
return response
@router.post(
"/budget/info",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_budget(data: BudgetRequest):
"""
Get the budget id specific information
"""
global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if len(data.budgets) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
},
)
response = await prisma_client.db.litellm_budgettable.find_many(
where={"budget_id": {"in": data.budgets}},
)
return response
@router.get(
"/budget/settings",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def budget_settings(
budget_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get list of configurable params + current value for a budget item + description of each field
Used on Admin UI.
"""
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
## get budget item from db
db_budget_row = await prisma_client.db.litellm_budgettable.find_first(
where={"budget_id": budget_id}
)
if db_budget_row is not None:
db_budget_row_dict = db_budget_row.model_dump(exclude_none=True)
else:
db_budget_row_dict = {}
allowed_args = {
"max_parallel_requests": {"type": "Integer"},
"tpm_limit": {"type": "Integer"},
"rpm_limit": {"type": "Integer"},
"budget_duration": {"type": "String"},
"max_budget": {"type": "Float"},
"soft_budget": {"type": "Float"},
}
return_val = []
for field_name, field_info in BudgetNew.model_fields.items():
if field_name in allowed_args:
_stored_in_db = True
_response_obj = ConfigList(
field_name=field_name,
field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "",
field_value=db_budget_row_dict.get(field_name, None),
stored_in_db=_stored_in_db,
field_default_value=field_info.default,
)
return_val.append(_response_obj)
return return_val
@router.get(
"/budget/list",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def list_budget(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""List all the created budgets in proxy db. Used on Admin UI."""
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
response = await prisma_client.db.litellm_budgettable.find_many()
return response
@router.post(
"/budget/delete",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_budget(
data: BudgetDeleteRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""Delete budget"""
global prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
response = await prisma_client.db.litellm_budgettable.delete(
where={"budget_id": data.id}
)
return response
#### MODEL MANAGEMENT #### #### MODEL MANAGEMENT ####
@ -8856,3 +8627,4 @@ app.include_router(debugging_endpoints_router)
app.include_router(ui_crud_endpoints_router) app.include_router(ui_crud_endpoints_router)
app.include_router(openai_files_router) app.include_router(openai_files_router)
app.include_router(team_callback_router) app.include_router(team_callback_router)
app.include_router(budget_management_router)

View file

@ -2533,8 +2533,8 @@ async def provider_budgets() -> ProviderBudgetResponse:
_provider _provider
) )
provider_budget_response_object = ProviderBudgetResponseObject( provider_budget_response_object = ProviderBudgetResponseObject(
budget_limit=_budget_info.budget_limit, budget_limit=_budget_info.max_budget,
time_period=_budget_info.time_period, time_period=_budget_info.budget_duration,
spend=_provider_spend, spend=_provider_spend,
budget_reset_at=_provider_budget_ttl, budget_reset_at=_provider_budget_ttl,
) )

View file

@ -1018,6 +1018,19 @@ def on_backoff(details):
print_verbose(f"Backing off... this was attempt #{details['tries']}") print_verbose(f"Backing off... this was attempt #{details['tries']}")
def jsonify_object(data: dict) -> dict:
db_data = copy.deepcopy(data)
for k, v in db_data.items():
if isinstance(v, dict):
try:
db_data[k] = json.dumps(v)
except Exception:
# This avoids Prisma retrying this 5 times, and making 5 clients
db_data[k] = "failed-to-serialize-json"
return db_data
class PrismaClient: class PrismaClient:
user_list_transactons: dict = {} user_list_transactons: dict = {}
end_user_list_transactons: dict = {} end_user_list_transactons: dict = {}
@ -1516,11 +1529,17 @@ class PrismaClient:
t.metadata AS team_metadata, t.metadata AS team_metadata,
t.members_with_roles AS team_members_with_roles, t.members_with_roles AS team_members_with_roles,
tm.spend AS team_member_spend, tm.spend AS team_member_spend,
m.aliases as team_model_aliases m.aliases AS team_model_aliases,
-- Added comma to separate b.* columns
b.max_budget AS litellm_budget_table_max_budget,
b.tpm_limit AS litellm_budget_table_tpm_limit,
b.rpm_limit AS litellm_budget_table_rpm_limit,
b.model_max_budget as litellm_budget_table_model_max_budget
FROM "LiteLLM_VerificationToken" AS v FROM "LiteLLM_VerificationToken" AS v
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
LEFT JOIN "LiteLLM_BudgetTable" AS b ON v.budget_id = b.budget_id
WHERE v.token = '{token}' WHERE v.token = '{token}'
""" """
@ -1634,6 +1653,7 @@ class PrismaClient:
"create": {**db_data}, # type: ignore "create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
include={"litellm_budget_table": True},
) )
verbose_proxy_logger.info("Data Inserted into Keys Table") verbose_proxy_logger.info("Data Inserted into Keys Table")
return new_verification_token return new_verification_token

View file

@ -98,7 +98,6 @@ from litellm.types.router import (
CustomRoutingStrategyBase, CustomRoutingStrategyBase,
Deployment, Deployment,
DeploymentTypedDict, DeploymentTypedDict,
GenericBudgetConfigType,
LiteLLM_Params, LiteLLM_Params,
ModelGroupInfo, ModelGroupInfo,
OptionalPreCallChecks, OptionalPreCallChecks,
@ -111,6 +110,7 @@ from litellm.types.router import (
RoutingStrategy, RoutingStrategy,
) )
from litellm.types.services import ServiceTypes from litellm.types.services import ServiceTypes
from litellm.types.utils import GenericBudgetConfigType
from litellm.types.utils import ModelInfo as ModelMapInfo from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from litellm.utils import ( from litellm.utils import (

View file

@ -33,14 +33,10 @@ from litellm.router_utils.cooldown_callbacks import (
_get_prometheus_logger_from_callbacks, _get_prometheus_logger_from_callbacks,
) )
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import ( from litellm.types.router import DeploymentTypedDict, LiteLLM_Params, RouterErrors
DeploymentTypedDict, from litellm.types.utils import BudgetConfig
GenericBudgetConfigType, from litellm.types.utils import BudgetConfig as GenericBudgetInfo
GenericBudgetInfo, from litellm.types.utils import GenericBudgetConfigType, StandardLoggingPayload
LiteLLM_Params,
RouterErrors,
)
from litellm.types.utils import BudgetConfig, StandardLoggingPayload
DEFAULT_REDIS_SYNC_INTERVAL = 1 DEFAULT_REDIS_SYNC_INTERVAL = 1
@ -170,17 +166,19 @@ class RouterBudgetLimiting(CustomLogger):
provider = self._get_llm_provider_for_deployment(deployment) provider = self._get_llm_provider_for_deployment(deployment)
if provider in provider_configs: if provider in provider_configs:
config = provider_configs[provider] config = provider_configs[provider]
if config.max_budget is None:
continue
current_spend = spend_map.get( current_spend = spend_map.get(
f"provider_spend:{provider}:{config.time_period}", 0.0 f"provider_spend:{provider}:{config.budget_duration}", 0.0
) )
self._track_provider_remaining_budget_prometheus( self._track_provider_remaining_budget_prometheus(
provider=provider, provider=provider,
spend=current_spend, spend=current_spend,
budget_limit=config.budget_limit, budget_limit=config.max_budget,
) )
if current_spend >= config.budget_limit: if config.max_budget and current_spend >= config.max_budget:
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.budget_limit}" debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.max_budget}"
deployment_above_budget_info += f"{debug_msg}\n" deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False is_within_budget = False
continue continue
@ -194,30 +192,32 @@ class RouterBudgetLimiting(CustomLogger):
if model_id in deployment_configs: if model_id in deployment_configs:
config = deployment_configs[model_id] config = deployment_configs[model_id]
current_spend = spend_map.get( current_spend = spend_map.get(
f"deployment_spend:{model_id}:{config.time_period}", 0.0 f"deployment_spend:{model_id}:{config.budget_duration}", 0.0
) )
if current_spend >= config.budget_limit: if config.max_budget and current_spend >= config.max_budget:
debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_limit}" debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_duration}"
verbose_router_logger.debug(debug_msg) verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n" deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False is_within_budget = False
continue continue
# Check tag budget # Check tag budget
if self.tag_budget_config and is_within_budget: if self.tag_budget_config and is_within_budget:
for _tag in request_tags: for _tag in request_tags:
_tag_budget_config = self._get_budget_config_for_tag(_tag) _tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config: if _tag_budget_config:
_tag_spend = spend_map.get( _tag_spend = spend_map.get(
f"tag_spend:{_tag}:{_tag_budget_config.time_period}", 0.0 f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}",
0.0,
) )
if _tag_spend >= _tag_budget_config.budget_limit: if (
debug_msg = f"Exceeded budget for tag='{_tag}', tag_spend={_tag_spend}, tag_budget_limit={_tag_budget_config.budget_limit}" _tag_budget_config.max_budget
and _tag_spend >= _tag_budget_config.max_budget
):
debug_msg = f"Exceeded budget for tag='{_tag}', tag_spend={_tag_spend}, tag_budget_limit={_tag_budget_config.max_budget}"
verbose_router_logger.debug(debug_msg) verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n" deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False is_within_budget = False
continue continue
if is_within_budget: if is_within_budget:
potential_deployments.append(deployment) potential_deployments.append(deployment)
@ -247,10 +247,13 @@ class RouterBudgetLimiting(CustomLogger):
provider = self._get_llm_provider_for_deployment(deployment) provider = self._get_llm_provider_for_deployment(deployment)
if provider is not None: if provider is not None:
budget_config = self._get_budget_config_for_provider(provider) budget_config = self._get_budget_config_for_provider(provider)
if budget_config is not None: if (
budget_config is not None
and budget_config.budget_duration is not None
):
provider_configs[provider] = budget_config provider_configs[provider] = budget_config
cache_keys.append( cache_keys.append(
f"provider_spend:{provider}:{budget_config.time_period}" f"provider_spend:{provider}:{budget_config.budget_duration}"
) )
# Check deployment budgets # Check deployment budgets
@ -261,7 +264,7 @@ class RouterBudgetLimiting(CustomLogger):
if budget_config is not None: if budget_config is not None:
deployment_configs[model_id] = budget_config deployment_configs[model_id] = budget_config
cache_keys.append( cache_keys.append(
f"deployment_spend:{model_id}:{budget_config.time_period}" f"deployment_spend:{model_id}:{budget_config.budget_duration}"
) )
# Check tag budgets # Check tag budgets
if self.tag_budget_config: if self.tag_budget_config:
@ -272,7 +275,7 @@ class RouterBudgetLimiting(CustomLogger):
_tag_budget_config = self._get_budget_config_for_tag(_tag) _tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config: if _tag_budget_config:
cache_keys.append( cache_keys.append(
f"tag_spend:{_tag}:{_tag_budget_config.time_period}" f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
) )
return cache_keys, provider_configs, deployment_configs return cache_keys, provider_configs, deployment_configs
@ -365,7 +368,7 @@ class RouterBudgetLimiting(CustomLogger):
if budget_config: if budget_config:
# increment spend for provider # increment spend for provider
spend_key = ( spend_key = (
f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" f"provider_spend:{custom_llm_provider}:{budget_config.budget_duration}"
) )
start_time_key = f"provider_budget_start_time:{custom_llm_provider}" start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
await self._increment_spend_for_key( await self._increment_spend_for_key(
@ -378,9 +381,7 @@ class RouterBudgetLimiting(CustomLogger):
deployment_budget_config = self._get_budget_config_for_deployment(model_id) deployment_budget_config = self._get_budget_config_for_deployment(model_id)
if deployment_budget_config: if deployment_budget_config:
# increment spend for specific deployment id # increment spend for specific deployment id
deployment_spend_key = ( deployment_spend_key = f"deployment_spend:{model_id}:{deployment_budget_config.budget_duration}"
f"deployment_spend:{model_id}:{deployment_budget_config.time_period}"
)
deployment_start_time_key = f"deployment_budget_start_time:{model_id}" deployment_start_time_key = f"deployment_budget_start_time:{model_id}"
await self._increment_spend_for_key( await self._increment_spend_for_key(
budget_config=deployment_budget_config, budget_config=deployment_budget_config,
@ -395,7 +396,7 @@ class RouterBudgetLimiting(CustomLogger):
_tag_budget_config = self._get_budget_config_for_tag(_tag) _tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config: if _tag_budget_config:
_tag_spend_key = ( _tag_spend_key = (
f"tag_spend:{_tag}:{_tag_budget_config.time_period}" f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
) )
_tag_start_time_key = f"tag_budget_start_time:{_tag}" _tag_start_time_key = f"tag_budget_start_time:{_tag}"
await self._increment_spend_for_key( await self._increment_spend_for_key(
@ -412,8 +413,11 @@ class RouterBudgetLimiting(CustomLogger):
start_time_key: str, start_time_key: str,
response_cost: float, response_cost: float,
): ):
if budget_config.budget_duration is None:
return
current_time = datetime.now(timezone.utc).timestamp() current_time = datetime.now(timezone.utc).timestamp()
ttl_seconds = duration_in_seconds(budget_config.time_period) ttl_seconds = duration_in_seconds(budget_config.budget_duration)
budget_start = await self._get_or_set_budget_start_time( budget_start = await self._get_or_set_budget_start_time(
start_time_key=start_time_key, start_time_key=start_time_key,
@ -529,21 +533,23 @@ class RouterBudgetLimiting(CustomLogger):
for provider, config in self.provider_budget_config.items(): for provider, config in self.provider_budget_config.items():
if config is None: if config is None:
continue continue
cache_keys.append(f"provider_spend:{provider}:{config.time_period}") cache_keys.append(
f"provider_spend:{provider}:{config.budget_duration}"
)
if self.deployment_budget_config is not None: if self.deployment_budget_config is not None:
for model_id, config in self.deployment_budget_config.items(): for model_id, config in self.deployment_budget_config.items():
if config is None: if config is None:
continue continue
cache_keys.append( cache_keys.append(
f"deployment_spend:{model_id}:{config.time_period}" f"deployment_spend:{model_id}:{config.budget_duration}"
) )
if self.tag_budget_config is not None: if self.tag_budget_config is not None:
for tag, config in self.tag_budget_config.items(): for tag, config in self.tag_budget_config.items():
if config is None: if config is None:
continue continue
cache_keys.append(f"tag_spend:{tag}:{config.time_period}") cache_keys.append(f"tag_spend:{tag}:{config.budget_duration}")
# Batch fetch current spend values from Redis # Batch fetch current spend values from Redis
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache( redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
@ -635,7 +641,7 @@ class RouterBudgetLimiting(CustomLogger):
if budget_config is None: if budget_config is None:
return None return None
spend_key = f"provider_spend:{provider}:{budget_config.time_period}" spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
if self.dual_cache.redis_cache: if self.dual_cache.redis_cache:
# use Redis as source of truth since that has spend across all instances # use Redis as source of truth since that has spend across all instances
@ -652,7 +658,7 @@ class RouterBudgetLimiting(CustomLogger):
if budget_config is None: if budget_config is None:
return None return None
spend_key = f"provider_spend:{provider}:{budget_config.time_period}" spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
if self.dual_cache.redis_cache: if self.dual_cache.redis_cache:
ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key) ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
else: else:
@ -672,9 +678,13 @@ class RouterBudgetLimiting(CustomLogger):
- provider_budget_start_time:{provider} - stores the start time of the budget window - provider_budget_start_time:{provider} - stores the start time of the budget window
""" """
spend_key = f"provider_spend:{provider}:{budget_config.time_period}"
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
start_time_key = f"provider_budget_start_time:{provider}" start_time_key = f"provider_budget_start_time:{provider}"
ttl_seconds = duration_in_seconds(budget_config.time_period) ttl_seconds: Optional[int] = None
if budget_config.budget_duration is not None:
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
budget_start = await self.dual_cache.async_get_cache(start_time_key) budget_start = await self.dual_cache.async_get_cache(start_time_key)
if budget_start is None: if budget_start is None:
budget_start = datetime.now(timezone.utc).timestamp() budget_start = datetime.now(timezone.utc).timestamp()

View file

@ -11,8 +11,6 @@ import httpx
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
from litellm.types.utils import GenericBudgetConfigType, GenericBudgetInfo
from ..exceptions import RateLimitError from ..exceptions import RateLimitError
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest

View file

@ -1694,17 +1694,25 @@ class StandardKeyGenerationConfig(TypedDict, total=False):
personal_key_generation: PersonalUIKeyGenerationConfig personal_key_generation: PersonalUIKeyGenerationConfig
class GenericBudgetInfo(BaseModel):
time_period: str # e.g., '1d', '30d'
budget_limit: float
GenericBudgetConfigType = Dict[str, GenericBudgetInfo]
class BudgetConfig(BaseModel): class BudgetConfig(BaseModel):
max_budget: float max_budget: Optional[float] = None
budget_duration: str budget_duration: Optional[str] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
def __init__(self, **data: Any) -> None:
# Map time_period to budget_duration if present
if "time_period" in data:
data["budget_duration"] = data.pop("time_period")
# Map budget_limit to max_budget if present
if "budget_limit" in data:
data["max_budget"] = data.pop("budget_limit")
super().__init__(**data)
GenericBudgetConfigType = Dict[str, BudgetConfig]
class LlmProviders(str, Enum): class LlmProviders(str, Enum):

View file

@ -172,6 +172,11 @@ def main():
"delete_organization", "delete_organization",
"list_organization", "list_organization",
"user_update", "user_update",
"new_budget",
"info_budget",
"update_budget",
"delete_budget",
"list_budget",
] ]
# directory = "../../litellm/proxy/management_endpoints" # LOCAL # directory = "../../litellm/proxy/management_endpoints" # LOCAL
directory = "./litellm/proxy/management_endpoints" directory = "./litellm/proxy/management_endpoints"

View file

@ -14,15 +14,13 @@ from litellm import Router
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.types.router import ( from litellm.types.router import (
RoutingStrategy, RoutingStrategy,
GenericBudgetConfigType,
GenericBudgetInfo,
) )
from litellm.types.utils import GenericBudgetConfigType, BudgetConfig
from litellm.caching.caching import DualCache, RedisCache from litellm.caching.caching import DualCache, RedisCache
import logging import logging
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
import litellm import litellm
from datetime import timezone, timedelta from datetime import timezone, timedelta
from litellm.types.utils import BudgetConfig
verbose_router_logger.setLevel(logging.DEBUG) verbose_router_logger.setLevel(logging.DEBUG)
@ -67,8 +65,8 @@ async def test_provider_budgets_e2e_test():
cleanup_redis() cleanup_redis()
# Modify for test # Modify for test
provider_budget_config: GenericBudgetConfigType = { provider_budget_config: GenericBudgetConfigType = {
"openai": GenericBudgetInfo(time_period="1d", budget_limit=0.000000000001), "openai": BudgetConfig(time_period="1d", budget_limit=0.000000000001),
"azure": GenericBudgetInfo(time_period="1d", budget_limit=100), "azure": BudgetConfig(time_period="1d", budget_limit=100),
} }
router = Router( router = Router(
@ -215,8 +213,8 @@ async def test_get_budget_config_for_provider():
""" """
cleanup_redis() cleanup_redis()
config = { config = {
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100), "openai": BudgetConfig(budget_duration="1d", max_budget=100),
"anthropic": GenericBudgetInfo(time_period="7d", budget_limit=500), "anthropic": BudgetConfig(budget_duration="7d", max_budget=500),
} }
provider_budget = RouterBudgetLimiting( provider_budget = RouterBudgetLimiting(
@ -226,13 +224,13 @@ async def test_get_budget_config_for_provider():
# Test existing providers # Test existing providers
openai_config = provider_budget._get_budget_config_for_provider("openai") openai_config = provider_budget._get_budget_config_for_provider("openai")
assert openai_config is not None assert openai_config is not None
assert openai_config.time_period == "1d" assert openai_config.budget_duration == "1d"
assert openai_config.budget_limit == 100 assert openai_config.max_budget == 100
anthropic_config = provider_budget._get_budget_config_for_provider("anthropic") anthropic_config = provider_budget._get_budget_config_for_provider("anthropic")
assert anthropic_config is not None assert anthropic_config is not None
assert anthropic_config.time_period == "7d" assert anthropic_config.budget_duration == "7d"
assert anthropic_config.budget_limit == 500 assert anthropic_config.max_budget == 500
# Test non-existent provider # Test non-existent provider
assert provider_budget._get_budget_config_for_provider("unknown") is None assert provider_budget._get_budget_config_for_provider("unknown") is None
@ -254,15 +252,15 @@ async def test_prometheus_metric_tracking():
provider_budget = RouterBudgetLimiting( provider_budget = RouterBudgetLimiting(
dual_cache=DualCache(), dual_cache=DualCache(),
provider_budget_config={ provider_budget_config={
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100) "openai": BudgetConfig(budget_duration="1d", max_budget=100)
}, },
) )
litellm._async_success_callback = [mock_prometheus] litellm._async_success_callback = [mock_prometheus]
provider_budget_config: GenericBudgetConfigType = { provider_budget_config: GenericBudgetConfigType = {
"openai": GenericBudgetInfo(time_period="1d", budget_limit=0.000000000001), "openai": BudgetConfig(budget_duration="1d", max_budget=0.000000000001),
"azure": GenericBudgetInfo(time_period="1d", budget_limit=100), "azure": BudgetConfig(budget_duration="1d", max_budget=100),
} }
router = Router( router = Router(
@ -442,8 +440,8 @@ async def test_sync_in_memory_spend_with_redis():
""" """
cleanup_redis() cleanup_redis()
provider_budget_config = { provider_budget_config = {
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100), "openai": BudgetConfig(time_period="1d", budget_limit=100),
"anthropic": GenericBudgetInfo(time_period="1d", budget_limit=200), "anthropic": BudgetConfig(time_period="1d", budget_limit=200),
} }
provider_budget = RouterBudgetLimiting( provider_budget = RouterBudgetLimiting(
@ -497,7 +495,7 @@ async def test_get_current_provider_spend():
provider_budget = RouterBudgetLimiting( provider_budget = RouterBudgetLimiting(
dual_cache=DualCache(), dual_cache=DualCache(),
provider_budget_config={ provider_budget_config={
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100), "openai": BudgetConfig(time_period="1d", budget_limit=100),
}, },
) )
@ -538,8 +536,8 @@ async def test_get_current_provider_budget_reset_at():
) )
), ),
provider_budget_config={ provider_budget_config={
"openai": GenericBudgetInfo(time_period="1d", budget_limit=100), "openai": BudgetConfig(budget_duration="1d", max_budget=100),
"vertex_ai": GenericBudgetInfo(time_period="1h", budget_limit=100), "vertex_ai": BudgetConfig(budget_duration="1h", max_budget=100),
}, },
) )

View file

@ -777,3 +777,68 @@ async def test_user_info_as_proxy_admin(prisma_client):
assert user_info_response.keys is not None assert user_info_response.keys is not None
assert len(user_info_response.keys) > 0, "Expected at least one key in response" assert len(user_info_response.keys) > 0, "Expected at least one key in response"
@pytest.mark.asyncio
async def test_key_update_with_model_specific_params(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
from litellm.proxy.management_endpoints.key_management_endpoints import (
update_key_fn,
)
from litellm.proxy._types import UpdateKeyRequest
new_key = await generate_key_fn(
data=GenerateKeyRequest(models=["gpt-4"]),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = new_key.key
token_hash = new_key.token_id
print(generated_key)
request = Request(scope={"type": "http"})
request._url = URL(url="/update/key")
args = {
"key_alias": f"test-key_{uuid.uuid4()}",
"duration": None,
"models": ["all-team-models"],
"spend": 0,
"max_budget": None,
"user_id": "default_user_id",
"team_id": None,
"max_parallel_requests": None,
"metadata": {
"model_tpm_limit": {"fake-openai-endpoint": 10},
"model_rpm_limit": {"fake-openai-endpoint": 0},
},
"tpm_limit": None,
"rpm_limit": None,
"budget_duration": None,
"allowed_cache_controls": [],
"soft_budget": None,
"config": {},
"permissions": {},
"model_max_budget": {},
"send_invite_email": None,
"model_rpm_limit": None,
"model_tpm_limit": None,
"guardrails": None,
"blocked": None,
"aliases": {},
"key": token_hash,
"budget_id": None,
"key_name": "sk-...2GWA",
"expires": None,
"token_id": token_hash,
"litellm_budget_table": None,
"token": token_hash,
}
await update_key_fn(request=request, data=UpdateKeyRequest(**args))

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
import sys import sys
from typing import Any, Dict
from unittest.mock import Mock from unittest.mock import Mock
from litellm.proxy.utils import _get_redoc_url, _get_docs_url from litellm.proxy.utils import _get_redoc_url, _get_docs_url
import json import json
@ -1104,3 +1105,89 @@ def test_proxy_config_state_post_init_callback_call():
config = pc.get_config_state() config = pc.get_config_state()
assert config["litellm_settings"]["default_team_settings"][0]["team_id"] == "test" assert config["litellm_settings"]["default_team_settings"][0]["team_id"] == "test"
@pytest.mark.parametrize(
"associated_budget_table, expected_user_api_key_auth_key, expected_user_api_key_auth_value",
[
(
{
"litellm_budget_table_max_budget": None,
"litellm_budget_table_tpm_limit": None,
"litellm_budget_table_rpm_limit": 1,
"litellm_budget_table_model_max_budget": None,
},
"rpm_limit",
1,
),
(
{},
None,
None,
),
(
{
"litellm_budget_table_max_budget": None,
"litellm_budget_table_tpm_limit": None,
"litellm_budget_table_rpm_limit": None,
"litellm_budget_table_model_max_budget": {"gpt-4o": 100},
},
"model_max_budget",
{"gpt-4o": 100},
),
],
)
def test_litellm_verification_token_view_response_with_budget_table(
associated_budget_table,
expected_user_api_key_auth_key,
expected_user_api_key_auth_value,
):
from litellm.proxy._types import LiteLLM_VerificationTokenView
args: Dict[str, Any] = {
"token": "78b627d4d14bc3acf5571ae9cb6834e661bc8794d1209318677387add7621ce1",
"key_name": "sk-...if_g",
"key_alias": None,
"soft_budget_cooldown": False,
"spend": 0.011441999999999997,
"expires": None,
"models": [],
"aliases": {},
"config": {},
"user_id": None,
"team_id": "test",
"permissions": {},
"max_parallel_requests": None,
"metadata": {},
"blocked": None,
"tpm_limit": None,
"rpm_limit": None,
"max_budget": None,
"budget_duration": None,
"budget_reset_at": None,
"allowed_cache_controls": [],
"model_spend": {},
"model_max_budget": {},
"budget_id": "my-test-tier",
"created_at": "2024-12-26T02:28:52.615+00:00",
"updated_at": "2024-12-26T03:01:51.159+00:00",
"team_spend": 0.012134999999999998,
"team_max_budget": None,
"team_tpm_limit": None,
"team_rpm_limit": None,
"team_models": [],
"team_metadata": {},
"team_blocked": False,
"team_alias": None,
"team_members_with_roles": [{"role": "admin", "user_id": "default_user_id"}],
"team_member_spend": None,
"team_model_aliases": None,
"team_member": None,
**associated_budget_table,
}
resp = LiteLLM_VerificationTokenView(**args)
if expected_user_api_key_auth_key is not None:
assert (
getattr(resp, expected_user_api_key_auth_key)
== expected_user_api_key_auth_value
)

View file

@ -13,7 +13,7 @@ import pytest
import litellm import litellm
import json import json
from litellm.types.utils import GenericBudgetInfo from litellm.types.utils import BudgetConfig as GenericBudgetInfo
import os import os
import sys import sys
from datetime import datetime from datetime import datetime
@ -56,13 +56,13 @@ def test_get_request_model_budget_config(budget_limiter):
config = budget_limiter._get_request_model_budget_config( config = budget_limiter._get_request_model_budget_config(
model="gpt-4", internal_model_max_budget=internal_budget model="gpt-4", internal_model_max_budget=internal_budget
) )
assert config.budget_limit == 100.0 assert config.max_budget == 100.0
# Test model with provider # Test model with provider
config = budget_limiter._get_request_model_budget_config( config = budget_limiter._get_request_model_budget_config(
model="openai/gpt-4", internal_model_max_budget=internal_budget model="openai/gpt-4", internal_model_max_budget=internal_budget
) )
assert config.budget_limit == 100.0 assert config.max_budget == 100.0
# Test non-existent model # Test non-existent model
config = budget_limiter._get_request_model_budget_config( config = budget_limiter._get_request_model_budget_config(