fix(auth_checks.py): make global spend checks more accurate

This commit is contained in:
Krrish Dholakia 2024-03-29 14:57:44 -07:00
parent bbd94f504c
commit d8c15a5677
2 changed files with 59 additions and 11 deletions

View file

@ -437,12 +437,34 @@ async def user_api_key_auth(
key=end_user_id, value=end_user_object
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
)
if global_proxy_spend is None and prisma_client is not None:
# get from db
sql_query = """SELECT SUM(spend) as total_spend FROM MONTHLYGLOBALSPEND;"""
response = await prisma_client.db.query_raw(query=sql_query)
global_proxy_spend = response[0]["total_spend"]
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name),
value=global_proxy_spend,
ttl=60,
)
# run through common checks
_ = common_checks(
request_body=request_data,
team_object=team_object,
end_user_object=end_user_object,
general_settings=general_settings,
global_proxy_spend=global_proxy_spend,
route=route,
)
# save user object in cache
@ -656,17 +678,9 @@ async def user_api_key_auth(
)
# Check 2. If user_id for this token is in budget
## Check 2.1 If global proxy is in budget
## Check 2.2 [OPTIONAL - checked only if litellm.max_user_budget is not None] If 'user' passed in /chat/completions is in budget
if valid_token.user_id is not None:
user_id_list = [valid_token.user_id, litellm_proxy_budget_name]
if (
litellm.max_user_budget is not None
): # Check if 'user' passed in /chat/completions is in budget, only checked if litellm.max_user_budget is set
user_passed_to_chat_completions = request_data.get("user", None)
if user_passed_to_chat_completions is not None:
user_id_list.append(user_passed_to_chat_completions)
user_id_list = [valid_token.user_id]
for id in user_id_list:
value = user_api_key_cache.get_cache(key=id)
if value is not None:
@ -681,7 +695,6 @@ async def user_api_key_auth(
user_id_information = await prisma_client.get_data(
user_id_list=[
valid_token.user_id,
litellm_proxy_budget_name,
],
table_name="user",
query_type="find_all",
@ -881,11 +894,35 @@ async def user_api_key_auth(
blocked=valid_token.team_blocked,
models=valid_token.team_models,
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
)
if global_proxy_spend is None:
# get from db
sql_query = (
"""SELECT SUM(spend) as total_spend FROM MONTHLYGLOBALSPEND;"""
)
response = await prisma_client.db.query_raw(query=sql_query)
global_proxy_spend = response[0].total_spend
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name),
value=global_proxy_spend,
ttl=60,
)
_ = common_checks(
request_body=request_data,
team_object=_team_obj,
end_user_object=None,
general_settings=general_settings,
global_proxy_spend=global_proxy_spend,
route=route,
)
# Token passed all checks