forked from phoenix/litellm-mirror
feat(proxy_server.py): support max budget on proxy
This commit is contained in:
parent
14115d0d60
commit
1a32228da5
5 changed files with 66 additions and 18 deletions
|
@ -121,6 +121,7 @@ class GenerateKeyRequest(LiteLLMBase):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
metadata: Optional[dict] = {}
|
metadata: Optional[dict] = {}
|
||||||
|
max_budget: Optional[float] = None
|
||||||
|
|
||||||
class UpdateKeyRequest(LiteLLMBase):
|
class UpdateKeyRequest(LiteLLMBase):
|
||||||
key: str
|
key: str
|
||||||
|
@ -132,21 +133,7 @@ class UpdateKeyRequest(LiteLLMBase):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
metadata: Optional[dict] = {}
|
metadata: Optional[dict] = {}
|
||||||
|
max_budget: Optional[float] = None
|
||||||
class GenerateKeyResponse(LiteLLMBase):
|
|
||||||
key: str
|
|
||||||
expires: datetime
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _DeleteKeyObject(LiteLLMBase):
|
|
||||||
key: str
|
|
||||||
|
|
||||||
class DeleteKeyRequest(LiteLLMBase):
|
|
||||||
keys: List[_DeleteKeyObject]
|
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||||
"""
|
"""
|
||||||
|
@ -161,6 +148,19 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
duration: str = "1h"
|
duration: str = "1h"
|
||||||
metadata: dict = {}
|
metadata: dict = {}
|
||||||
|
max_budget: Optional[float] = None
|
||||||
|
|
||||||
|
class GenerateKeyResponse(LiteLLMBase):
|
||||||
|
key: str
|
||||||
|
expires: datetime
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
class _DeleteKeyObject(LiteLLMBase):
|
||||||
|
key: str
|
||||||
|
|
||||||
|
class DeleteKeyRequest(LiteLLMBase):
|
||||||
|
keys: List[_DeleteKeyObject]
|
||||||
|
|
||||||
|
|
||||||
class ConfigGeneralSettings(LiteLLMBase):
|
class ConfigGeneralSettings(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
|
|
35
litellm/proxy/hooks/max_budget_limiter.py
Normal file
35
litellm/proxy/hooks/max_budget_limiter.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
from typing import Optional
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
class MaxBudgetLimiter(CustomLogger):
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement):
|
||||||
|
if litellm.set_verbose is True:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
|
||||||
|
|
||||||
|
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
||||||
|
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
|
||||||
|
api_key = user_api_key_dict.api_key
|
||||||
|
max_budget = user_api_key_dict.max_budget
|
||||||
|
curr_spend = user_api_key_dict.spend
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if max_budget is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if curr_spend is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# CHECK IF REQUEST ALLOWED
|
||||||
|
if curr_spend >= max_budget:
|
||||||
|
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
|
@ -616,7 +616,16 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
router = litellm.Router(**router_params) # type:ignore
|
router = litellm.Router(**router_params) # type:ignore
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None, metadata: Optional[dict] = {}):
|
async def generate_key_helper_fn(duration: Optional[str],
|
||||||
|
models: list,
|
||||||
|
aliases: dict,
|
||||||
|
config: dict,
|
||||||
|
spend: float,
|
||||||
|
max_budget: Optional[float]=None,
|
||||||
|
token: Optional[str]=None,
|
||||||
|
user_id: Optional[str]=None,
|
||||||
|
max_parallel_requests: Optional[int]=None,
|
||||||
|
metadata: Optional[dict] = {},):
|
||||||
global prisma_client
|
global prisma_client
|
||||||
|
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
|
@ -666,7 +675,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
|
||||||
"spend": spend,
|
"spend": spend,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"max_parallel_requests": max_parallel_requests,
|
"max_parallel_requests": max_parallel_requests,
|
||||||
"metadata": metadata_json
|
"metadata": metadata_json,
|
||||||
|
"max_budget": max_budget
|
||||||
}
|
}
|
||||||
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
|
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -18,4 +18,5 @@ model LiteLLM_VerificationToken {
|
||||||
user_id String?
|
user_id String?
|
||||||
max_parallel_requests Int?
|
max_parallel_requests Int?
|
||||||
metadata Json @default("{}")
|
metadata Json @default("{}")
|
||||||
|
max_budget Float?
|
||||||
}
|
}
|
|
@ -4,6 +4,7 @@ import litellm, backoff
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||||
|
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
|
@ -23,11 +24,13 @@ class ProxyLogging:
|
||||||
self.call_details: dict = {}
|
self.call_details: dict = {}
|
||||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||||
|
self.max_budget_limiter = MaxBudgetLimiter()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _init_litellm_callbacks(self):
|
def _init_litellm_callbacks(self):
|
||||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||||
|
litellm.callbacks.append(self.max_budget_limiter)
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if callback not in litellm.input_callback:
|
if callback not in litellm.input_callback:
|
||||||
litellm.input_callback.append(callback)
|
litellm.input_callback.append(callback)
|
||||||
|
@ -203,7 +206,6 @@ class PrismaClient:
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
db_data = self.jsonify_object(data=data)
|
db_data = self.jsonify_object(data=data)
|
||||||
db_data["token"] = hashed_token
|
db_data["token"] = hashed_token
|
||||||
|
|
||||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
where={
|
where={
|
||||||
'token': hashed_token,
|
'token': hashed_token,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue