diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e834d98b0..480a7b333 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -648,6 +648,20 @@ class LiteLLM_BudgetTable(LiteLLMBase): protected_namespaces = () +class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): + """ + Used to track spend of a user_id within a team_id + """ + + spend: Optional[float] = None + user_id: Optional[str] = None + team_id: Optional[str] = None + budget_id: Optional[str] = None + + class Config: + protected_namespaces = () + + class NewOrganizationRequest(LiteLLM_BudgetTable): organization_id: Optional[str] = None organization_alias: str @@ -938,6 +952,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): team_blocked: bool = False soft_budget: Optional[float] = None team_model_aliases: Optional[Dict] = None + team_member_spend: Optional[float] = None # End User Params end_user_id: Optional[str] = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1db7150f0..89d5e6871 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -798,12 +798,13 @@ async def user_api_key_auth( # Run checks for # 1. If token can call model # 2. If user_id for this token is in budget - # 3. If 'user' passed to /chat/completions, /embeddings endpoint is in budget - # 4. If token is expired - # 5. If token spend is under Budget for the token - # 6. If token spend per model is under budget per model - # 7. If token spend is under team budget - # 8. If team spend is under team budget + # 3. If the user spend within their own team is within budget + # 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget + # 5. If token is expired + # 6. If token spend is under Budget for the token + # 7. If token spend per model is under budget per model + # 8. If token spend is under team budget + # 9. If team spend is under team budget # Check 1. If token can call model _model_alias_map = {} @@ -1001,6 +1002,43 @@ async def user_api_key_auth( raise Exception( f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}" ) + # Check 3. Check if user is in their team budget + if valid_token.team_member_spend is not None: + if prisma_client is not None: + + _cache_key = f"{valid_token.team_id}_{valid_token.user_id}" + + team_member_info = await user_api_key_cache.async_get_cache( + key=_cache_key + ) + if team_member_info is None: + team_member_info = ( + await prisma_client.db.litellm_teammembership.find_first( + where={ + "user_id": valid_token.user_id, + "team_id": valid_token.team_id, + }, # type: ignore + include={"litellm_budget_table": True}, + ) + ) + await user_api_key_cache.async_set_cache( + key=_cache_key, + value=team_member_info, + ttl=UserAPIKeyCacheTTLEnum.user_information_cache.value, + ) + + if ( + team_member_info is not None + and team_member_info.litellm_budget_table is not None + ): + team_member_budget = ( + team_member_info.litellm_budget_table.max_budget + ) + if team_member_budget is not None and team_member_budget > 0: + if valid_token.team_member_spend > team_member_budget: + raise Exception( + f"ExceededBudget: Crossed spend within team. UserID: {valid_token.user_id}, in team {valid_token.team_id} has exceeded their budget. Current spend: {valid_token.team_member_spend}; Max Budget: {team_member_budget}" + ) # Check 3. If token is expired if valid_token.expires is not None: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0742c2109..2cd3e0c37 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1097,9 +1097,11 @@ class PrismaClient: t.models AS team_models, t.blocked AS team_blocked, t.team_alias AS team_alias, + tm.spend AS team_member_spend, m.aliases as team_model_aliases FROM "LiteLLM_VerificationToken" AS v LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id + LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id WHERE v.token = '{token}' """