fix(proxy_server.py): enforce budget limit if global proxy limit reached

This commit is contained in:
Krrish Dholakia 2024-01-24 17:11:40 -08:00
parent 624da17698
commit 30a8071bf1
2 changed files with 70 additions and 24 deletions

View file

@ -370,30 +370,62 @@ async def user_api_key_auth(
)
# Check 2. If user_id for this token is in budget
## Check 2.5 If global proxy is in budget
if valid_token.user_id is not None:
if prisma_client is not None:
user_id_information = await prisma_client.get_data(
user_id=valid_token.user_id, table_name="user"
user_id_list=[valid_token.user_id, "litellm-proxy-budget"],
table_name="user",
query_type="find_all",
)
if custom_db_client is not None:
user_id_information = await custom_db_client.get_data(
key=valid_token.user_id, table_name="user"
)
verbose_proxy_logger.debug(
f"user_id_information: {user_id_information}"
)
# Token exists, not expired now check if its in budget for the user
if valid_token.spend is not None and valid_token.user_id is not None:
user_max_budget = getattr(user_id_information, "max_budget", None)
user_current_spend = getattr(user_id_information, "spend", None)
if user_id_information is not None:
if isinstance(user_id_information, list):
## Check if user in budget
for _user in user_id_information:
if _user is None:
continue
assert isinstance(_user, dict)
# Token exists, not expired now check if its in budget for the user
user_max_budget = _user.get("max_budget", None)
user_current_spend = _user.get("spend", None)
if user_max_budget is not None and user_current_spend is not None:
if user_current_spend > user_max_budget:
raise Exception(
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
verbose_proxy_logger.debug(
f"user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
)
if (
user_max_budget is not None
and user_current_spend is not None
):
if user_current_spend > user_max_budget:
raise Exception(
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
)
else:
# Token exists, not expired now check if its in budget for the user
user_max_budget = getattr(
user_id_information, "max_budget", None
)
user_current_spend = getattr(user_id_information, "spend", None)
if (
user_max_budget is not None
and user_current_spend is not None
):
if user_current_spend > user_max_budget:
raise Exception(
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
)
# Check 3. If token is expired
if valid_token.expires is not None:
current_time = datetime.now(timezone.utc)
@ -1165,6 +1197,7 @@ async def generate_key_helper_fn(
tpm_limit: Optional[int] = None,
rpm_limit: Optional[int] = None,
query_type: Literal["insert_data", "update_data"] = "insert_data",
update_key_values: Optional[dict] = None,
):
global prisma_client, custom_db_client
@ -1265,7 +1298,9 @@ async def generate_key_helper_fn(
key_data["models"] = user_row.models
elif query_type == "update_data":
user_row = await prisma_client.update_data(
data=user_data, table_name="user"
data=user_data,
table_name="user",
update_key_values=update_key_values,
)
## CREATE KEY
@ -1598,6 +1633,10 @@ async def startup_event():
max_budget=litellm.max_budget,
budget_duration=litellm.budget_duration,
query_type="update_data",
update_key_values={
"max_budget": litellm.max_budget,
"budget_duration": litellm.budget_duration,
},
)
verbose_proxy_logger.debug(