fix(proxy_server.py): enforce budget limit if global proxy limit reached

This commit is contained in:
Krrish Dholakia 2024-01-24 17:11:40 -08:00
parent 624da17698
commit 30a8071bf1
2 changed files with 70 additions and 24 deletions

View file

@ -370,25 +370,57 @@ async def user_api_key_auth(
) )
# Check 2. If user_id for this token is in budget # Check 2. If user_id for this token is in budget
## Check 2.5 If global proxy is in budget
if valid_token.user_id is not None: if valid_token.user_id is not None:
if prisma_client is not None: if prisma_client is not None:
user_id_information = await prisma_client.get_data( user_id_information = await prisma_client.get_data(
user_id=valid_token.user_id, table_name="user" user_id_list=[valid_token.user_id, "litellm-proxy-budget"],
table_name="user",
query_type="find_all",
) )
if custom_db_client is not None: if custom_db_client is not None:
user_id_information = await custom_db_client.get_data( user_id_information = await custom_db_client.get_data(
key=valid_token.user_id, table_name="user" key=valid_token.user_id, table_name="user"
) )
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"user_id_information: {user_id_information}" f"user_id_information: {user_id_information}"
) )
if user_id_information is not None:
if isinstance(user_id_information, list):
## Check if user in budget
for _user in user_id_information:
if _user is None:
continue
assert isinstance(_user, dict)
# Token exists, not expired now check if its in budget for the user # Token exists, not expired now check if its in budget for the user
if valid_token.spend is not None and valid_token.user_id is not None: user_max_budget = _user.get("max_budget", None)
user_max_budget = getattr(user_id_information, "max_budget", None) user_current_spend = _user.get("spend", None)
verbose_proxy_logger.debug(
f"user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
)
if (
user_max_budget is not None
and user_current_spend is not None
):
if user_current_spend > user_max_budget:
raise Exception(
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
)
else:
# Token exists, not expired now check if its in budget for the user
user_max_budget = getattr(
user_id_information, "max_budget", None
)
user_current_spend = getattr(user_id_information, "spend", None) user_current_spend = getattr(user_id_information, "spend", None)
if user_max_budget is not None and user_current_spend is not None: if (
user_max_budget is not None
and user_current_spend is not None
):
if user_current_spend > user_max_budget: if user_current_spend > user_max_budget:
raise Exception( raise Exception(
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}" f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
@ -1165,6 +1197,7 @@ async def generate_key_helper_fn(
tpm_limit: Optional[int] = None, tpm_limit: Optional[int] = None,
rpm_limit: Optional[int] = None, rpm_limit: Optional[int] = None,
query_type: Literal["insert_data", "update_data"] = "insert_data", query_type: Literal["insert_data", "update_data"] = "insert_data",
update_key_values: Optional[dict] = None,
): ):
global prisma_client, custom_db_client global prisma_client, custom_db_client
@ -1265,7 +1298,9 @@ async def generate_key_helper_fn(
key_data["models"] = user_row.models key_data["models"] = user_row.models
elif query_type == "update_data": elif query_type == "update_data":
user_row = await prisma_client.update_data( user_row = await prisma_client.update_data(
data=user_data, table_name="user" data=user_data,
table_name="user",
update_key_values=update_key_values,
) )
## CREATE KEY ## CREATE KEY
@ -1598,6 +1633,10 @@ async def startup_event():
max_budget=litellm.max_budget, max_budget=litellm.max_budget,
budget_duration=litellm.budget_duration, budget_duration=litellm.budget_duration,
query_type="update_data", query_type="update_data",
update_key_values={
"max_budget": litellm.max_budget,
"budget_duration": litellm.budget_duration,
},
) )
verbose_proxy_logger.debug( verbose_proxy_logger.debug(

View file

@ -361,6 +361,7 @@ class PrismaClient:
self, self,
token: Optional[str] = None, token: Optional[str] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
user_id_list: Optional[list] = None,
key_val: Optional[dict] = None, key_val: Optional[dict] = None,
table_name: Optional[Literal["user", "key", "config", "spend"]] = None, table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique", query_type: Literal["find_unique", "find_all"] = "find_unique",
@ -442,6 +443,17 @@ class PrismaClient:
"budget_reset_at": {"lt": reset_at}, "budget_reset_at": {"lt": reset_at},
} }
) )
elif query_type == "find_all" and user_id_list is not None:
user_id_values = str(tuple(user_id_list))
sql_query = f"""
SELECT *
FROM "LiteLLM_UserTable"
WHERE "user_id" IN {user_id_values}
"""
# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response return response
elif table_name == "user" and query_type == "find_all": elif table_name == "user" and query_type == "find_all":
response = await self.db.litellm_usertable.find_many( # type: ignore response = await self.db.litellm_usertable.find_many( # type: ignore
@ -586,6 +598,7 @@ class PrismaClient:
user_id: Optional[str] = None, user_id: Optional[str] = None,
query_type: Literal["update", "update_many"] = "update", query_type: Literal["update", "update_many"] = "update",
table_name: Optional[Literal["user", "key", "config", "spend"]] = None, table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
update_key_values: Optional[dict] = None,
): ):
""" """
Update existing data Update existing data
@ -612,26 +625,20 @@ class PrismaClient:
user_id is not None user_id is not None
or (table_name is not None and table_name == "user") or (table_name is not None and table_name == "user")
and query_type == "update" and query_type == "update"
and update_key_values is not None
): ):
""" """
If data['spend'] + data['user'], update the user table with spend info as well If data['spend'] + data['user'], update the user table with spend info as well
""" """
if user_id is None: if user_id is None:
user_id = db_data["user_id"] user_id = db_data["user_id"]
update_user_row = await self.db.litellm_usertable.update(
where={"user_id": user_id}, # type: ignore
data={**db_data}, # type: ignore
)
if update_user_row is None:
# if the provided user does not exist, STILL Track this!
# make a new user with {"user_id": user_id, "spend": data['spend']}
db_data["user_id"] = user_id
update_user_row = await self.db.litellm_usertable.upsert( update_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": user_id}, # type: ignore where={"user_id": user_id}, # type: ignore
data={ data={
"create": {**db_data}, # type: ignore "create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists "update": {
**update_key_values # type: ignore
}, # just update user-specified values, if it already exists
}, },
) )
verbose_proxy_logger.info( verbose_proxy_logger.info(