fix(proxy_server.py): enforce end user budgets with 'litellm.max_end_user_budget' param

This commit is contained in:
Krrish Dholakia 2024-03-29 17:14:40 -07:00
parent 786116783f
commit 5280fc809f
5 changed files with 22 additions and 14 deletions

View file

@ -174,6 +174,7 @@ upperbound_key_generate_params: Optional[Dict] = None
default_user_params: Optional[Dict] = None default_user_params: Optional[Dict] = None
default_team_settings: Optional[List] = None default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
max_end_user_budget: Optional[float] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: Optional[float] = 6000 request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None # per model endpoint num_retries: Optional[int] = None # per model endpoint

View file

@ -69,7 +69,7 @@ def common_checks(
end_user_budget = end_user_object.litellm_budget_table.max_budget end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget: if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise Exception( raise Exception(
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
) )
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints # 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if ( if (

View file

@ -692,7 +692,6 @@ async def user_api_key_auth(
) )
# Check 2. If user_id for this token is in budget # Check 2. If user_id for this token is in budget
## Check 2.2 [OPTIONAL - checked only if litellm.max_user_budget is not None] If 'user' passed in /chat/completions is in budget
if valid_token.user_id is not None: if valid_token.user_id is not None:
user_id_list = [valid_token.user_id] user_id_list = [valid_token.user_id]
for id in user_id_list: for id in user_id_list:
@ -909,6 +908,13 @@ async def user_api_key_auth(
models=valid_token.team_models, models=valid_token.team_models,
) )
_end_user_object = None
if "user" in request_data:
_id = "end_user_id:{}".format(request_data["user"])
_end_user_object = await user_api_key_cache.async_get_cache(key=_id)
if _end_user_object is not None:
_end_user_object = LiteLLM_EndUserTable(**_end_user_object)
global_proxy_spend = None global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget if litellm.max_budget > 0: # user set proxy max budget
# check cache # check cache
@ -947,7 +953,7 @@ async def user_api_key_auth(
_ = common_checks( _ = common_checks(
request_body=request_data, request_body=request_data,
team_object=_team_obj, team_object=_team_obj,
end_user_object=None, end_user_object=_end_user_object,
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
@ -1617,7 +1623,7 @@ async def update_cache(
async def _update_user_cache(): async def _update_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY ## UPDATE CACHE FOR USER ID + GLOBAL PROXY
user_ids = [user_id, litellm_proxy_budget_name, end_user_id] user_ids = [user_id, litellm_proxy_budget_name]
try: try:
for _id in user_ids: for _id in user_ids:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
@ -1664,8 +1670,7 @@ async def update_cache(
) )
async def _update_end_user_cache(): async def _update_end_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY _id = "end_user_id:{}".format(end_user_id)
_id = end_user_id
try: try:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id) existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
@ -1673,14 +1678,14 @@ async def update_cache(
# if user does not exist in LiteLLM_UserTable, create a new user # if user does not exist in LiteLLM_UserTable, create a new user
existing_spend = 0 existing_spend = 0
max_user_budget = None max_user_budget = None
if litellm.max_user_budget is not None: if litellm.max_end_user_budget is not None:
max_user_budget = litellm.max_user_budget max_end_user_budget = litellm.max_end_user_budget
existing_spend_obj = LiteLLM_EndUserTable( existing_spend_obj = LiteLLM_EndUserTable(
user_id=_id, user_id=_id,
spend=0, spend=0,
blocked=False, blocked=False,
litellm_budget_table=LiteLLM_BudgetTable( litellm_budget_table=LiteLLM_BudgetTable(
max_budget=max_user_budget max_budget=max_end_user_budget
), ),
) )
verbose_proxy_logger.debug( verbose_proxy_logger.debug(

View file

@ -1941,9 +1941,9 @@ async def update_spend(
end_user_id, end_user_id,
response_cost, response_cost,
) in prisma_client.end_user_list_transactons.items(): ) in prisma_client.end_user_list_transactons.items():
max_user_budget = None max_end_user_budget = None
if litellm.max_user_budget is not None: if litellm.max_end_user_budget is not None:
max_user_budget = litellm.max_user_budget max_end_user_budget = litellm.max_end_user_budget
new_user_obj = LiteLLM_EndUserTable( new_user_obj = LiteLLM_EndUserTable(
user_id=end_user_id, spend=response_cost, blocked=False user_id=end_user_id, spend=response_cost, blocked=False
) )

View file

@ -324,7 +324,7 @@ def test_call_with_end_user_over_budget(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_user_budget", 0.00001) setattr(litellm, "max_end_user_budget", 0.00001)
try: try:
async def test(): async def test():
@ -378,7 +378,9 @@ def test_call_with_end_user_over_budget(prisma_client):
"user_api_key_user_id": user, "user_api_key_user_id": user,
}, },
"proxy_server_request": { "proxy_server_request": {
"user": user, "body": {
"user": user,
}
}, },
}, },
"response_cost": 10, "response_cost": 10,