mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #2757 from BerriAI/litellm_fix_budget_alerts
fix(auth_checks.py): make global spend checks more accurate
This commit is contained in:
commit
6d9887969f
6 changed files with 147 additions and 39 deletions
|
@ -437,12 +437,49 @@ async def user_api_key_auth(
|
|||
key=end_user_id, value=end_user_object
|
||||
)
|
||||
|
||||
global_proxy_spend = None
|
||||
|
||||
if litellm.max_budget > 0: # user set proxy max budget
|
||||
# check cache
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name)
|
||||
)
|
||||
if global_proxy_spend is None and prisma_client is not None:
|
||||
# get from db
|
||||
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
||||
|
||||
response = await prisma_client.db.query_raw(query=sql_query)
|
||||
|
||||
global_proxy_spend = response[0]["total_spend"]
|
||||
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name),
|
||||
value=global_proxy_spend,
|
||||
ttl=60,
|
||||
)
|
||||
if global_proxy_spend is not None:
|
||||
user_info = {
|
||||
"user_id": litellm_proxy_admin_name,
|
||||
"max_budget": litellm.max_budget,
|
||||
"spend": global_proxy_spend,
|
||||
"user_email": "",
|
||||
}
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
user_max_budget=litellm.max_budget,
|
||||
user_current_spend=global_proxy_spend,
|
||||
type="user_and_proxy_budget",
|
||||
user_info=user_info,
|
||||
)
|
||||
)
|
||||
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=team_object,
|
||||
end_user_object=end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
)
|
||||
# save user object in cache
|
||||
|
@ -656,17 +693,8 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget
|
||||
## Check 2.1 If global proxy is in budget
|
||||
## Check 2.2 [OPTIONAL - checked only if litellm.max_user_budget is not None] If 'user' passed in /chat/completions is in budget
|
||||
if valid_token.user_id is not None:
|
||||
user_id_list = [valid_token.user_id, litellm_proxy_budget_name]
|
||||
if (
|
||||
litellm.max_user_budget is not None
|
||||
): # Check if 'user' passed in /chat/completions is in budget, only checked if litellm.max_user_budget is set
|
||||
user_passed_to_chat_completions = request_data.get("user", None)
|
||||
if user_passed_to_chat_completions is not None:
|
||||
user_id_list.append(user_passed_to_chat_completions)
|
||||
|
||||
user_id_list = [valid_token.user_id]
|
||||
for id in user_id_list:
|
||||
value = user_api_key_cache.get_cache(key=id)
|
||||
if value is not None:
|
||||
|
@ -675,13 +703,12 @@ async def user_api_key_auth(
|
|||
user_id_information.append(value)
|
||||
if user_id_information is None or (
|
||||
isinstance(user_id_information, list)
|
||||
and len(user_id_information) < 2
|
||||
and len(user_id_information) < 1
|
||||
):
|
||||
if prisma_client is not None:
|
||||
user_id_information = await prisma_client.get_data(
|
||||
user_id_list=[
|
||||
valid_token.user_id,
|
||||
litellm_proxy_budget_name,
|
||||
],
|
||||
table_name="user",
|
||||
query_type="find_all",
|
||||
|
@ -881,11 +908,54 @@ async def user_api_key_auth(
|
|||
blocked=valid_token.team_blocked,
|
||||
models=valid_token.team_models,
|
||||
)
|
||||
|
||||
_end_user_object = None
|
||||
if "user" in request_data:
|
||||
_id = "end_user_id:{}".format(request_data["user"])
|
||||
_end_user_object = await user_api_key_cache.async_get_cache(key=_id)
|
||||
if _end_user_object is not None:
|
||||
_end_user_object = LiteLLM_EndUserTable(**_end_user_object)
|
||||
|
||||
global_proxy_spend = None
|
||||
if litellm.max_budget > 0: # user set proxy max budget
|
||||
# check cache
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name)
|
||||
)
|
||||
if global_proxy_spend is None:
|
||||
# get from db
|
||||
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
||||
|
||||
response = await prisma_client.db.query_raw(query=sql_query)
|
||||
|
||||
global_proxy_spend = response[0]["total_spend"]
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name),
|
||||
value=global_proxy_spend,
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
if global_proxy_spend is not None:
|
||||
user_info = {
|
||||
"user_id": litellm_proxy_admin_name,
|
||||
"max_budget": litellm.max_budget,
|
||||
"spend": global_proxy_spend,
|
||||
"user_email": "",
|
||||
}
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
user_max_budget=litellm.max_budget,
|
||||
user_current_spend=global_proxy_spend,
|
||||
type="user_and_proxy_budget",
|
||||
user_info=user_info,
|
||||
)
|
||||
)
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
end_user_object=None,
|
||||
end_user_object=_end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
)
|
||||
# Token passed all checks
|
||||
|
@ -1553,7 +1623,7 @@ async def update_cache(
|
|||
|
||||
async def _update_user_cache():
|
||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||
user_ids = [user_id, litellm_proxy_budget_name, end_user_id]
|
||||
user_ids = [user_id]
|
||||
try:
|
||||
for _id in user_ids:
|
||||
# Fetch the existing cost for the given user
|
||||
|
@ -1594,14 +1664,26 @@ async def update_cache(
|
|||
user_api_key_cache.set_cache(
|
||||
key=_id, value=existing_spend_obj.json()
|
||||
)
|
||||
## UPDATE GLOBAL PROXY ##
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name)
|
||||
)
|
||||
if global_proxy_spend is None:
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name), value=response_cost
|
||||
)
|
||||
elif response_cost is not None and global_proxy_spend is not None:
|
||||
increment = global_proxy_spend + response_cost
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name), value=increment
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def _update_end_user_cache():
|
||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||
_id = end_user_id
|
||||
_id = "end_user_id:{}".format(end_user_id)
|
||||
try:
|
||||
# Fetch the existing cost for the given user
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
|
||||
|
@ -1609,14 +1691,14 @@ async def update_cache(
|
|||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
if litellm.max_end_user_budget is not None:
|
||||
max_end_user_budget = litellm.max_end_user_budget
|
||||
existing_spend_obj = LiteLLM_EndUserTable(
|
||||
user_id=_id,
|
||||
spend=0,
|
||||
blocked=False,
|
||||
litellm_budget_table=LiteLLM_BudgetTable(
|
||||
max_budget=max_user_budget
|
||||
max_budget=max_end_user_budget
|
||||
),
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -4049,7 +4131,6 @@ async def generate_key_fn(
|
|||
)
|
||||
_budget_id = getattr(_budget, "budget_id", None)
|
||||
data_json = data.json() # type: ignore
|
||||
|
||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||
if "max_budget" in data_json:
|
||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue