diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 233c1b642..76d37bddf 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -121,6 +121,7 @@ class GenerateKeyRequest(LiteLLMBase): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} + max_budget: Optional[float] = None class UpdateKeyRequest(LiteLLMBase): key: str @@ -132,21 +133,7 @@ class UpdateKeyRequest(LiteLLMBase): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} - -class GenerateKeyResponse(LiteLLMBase): - key: str - expires: datetime - user_id: str - - - - -class _DeleteKeyObject(LiteLLMBase): - key: str - -class DeleteKeyRequest(LiteLLMBase): - keys: List[_DeleteKeyObject] - + max_budget: Optional[float] = None class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth """ @@ -161,6 +148,19 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k max_parallel_requests: Optional[int] = None duration: str = "1h" metadata: dict = {} + max_budget: Optional[float] = None + +class GenerateKeyResponse(LiteLLMBase): + key: str + expires: datetime + user_id: str + +class _DeleteKeyObject(LiteLLMBase): + key: str + +class DeleteKeyRequest(LiteLLMBase): + keys: List[_DeleteKeyObject] + class ConfigGeneralSettings(LiteLLMBase): """ diff --git a/litellm/proxy/hooks/max_budget_limiter.py b/litellm/proxy/hooks/max_budget_limiter.py new file mode 100644 index 000000000..b2ffbeea8 --- /dev/null +++ b/litellm/proxy/hooks/max_budget_limiter.py @@ -0,0 +1,35 @@ +from typing import Optional +import litellm +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException + +class MaxBudgetLimiter(CustomLogger): + # Class variables or attributes + def __init__(self): + pass + + def print_verbose(self, print_statement): + if litellm.set_verbose is True: + print(print_statement) # noqa + + + async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): + self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook") + api_key = user_api_key_dict.api_key + max_budget = user_api_key_dict.max_budget + curr_spend = user_api_key_dict.spend + + if api_key is None: + return + + if max_budget is None: + return + + if curr_spend is None: + return + + # CHECK IF REQUEST ALLOWED + if curr_spend >= max_budget: + raise HTTPException(status_code=429, detail="Max budget limit reached.") \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8a32f1b4f..e6ec64823 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -616,7 +616,16 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): router = litellm.Router(**router_params) # type:ignore return router, model_list, general_settings -async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None, metadata: Optional[dict] = {}): +async def generate_key_helper_fn(duration: Optional[str], + models: list, + aliases: dict, + config: dict, + spend: float, + max_budget: Optional[float]=None, + token: Optional[str]=None, + user_id: Optional[str]=None, + max_parallel_requests: Optional[int]=None, + metadata: Optional[dict] = {},): global prisma_client if prisma_client is None: @@ -666,7 +675,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: "spend": spend, "user_id": user_id, "max_parallel_requests": max_parallel_requests, - "metadata": metadata_json + "metadata": metadata_json, + "max_budget": max_budget } new_verification_token = await prisma_client.insert_data(data=verification_token_data) except Exception as e: diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 6cfcdb866..e4acd13e5 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -18,4 +18,5 @@ model LiteLLM_VerificationToken { user_id String? max_parallel_requests Int? metadata Json @default("{}") + max_budget Float? } \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 52a2fb6aa..3592593d5 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -4,6 +4,7 @@ import litellm, backoff from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler +from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.integrations.custom_logger import CustomLogger def print_verbose(print_statement): if litellm.set_verbose: @@ -23,11 +24,13 @@ class ProxyLogging: self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache self.max_parallel_request_limiter = MaxParallelRequestsHandler() + self.max_budget_limiter = MaxBudgetLimiter() pass def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") litellm.callbacks.append(self.max_parallel_request_limiter) + litellm.callbacks.append(self.max_budget_limiter) for callback in litellm.callbacks: if callback not in litellm.input_callback: litellm.input_callback.append(callback) @@ -203,7 +206,6 @@ class PrismaClient: hashed_token = self.hash_token(token=token) db_data = self.jsonify_object(data=data) db_data["token"] = hashed_token - new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore where={ 'token': hashed_token,