diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 9704a2f19..75396233e 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -698,6 +698,8 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): team_tpm_limit: Optional[int] = None team_rpm_limit: Optional[int] = None team_max_budget: Optional[float] = None + team_models: List = [] + team_blocked: bool = False soft_budget: Optional[float] = None team_model_aliases: Optional[Dict] = None diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 37ec2065f..5246fb94d 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -15,7 +15,7 @@ from litellm.proxy._types import ( LiteLLM_TeamTable, LiteLLMRoutes, ) -from typing import Optional, Literal +from typing import Optional, Literal, Union from litellm.proxy.utils import PrismaClient from litellm.caching import DualCache @@ -26,6 +26,8 @@ def common_checks( request_body: dict, team_object: LiteLLM_TeamTable, end_user_object: Optional[LiteLLM_EndUserTable], + general_settings: dict, + route: str, ) -> bool: """ Common checks across jwt + key-based auth. @@ -34,6 +36,7 @@ def common_checks( 2. If team can call model 3. If team is in budget 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + 5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints """ _model = request_body.get("model", None) if team_object.blocked == True: @@ -65,6 +68,16 @@ def common_checks( raise Exception( f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" ) + # 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints + if ( + general_settings.get("enforce_user_param", None) is not None + and general_settings["enforce_user_param"] == True + ): + if route in LiteLLMRoutes.openai_routes.value and "user" not in request_body: + raise Exception( + f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" + ) + return True diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8fa2862f2..537ed1bab 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -440,6 +440,8 @@ async def user_api_key_auth( request_body=request_data, team_object=team_object, end_user_object=end_user_object, + general_settings=general_settings, + route=route, ) # save user object in cache await user_api_key_cache.async_set_cache( @@ -867,6 +869,23 @@ async def user_api_key_auth( f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}" ) + # Check 8: Additional Common Checks across jwt + key auth + _team_obj = LiteLLM_TeamTable( + team_id=valid_token.team_id, + max_budget=valid_token.team_max_budget, + spend=valid_token.team_spend, + tpm_limit=valid_token.team_tpm_limit, + rpm_limit=valid_token.team_rpm_limit, + blocked=valid_token.team_blocked, + models=valid_token.team_models, + ) + _ = common_checks( + request_body=request_data, + team_object=_team_obj, + end_user_object=None, + general_settings=general_settings, + route=route, + ) # Token passed all checks api_key = valid_token.token diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ba8d70804..fd8421b50 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1013,6 +1013,8 @@ class PrismaClient: t.max_budget AS team_max_budget, t.tpm_limit AS team_tpm_limit, t.rpm_limit AS team_rpm_limit, + t.models AS team_models, + t.blocked AS team_blocked, m.aliases as team_model_aliases FROM "LiteLLM_VerificationToken" AS v LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id @@ -1023,6 +1025,10 @@ class PrismaClient: response = await self.db.query_first(query=sql_query) if response is not None: + if response["team_models"] is None: + response["team_models"] = [] + if response["team_blocked"] is None: + response["team_blocked"] = False response = LiteLLM_VerificationTokenView(**response) # for prisma we need to cast the expires time to str if response.expires is not None and isinstance(