Merge pull request #2757 from BerriAI/litellm_fix_budget_alerts

fix(auth_checks.py): make global spend checks more accurate
This commit is contained in:
Krish Dholakia 2024-03-29 21:13:27 -07:00 committed by GitHub
commit 6d9887969f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 147 additions and 39 deletions

View file

@ -174,6 +174,7 @@ upperbound_key_generate_params: Optional[Dict] = None
default_user_params: Optional[Dict] = None default_user_params: Optional[Dict] = None
default_team_settings: Optional[List] = None default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
max_end_user_budget: Optional[float] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: Optional[float] = 6000 request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None # per model endpoint num_retries: Optional[int] = None # per model endpoint

View file

@ -5,10 +5,15 @@ model_list:
api_key: my-fake-key api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
max_budget: 600020
budget_duration: 30d
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds) proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
enable_jwt_auth: True enable_jwt_auth: True
alerting: ["slack"]
litellm_jwtauth: litellm_jwtauth:
admin_jwt_scope: "litellm_proxy_admin" admin_jwt_scope: "litellm_proxy_admin"
team_jwt_scope: "litellm_team" team_jwt_scope: "litellm_team"

View file

@ -18,6 +18,7 @@ from litellm.proxy._types import (
from typing import Optional, Literal, Union from typing import Optional, Literal, Union
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache from litellm.caching import DualCache
import litellm
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
@ -26,6 +27,7 @@ def common_checks(
request_body: dict, request_body: dict,
team_object: LiteLLM_TeamTable, team_object: LiteLLM_TeamTable,
end_user_object: Optional[LiteLLM_EndUserTable], end_user_object: Optional[LiteLLM_EndUserTable],
global_proxy_spend: Optional[float],
general_settings: dict, general_settings: dict,
route: str, route: str,
) -> bool: ) -> bool:
@ -37,6 +39,7 @@ def common_checks(
3. If team is in budget 3. If team is in budget
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints 5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
""" """
_model = request_body.get("model", None) _model = request_body.get("model", None)
if team_object.blocked == True: if team_object.blocked == True:
@ -66,7 +69,7 @@ def common_checks(
end_user_budget = end_user_object.litellm_budget_table.max_budget end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget: if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise Exception( raise Exception(
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
) )
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints # 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if ( if (
@ -77,7 +80,12 @@ def common_checks(
raise Exception( raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
) )
# 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
if litellm.max_budget > 0 and global_proxy_spend is not None:
if global_proxy_spend > litellm.max_budget:
raise Exception(
f"ExceededBudget: LiteLLM Proxy has exceeded its budget. Current spend: {global_proxy_spend}; Max Budget: {litellm.max_budget}"
)
return True return True

View file

@ -437,12 +437,49 @@ async def user_api_key_auth(
key=end_user_id, value=end_user_object key=end_user_id, value=end_user_object
) )
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
)
if global_proxy_spend is None and prisma_client is not None:
# get from db
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
response = await prisma_client.db.query_raw(query=sql_query)
global_proxy_spend = response[0]["total_spend"]
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name),
value=global_proxy_spend,
ttl=60,
)
if global_proxy_spend is not None:
user_info = {
"user_id": litellm_proxy_admin_name,
"max_budget": litellm.max_budget,
"spend": global_proxy_spend,
"user_email": "",
}
asyncio.create_task(
proxy_logging_obj.budget_alerts(
user_max_budget=litellm.max_budget,
user_current_spend=global_proxy_spend,
type="user_and_proxy_budget",
user_info=user_info,
)
)
# run through common checks # run through common checks
_ = common_checks( _ = common_checks(
request_body=request_data, request_body=request_data,
team_object=team_object, team_object=team_object,
end_user_object=end_user_object, end_user_object=end_user_object,
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend,
route=route, route=route,
) )
# save user object in cache # save user object in cache
@ -656,17 +693,8 @@ async def user_api_key_auth(
) )
# Check 2. If user_id for this token is in budget # Check 2. If user_id for this token is in budget
## Check 2.1 If global proxy is in budget
## Check 2.2 [OPTIONAL - checked only if litellm.max_user_budget is not None] If 'user' passed in /chat/completions is in budget
if valid_token.user_id is not None: if valid_token.user_id is not None:
user_id_list = [valid_token.user_id, litellm_proxy_budget_name] user_id_list = [valid_token.user_id]
if (
litellm.max_user_budget is not None
): # Check if 'user' passed in /chat/completions is in budget, only checked if litellm.max_user_budget is set
user_passed_to_chat_completions = request_data.get("user", None)
if user_passed_to_chat_completions is not None:
user_id_list.append(user_passed_to_chat_completions)
for id in user_id_list: for id in user_id_list:
value = user_api_key_cache.get_cache(key=id) value = user_api_key_cache.get_cache(key=id)
if value is not None: if value is not None:
@ -675,13 +703,12 @@ async def user_api_key_auth(
user_id_information.append(value) user_id_information.append(value)
if user_id_information is None or ( if user_id_information is None or (
isinstance(user_id_information, list) isinstance(user_id_information, list)
and len(user_id_information) < 2 and len(user_id_information) < 1
): ):
if prisma_client is not None: if prisma_client is not None:
user_id_information = await prisma_client.get_data( user_id_information = await prisma_client.get_data(
user_id_list=[ user_id_list=[
valid_token.user_id, valid_token.user_id,
litellm_proxy_budget_name,
], ],
table_name="user", table_name="user",
query_type="find_all", query_type="find_all",
@ -881,11 +908,54 @@ async def user_api_key_auth(
blocked=valid_token.team_blocked, blocked=valid_token.team_blocked,
models=valid_token.team_models, models=valid_token.team_models,
) )
_end_user_object = None
if "user" in request_data:
_id = "end_user_id:{}".format(request_data["user"])
_end_user_object = await user_api_key_cache.async_get_cache(key=_id)
if _end_user_object is not None:
_end_user_object = LiteLLM_EndUserTable(**_end_user_object)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
)
if global_proxy_spend is None:
# get from db
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
response = await prisma_client.db.query_raw(query=sql_query)
global_proxy_spend = response[0]["total_spend"]
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name),
value=global_proxy_spend,
ttl=60,
)
if global_proxy_spend is not None:
user_info = {
"user_id": litellm_proxy_admin_name,
"max_budget": litellm.max_budget,
"spend": global_proxy_spend,
"user_email": "",
}
asyncio.create_task(
proxy_logging_obj.budget_alerts(
user_max_budget=litellm.max_budget,
user_current_spend=global_proxy_spend,
type="user_and_proxy_budget",
user_info=user_info,
)
)
_ = common_checks( _ = common_checks(
request_body=request_data, request_body=request_data,
team_object=_team_obj, team_object=_team_obj,
end_user_object=None, end_user_object=_end_user_object,
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend,
route=route, route=route,
) )
# Token passed all checks # Token passed all checks
@ -1553,7 +1623,7 @@ async def update_cache(
async def _update_user_cache(): async def _update_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY ## UPDATE CACHE FOR USER ID + GLOBAL PROXY
user_ids = [user_id, litellm_proxy_budget_name, end_user_id] user_ids = [user_id]
try: try:
for _id in user_ids: for _id in user_ids:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
@ -1594,14 +1664,26 @@ async def update_cache(
user_api_key_cache.set_cache( user_api_key_cache.set_cache(
key=_id, value=existing_spend_obj.json() key=_id, value=existing_spend_obj.json()
) )
## UPDATE GLOBAL PROXY ##
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
)
if global_proxy_spend is None:
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name), value=response_cost
)
elif response_cost is not None and global_proxy_spend is not None:
increment = global_proxy_spend + response_cost
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name), value=increment
)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}" f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
) )
async def _update_end_user_cache(): async def _update_end_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY _id = "end_user_id:{}".format(end_user_id)
_id = end_user_id
try: try:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id) existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
@ -1609,14 +1691,14 @@ async def update_cache(
# if user does not exist in LiteLLM_UserTable, create a new user # if user does not exist in LiteLLM_UserTable, create a new user
existing_spend = 0 existing_spend = 0
max_user_budget = None max_user_budget = None
if litellm.max_user_budget is not None: if litellm.max_end_user_budget is not None:
max_user_budget = litellm.max_user_budget max_end_user_budget = litellm.max_end_user_budget
existing_spend_obj = LiteLLM_EndUserTable( existing_spend_obj = LiteLLM_EndUserTable(
user_id=_id, user_id=_id,
spend=0, spend=0,
blocked=False, blocked=False,
litellm_budget_table=LiteLLM_BudgetTable( litellm_budget_table=LiteLLM_BudgetTable(
max_budget=max_user_budget max_budget=max_end_user_budget
), ),
) )
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -4049,7 +4131,6 @@ async def generate_key_fn(
) )
_budget_id = getattr(_budget, "budget_id", None) _budget_id = getattr(_budget, "budget_id", None)
data_json = data.json() # type: ignore data_json = data.json() # type: ignore
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
if "max_budget" in data_json: if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None) data_json["key_max_budget"] = data_json.pop("max_budget", None)

View file

@ -1941,9 +1941,9 @@ async def update_spend(
end_user_id, end_user_id,
response_cost, response_cost,
) in prisma_client.end_user_list_transactons.items(): ) in prisma_client.end_user_list_transactons.items():
max_user_budget = None max_end_user_budget = None
if litellm.max_user_budget is not None: if litellm.max_end_user_budget is not None:
max_user_budget = litellm.max_user_budget max_end_user_budget = litellm.max_end_user_budget
new_user_obj = LiteLLM_EndUserTable( new_user_obj = LiteLLM_EndUserTable(
user_id=end_user_id, spend=response_cost, blocked=False user_id=end_user_id, spend=response_cost, blocked=False
) )

View file

@ -324,7 +324,7 @@ def test_call_with_end_user_over_budget(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_user_budget", 0.00001) setattr(litellm, "max_end_user_budget", 0.00001)
try: try:
async def test(): async def test():
@ -378,7 +378,9 @@ def test_call_with_end_user_over_budget(prisma_client):
"user_api_key_user_id": user, "user_api_key_user_id": user,
}, },
"proxy_server_request": { "proxy_server_request": {
"body": {
"user": user, "user": user,
}
}, },
}, },
"response_cost": 10, "response_cost": 10,
@ -407,18 +409,20 @@ def test_call_with_proxy_over_budget(prisma_client):
litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}" litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}"
setattr( setattr(
litellm.proxy.proxy_server, litellm.proxy.proxy_server,
"litellm_proxy_budget_name", "litellm_proxy_admin_name",
litellm_proxy_budget_name, litellm_proxy_budget_name,
) )
setattr(litellm, "max_budget", 0.00001)
from litellm.proxy.proxy_server import user_api_key_cache
user_api_key_cache.set_cache(
key="{}:spend".format(litellm_proxy_budget_name), value=0
)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
try: try:
async def test(): async def test():
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
## CREATE PROXY + USER BUDGET ##
request = NewUserRequest(
max_budget=0.00001, user_id=litellm_proxy_budget_name
)
await new_user(request)
request = NewUserRequest() request = NewUserRequest()
key = await new_user(request) key = await new_user(request)
print(key) print(key)
@ -470,6 +474,7 @@ def test_call_with_proxy_over_budget(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5) await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
@ -571,9 +576,17 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}" litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}"
setattr( setattr(
litellm.proxy.proxy_server, litellm.proxy.proxy_server,
"litellm_proxy_budget_name", "litellm_proxy_admin_name",
litellm_proxy_budget_name, litellm_proxy_budget_name,
) )
setattr(litellm, "max_budget", 0.00001)
from litellm.proxy.proxy_server import user_api_key_cache
user_api_key_cache.set_cache(
key="{}:spend".format(litellm_proxy_budget_name), value=0
)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
import logging import logging
@ -584,10 +597,10 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
async def test(): async def test():
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
## CREATE PROXY + USER BUDGET ## ## CREATE PROXY + USER BUDGET ##
request = NewUserRequest( # request = NewUserRequest(
max_budget=0.00001, user_id=litellm_proxy_budget_name # max_budget=0.00001, user_id=litellm_proxy_budget_name
) # )
await new_user(request) # await new_user(request)
request = NewUserRequest() request = NewUserRequest()
key = await new_user(request) key = await new_user(request)
print(key) print(key)