mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(proxy_server.py): enforce budget limit if global proxy limit reached
This commit is contained in:
parent
624da17698
commit
30a8071bf1
2 changed files with 70 additions and 24 deletions
|
@ -370,30 +370,62 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check 2. If user_id for this token is in budget
|
# Check 2. If user_id for this token is in budget
|
||||||
|
## Check 2.5 If global proxy is in budget
|
||||||
if valid_token.user_id is not None:
|
if valid_token.user_id is not None:
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
user_id_information = await prisma_client.get_data(
|
user_id_information = await prisma_client.get_data(
|
||||||
user_id=valid_token.user_id, table_name="user"
|
user_id_list=[valid_token.user_id, "litellm-proxy-budget"],
|
||||||
|
table_name="user",
|
||||||
|
query_type="find_all",
|
||||||
)
|
)
|
||||||
if custom_db_client is not None:
|
if custom_db_client is not None:
|
||||||
user_id_information = await custom_db_client.get_data(
|
user_id_information = await custom_db_client.get_data(
|
||||||
key=valid_token.user_id, table_name="user"
|
key=valid_token.user_id, table_name="user"
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"user_id_information: {user_id_information}"
|
f"user_id_information: {user_id_information}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Token exists, not expired now check if its in budget for the user
|
if user_id_information is not None:
|
||||||
if valid_token.spend is not None and valid_token.user_id is not None:
|
if isinstance(user_id_information, list):
|
||||||
user_max_budget = getattr(user_id_information, "max_budget", None)
|
## Check if user in budget
|
||||||
user_current_spend = getattr(user_id_information, "spend", None)
|
for _user in user_id_information:
|
||||||
|
if _user is None:
|
||||||
|
continue
|
||||||
|
assert isinstance(_user, dict)
|
||||||
|
# Token exists, not expired now check if its in budget for the user
|
||||||
|
user_max_budget = _user.get("max_budget", None)
|
||||||
|
user_current_spend = _user.get("spend", None)
|
||||||
|
|
||||||
if user_max_budget is not None and user_current_spend is not None:
|
verbose_proxy_logger.debug(
|
||||||
if user_current_spend > user_max_budget:
|
f"user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
|
||||||
raise Exception(
|
|
||||||
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_max_budget is not None
|
||||||
|
and user_current_spend is not None
|
||||||
|
):
|
||||||
|
if user_current_spend > user_max_budget:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Token exists, not expired now check if its in budget for the user
|
||||||
|
user_max_budget = getattr(
|
||||||
|
user_id_information, "max_budget", None
|
||||||
|
)
|
||||||
|
user_current_spend = getattr(user_id_information, "spend", None)
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_max_budget is not None
|
||||||
|
and user_current_spend is not None
|
||||||
|
):
|
||||||
|
if user_current_spend > user_max_budget:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check 3. If token is expired
|
# Check 3. If token is expired
|
||||||
if valid_token.expires is not None:
|
if valid_token.expires is not None:
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
|
@ -1165,6 +1197,7 @@ async def generate_key_helper_fn(
|
||||||
tpm_limit: Optional[int] = None,
|
tpm_limit: Optional[int] = None,
|
||||||
rpm_limit: Optional[int] = None,
|
rpm_limit: Optional[int] = None,
|
||||||
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
||||||
|
update_key_values: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
global prisma_client, custom_db_client
|
global prisma_client, custom_db_client
|
||||||
|
|
||||||
|
@ -1265,7 +1298,9 @@ async def generate_key_helper_fn(
|
||||||
key_data["models"] = user_row.models
|
key_data["models"] = user_row.models
|
||||||
elif query_type == "update_data":
|
elif query_type == "update_data":
|
||||||
user_row = await prisma_client.update_data(
|
user_row = await prisma_client.update_data(
|
||||||
data=user_data, table_name="user"
|
data=user_data,
|
||||||
|
table_name="user",
|
||||||
|
update_key_values=update_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
## CREATE KEY
|
## CREATE KEY
|
||||||
|
@ -1598,6 +1633,10 @@ async def startup_event():
|
||||||
max_budget=litellm.max_budget,
|
max_budget=litellm.max_budget,
|
||||||
budget_duration=litellm.budget_duration,
|
budget_duration=litellm.budget_duration,
|
||||||
query_type="update_data",
|
query_type="update_data",
|
||||||
|
update_key_values={
|
||||||
|
"max_budget": litellm.max_budget,
|
||||||
|
"budget_duration": litellm.budget_duration,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
|
|
@ -361,6 +361,7 @@ class PrismaClient:
|
||||||
self,
|
self,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
|
user_id_list: Optional[list] = None,
|
||||||
key_val: Optional[dict] = None,
|
key_val: Optional[dict] = None,
|
||||||
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
||||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||||
|
@ -442,6 +443,17 @@ class PrismaClient:
|
||||||
"budget_reset_at": {"lt": reset_at},
|
"budget_reset_at": {"lt": reset_at},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif query_type == "find_all" and user_id_list is not None:
|
||||||
|
user_id_values = str(tuple(user_id_list))
|
||||||
|
sql_query = f"""
|
||||||
|
SELECT *
|
||||||
|
FROM "LiteLLM_UserTable"
|
||||||
|
WHERE "user_id" IN {user_id_values}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Execute the raw query
|
||||||
|
# The asterisk before `user_id_list` unpacks the list into separate arguments
|
||||||
|
response = await self.db.query_raw(sql_query)
|
||||||
return response
|
return response
|
||||||
elif table_name == "user" and query_type == "find_all":
|
elif table_name == "user" and query_type == "find_all":
|
||||||
response = await self.db.litellm_usertable.find_many( # type: ignore
|
response = await self.db.litellm_usertable.find_many( # type: ignore
|
||||||
|
@ -586,6 +598,7 @@ class PrismaClient:
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
query_type: Literal["update", "update_many"] = "update",
|
query_type: Literal["update", "update_many"] = "update",
|
||||||
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
||||||
|
update_key_values: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Update existing data
|
Update existing data
|
||||||
|
@ -612,28 +625,22 @@ class PrismaClient:
|
||||||
user_id is not None
|
user_id is not None
|
||||||
or (table_name is not None and table_name == "user")
|
or (table_name is not None and table_name == "user")
|
||||||
and query_type == "update"
|
and query_type == "update"
|
||||||
|
and update_key_values is not None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
If data['spend'] + data['user'], update the user table with spend info as well
|
If data['spend'] + data['user'], update the user table with spend info as well
|
||||||
"""
|
"""
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
user_id = db_data["user_id"]
|
user_id = db_data["user_id"]
|
||||||
update_user_row = await self.db.litellm_usertable.update(
|
update_user_row = await self.db.litellm_usertable.upsert(
|
||||||
where={"user_id": user_id}, # type: ignore
|
where={"user_id": user_id}, # type: ignore
|
||||||
data={**db_data}, # type: ignore
|
data={
|
||||||
|
"create": {**db_data}, # type: ignore
|
||||||
|
"update": {
|
||||||
|
**update_key_values # type: ignore
|
||||||
|
}, # just update user-specified values, if it already exists
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if update_user_row is None:
|
|
||||||
# if the provided user does not exist, STILL Track this!
|
|
||||||
# make a new user with {"user_id": user_id, "spend": data['spend']}
|
|
||||||
|
|
||||||
db_data["user_id"] = user_id
|
|
||||||
update_user_row = await self.db.litellm_usertable.upsert(
|
|
||||||
where={"user_id": user_id}, # type: ignore
|
|
||||||
data={
|
|
||||||
"create": {**db_data}, # type: ignore
|
|
||||||
"update": {}, # don't do anything if it already exists
|
|
||||||
},
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
"\033[91m"
|
"\033[91m"
|
||||||
+ f"DB User Table - update succeeded {update_user_row}"
|
+ f"DB User Table - update succeeded {update_user_row}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue