forked from phoenix/litellm-mirror
feat(proxy_server.py): support max budget on proxy
This commit is contained in:
parent
14115d0d60
commit
1a32228da5
5 changed files with 66 additions and 18 deletions
|
@ -121,6 +121,7 @@ class GenerateKeyRequest(LiteLLMBase):
|
|||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Optional[dict] = {}
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
class UpdateKeyRequest(LiteLLMBase):
|
||||
key: str
|
||||
|
@ -132,21 +133,7 @@ class UpdateKeyRequest(LiteLLMBase):
|
|||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Optional[dict] = {}
|
||||
|
||||
class GenerateKeyResponse(LiteLLMBase):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
|
||||
|
||||
|
||||
class _DeleteKeyObject(LiteLLMBase):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||
"""
|
||||
|
@ -161,6 +148,19 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
|
|||
max_parallel_requests: Optional[int] = None
|
||||
duration: str = "1h"
|
||||
metadata: dict = {}
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
class GenerateKeyResponse(LiteLLMBase):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(LiteLLMBase):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
|
|
35
litellm/proxy/hooks/max_budget_limiter.py
Normal file
35
litellm/proxy/hooks/max_budget_limiter.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
|
||||
class MaxBudgetLimiter(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
||||
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_budget = user_api_key_dict.max_budget
|
||||
curr_spend = user_api_key_dict.spend
|
||||
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
if max_budget is None:
|
||||
return
|
||||
|
||||
if curr_spend is None:
|
||||
return
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
if curr_spend >= max_budget:
|
||||
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
|
@ -616,7 +616,16 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
router = litellm.Router(**router_params) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None, metadata: Optional[dict] = {}):
|
||||
async def generate_key_helper_fn(duration: Optional[str],
|
||||
models: list,
|
||||
aliases: dict,
|
||||
config: dict,
|
||||
spend: float,
|
||||
max_budget: Optional[float]=None,
|
||||
token: Optional[str]=None,
|
||||
user_id: Optional[str]=None,
|
||||
max_parallel_requests: Optional[int]=None,
|
||||
metadata: Optional[dict] = {},):
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
|
@ -666,7 +675,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
|
|||
"spend": spend,
|
||||
"user_id": user_id,
|
||||
"max_parallel_requests": max_parallel_requests,
|
||||
"metadata": metadata_json
|
||||
"metadata": metadata_json,
|
||||
"max_budget": max_budget
|
||||
}
|
||||
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
|
||||
except Exception as e:
|
||||
|
|
|
@ -18,4 +18,5 @@ model LiteLLM_VerificationToken {
|
|||
user_id String?
|
||||
max_parallel_requests Int?
|
||||
metadata Json @default("{}")
|
||||
max_budget Float?
|
||||
}
|
|
@ -4,6 +4,7 @@ import litellm, backoff
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
|
@ -23,11 +24,13 @@ class ProxyLogging:
|
|||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
self.max_budget_limiter = MaxBudgetLimiter()
|
||||
pass
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
litellm.callbacks.append(self.max_budget_limiter)
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
|
@ -203,7 +206,6 @@ class PrismaClient:
|
|||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue