diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 286ccfeea1..1a6418f37a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -370,30 +370,62 @@ async def user_api_key_auth( ) # Check 2. If user_id for this token is in budget + ## Check 2.5 If global proxy is in budget if valid_token.user_id is not None: if prisma_client is not None: user_id_information = await prisma_client.get_data( - user_id=valid_token.user_id, table_name="user" + user_id_list=[valid_token.user_id, "litellm-proxy-budget"], + table_name="user", + query_type="find_all", ) if custom_db_client is not None: user_id_information = await custom_db_client.get_data( key=valid_token.user_id, table_name="user" ) + verbose_proxy_logger.debug( f"user_id_information: {user_id_information}" ) - # Token exists, not expired now check if its in budget for the user - if valid_token.spend is not None and valid_token.user_id is not None: - user_max_budget = getattr(user_id_information, "max_budget", None) - user_current_spend = getattr(user_id_information, "spend", None) + if user_id_information is not None: + if isinstance(user_id_information, list): + ## Check if user in budget + for _user in user_id_information: + if _user is None: + continue + assert isinstance(_user, dict) + # Token exists, not expired now check if its in budget for the user + user_max_budget = _user.get("max_budget", None) + user_current_spend = _user.get("spend", None) - if user_max_budget is not None and user_current_spend is not None: - if user_current_spend > user_max_budget: - raise Exception( - f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}" + verbose_proxy_logger.debug( + f"user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}" ) + if ( + user_max_budget is not None + and user_current_spend is not None + ): + if user_current_spend > user_max_budget: + raise Exception( + f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}" + ) + else: + # Token exists, not expired now check if its in budget for the user + user_max_budget = getattr( + user_id_information, "max_budget", None + ) + user_current_spend = getattr(user_id_information, "spend", None) + + if ( + user_max_budget is not None + and user_current_spend is not None + ): + if user_current_spend > user_max_budget: + raise Exception( + f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}" + ) + # Check 3. If token is expired if valid_token.expires is not None: current_time = datetime.now(timezone.utc) @@ -1165,6 +1197,7 @@ async def generate_key_helper_fn( tpm_limit: Optional[int] = None, rpm_limit: Optional[int] = None, query_type: Literal["insert_data", "update_data"] = "insert_data", + update_key_values: Optional[dict] = None, ): global prisma_client, custom_db_client @@ -1265,7 +1298,9 @@ async def generate_key_helper_fn( key_data["models"] = user_row.models elif query_type == "update_data": user_row = await prisma_client.update_data( - data=user_data, table_name="user" + data=user_data, + table_name="user", + update_key_values=update_key_values, ) ## CREATE KEY @@ -1598,6 +1633,10 @@ async def startup_event(): max_budget=litellm.max_budget, budget_duration=litellm.budget_duration, query_type="update_data", + update_key_values={ + "max_budget": litellm.max_budget, + "budget_duration": litellm.budget_duration, + }, ) verbose_proxy_logger.debug( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8d06106c09..787c927666 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -361,6 +361,7 @@ class PrismaClient: self, token: Optional[str] = None, user_id: Optional[str] = None, + user_id_list: Optional[list] = None, key_val: Optional[dict] = None, table_name: Optional[Literal["user", "key", "config", "spend"]] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", @@ -442,6 +443,17 @@ class PrismaClient: "budget_reset_at": {"lt": reset_at}, } ) + elif query_type == "find_all" and user_id_list is not None: + user_id_values = str(tuple(user_id_list)) + sql_query = f""" + SELECT * + FROM "LiteLLM_UserTable" + WHERE "user_id" IN {user_id_values} + """ + + # Execute the raw query + # The asterisk before `user_id_list` unpacks the list into separate arguments + response = await self.db.query_raw(sql_query) return response elif table_name == "user" and query_type == "find_all": response = await self.db.litellm_usertable.find_many( # type: ignore @@ -586,6 +598,7 @@ class PrismaClient: user_id: Optional[str] = None, query_type: Literal["update", "update_many"] = "update", table_name: Optional[Literal["user", "key", "config", "spend"]] = None, + update_key_values: Optional[dict] = None, ): """ Update existing data @@ -612,28 +625,22 @@ class PrismaClient: user_id is not None or (table_name is not None and table_name == "user") and query_type == "update" + and update_key_values is not None ): """ If data['spend'] + data['user'], update the user table with spend info as well """ if user_id is None: user_id = db_data["user_id"] - update_user_row = await self.db.litellm_usertable.update( + update_user_row = await self.db.litellm_usertable.upsert( where={"user_id": user_id}, # type: ignore - data={**db_data}, # type: ignore + data={ + "create": {**db_data}, # type: ignore + "update": { + **update_key_values # type: ignore + }, # just update user-specified values, if it already exists + }, ) - if update_user_row is None: - # if the provided user does not exist, STILL Track this! - # make a new user with {"user_id": user_id, "spend": data['spend']} - - db_data["user_id"] = user_id - update_user_row = await self.db.litellm_usertable.upsert( - where={"user_id": user_id}, # type: ignore - data={ - "create": {**db_data}, # type: ignore - "update": {}, # don't do anything if it already exists - }, - ) verbose_proxy_logger.info( "\033[91m" + f"DB User Table - update succeeded {update_user_row}"