diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 778a012b6..908cb58cf 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -458,6 +458,17 @@ class LiteLLM_VerificationToken(LiteLLMBase): protected_namespaces = () +class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): + """ + Combined view of litellm verification token + litellm team table (select values) + """ + + team_spend: Optional[float] = None + team_tpm_limit: Optional[int] = None + team_rpm_limit: Optional[int] = None + team_max_budget: Optional[float] = None + + class UserAPIKeyAuth( LiteLLM_VerificationToken ): # the expected response object for user api key auth diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 56c49ae32..e1f604096 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -350,13 +350,14 @@ async def user_api_key_auth( original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility) if api_key.startswith("sk-"): api_key = hash_token(token=api_key) - valid_token = user_api_key_cache.get_cache(key=api_key) + # valid_token = user_api_key_cache.get_cache(key=api_key) + valid_token = None if valid_token is None: ## check db verbose_proxy_logger.debug(f"api key: {api_key}") if prisma_client is not None: valid_token = await prisma_client.get_data( - token=api_key, + token=api_key, table_name="combined_view" ) elif custom_db_client is not None: @@ -381,6 +382,8 @@ async def user_api_key_auth( # 4. If token is expired # 5. If token spend is under Budget for the token # 6. If token spend per model is under budget per model + # 7. If token spend is under team budget + # 8. If team spend is under team budget request_data = await _read_request_body( request=request @@ -610,6 +613,44 @@ async def user_api_key_auth( f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}" ) + # Check 6. Token spend is under Team budget + if ( + valid_token.spend is not None + and valid_token.team_max_budget is not None + ): + asyncio.create_task( + proxy_logging_obj.budget_alerts( + user_max_budget=valid_token.team_max_budget, + user_current_spend=valid_token.spend, + type="token_budget", + user_info=valid_token, + ) + ) + + if valid_token.spend > valid_token.team_max_budget: + raise Exception( + f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}" + ) + + # Check 7. Team spend is under Team budget + if ( + valid_token.team_spend is not None + and valid_token.team_max_budget is not None + ): + asyncio.create_task( + proxy_logging_obj.budget_alerts( + user_max_budget=valid_token.team_max_budget, + user_current_spend=valid_token.team_spend, + type="token_budget", + user_info=valid_token, + ) + ) + + if valid_token.team_spend > valid_token.team_max_budget: + raise Exception( + f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}" + ) + # Token passed all checks api_key = valid_token.token @@ -2256,6 +2297,10 @@ async def startup_event(): duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) + ### CHECK IF VIEW EXISTS ### + create_view_response = await prisma_client.check_view_exists() + print(f"create_view_response: {create_view_response}") # noqa + ### START BUDGET SCHEDULER ### if prisma_client is not None: scheduler = AsyncIOScheduler() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8defd918c..6b945ce72 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -5,6 +5,7 @@ from litellm.proxy._types import ( UserAPIKeyAuth, DynamoDBArgs, LiteLLM_VerificationToken, + LiteLLM_VerificationTokenView, LiteLLM_SpendLogs, ) from litellm.caching import DualCache @@ -479,6 +480,49 @@ class PrismaClient: db_data[k] = json.dumps(v) return db_data + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def check_view_exists(self): + """ + Checks if the LiteLLM_VerificationTokenView exists in the user's db. + + This is used for getting the token + team data in user_api_key_auth + + If the view doesn't exist, one will be created. + """ + try: + # Try to select one row from the view + await self.db.execute_raw( + """SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""" + ) + return "LiteLLM_VerificationTokenView Exists!" + except Exception as e: + # If an error occurs, the view does not exist, so create it + value = await self.health_check() + if '"litellm_verificationtokenview" does not exist' in str(e): + await self.db.execute_raw( + """ + CREATE VIEW "LiteLLM_VerificationTokenView" AS + SELECT + v.*, + t.spend AS team_spend, + t.max_budget AS team_max_budget, + t.tpm_limit AS team_tpm_limit, + t.rpm_limit AS team_rpm_limit + FROM "LiteLLM_VerificationToken" v + LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; + """ + ) + else: + raise e + + return "LiteLLM_VerificationTokenView Created!" + @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff @@ -535,7 +579,15 @@ class PrismaClient: team_id_list: Optional[list] = None, key_val: Optional[dict] = None, table_name: Optional[ - Literal["user", "key", "config", "spend", "team", "user_notification"] + Literal[ + "user", + "key", + "config", + "spend", + "team", + "user_notification", + "combined_view", + ] ] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", expires: Optional[datetime] = None, @@ -543,7 +595,9 @@ class PrismaClient: ): try: response: Any = None - if token is not None or (table_name is not None and table_name == "key"): + if (token is not None and table_name is None) or ( + table_name is not None and table_name == "key" + ): # check if plain text or hash if token is not None: if isinstance(token, str): @@ -723,6 +777,38 @@ class PrismaClient: elif query_type == "find_all": response = await self.db.litellm_usernotifications.find_many() # type: ignore return response + elif table_name == "combined_view": + # check if plain text or hash + if token is not None: + if isinstance(token, str): + hashed_token = token + if token.startswith("sk-"): + hashed_token = self.hash_token(token=token) + verbose_proxy_logger.debug( + f"PrismaClient: find_unique for token: {hashed_token}" + ) + if query_type == "find_unique": + if token is None: + raise HTTPException( + status_code=400, + detail={"error": f"No token passed in. Token={token}"}, + ) + + sql_query = f""" + SELECT * + FROM "LiteLLM_VerificationTokenView" + WHERE token = '{token}' + """ + + response = await self.db.query_first(query=sql_query) + if response is not None: + response = LiteLLM_VerificationTokenView(**response) + # for prisma we need to cast the expires time to str + if response.expires is not None and isinstance( + response.expires, datetime + ): + response.expires = response.expires.isoformat() + return response except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") import traceback