feat(proxy_server.py): support max budget on proxy

This commit is contained in:
Krrish Dholakia 2023-12-21 16:07:20 +05:30
parent 14115d0d60
commit 1a32228da5
5 changed files with 66 additions and 18 deletions

View file

@ -121,6 +121,7 @@ class GenerateKeyRequest(LiteLLMBase):
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
max_budget: Optional[float] = None
class UpdateKeyRequest(LiteLLMBase):
key: str
@ -132,21 +133,7 @@ class UpdateKeyRequest(LiteLLMBase):
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
max_budget: Optional[float] = None
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
"""
@ -161,6 +148,19 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
max_parallel_requests: Optional[int] = None
duration: str = "1h"
metadata: dict = {}
max_budget: Optional[float] = None
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class ConfigGeneralSettings(LiteLLMBase):
"""

View file

@ -0,0 +1,35 @@
from typing import Optional
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
class MaxBudgetLimiter(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_budget = user_api_key_dict.max_budget
curr_spend = user_api_key_dict.spend
if api_key is None:
return
if max_budget is None:
return
if curr_spend is None:
return
# CHECK IF REQUEST ALLOWED
if curr_spend >= max_budget:
raise HTTPException(status_code=429, detail="Max budget limit reached.")

View file

@ -616,7 +616,16 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None, metadata: Optional[dict] = {}):
async def generate_key_helper_fn(duration: Optional[str],
models: list,
aliases: dict,
config: dict,
spend: float,
max_budget: Optional[float]=None,
token: Optional[str]=None,
user_id: Optional[str]=None,
max_parallel_requests: Optional[int]=None,
metadata: Optional[dict] = {},):
global prisma_client
if prisma_client is None:
@ -666,7 +675,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
"spend": spend,
"user_id": user_id,
"max_parallel_requests": max_parallel_requests,
"metadata": metadata_json
"metadata": metadata_json,
"max_budget": max_budget
}
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
except Exception as e:

View file

@ -18,4 +18,5 @@ model LiteLLM_VerificationToken {
user_id String?
max_parallel_requests Int?
metadata Json @default("{}")
max_budget Float?
}

View file

@ -4,6 +4,7 @@ import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
def print_verbose(print_statement):
if litellm.set_verbose:
@ -23,11 +24,13 @@ class ProxyLogging:
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter()
pass
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
@ -203,7 +206,6 @@ class PrismaClient:
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,