diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 4fd1bf3b0..2cd979b4b 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -82,6 +82,8 @@ class LiteLLMRoutes(enum.Enum): "/team/update", "/team/delete", "/team/info", + "/team/block", + "/team/unblock", # model "/model/new", "/model/update", @@ -396,6 +398,7 @@ class TeamBase(LiteLLMBase): rpm_limit: Optional[int] = None max_budget: Optional[float] = None models: list = [] + blocked: bool = False class NewTeamRequest(TeamBase): @@ -436,6 +439,10 @@ class DeleteTeamRequest(LiteLLMBase): team_ids: List[str] # required +class BlockTeamRequest(LiteLLMBase): + team_id: str # required + + class LiteLLM_TeamTable(TeamBase): spend: Optional[float] = None max_parallel_requests: Optional[int] = None diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index b8f7c6e3f..37ec2065f 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -30,12 +30,17 @@ def common_checks( """ Common checks across jwt + key-based auth. - 1. If user can call model - 2. If user is in budget - 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + 1. If team is blocked + 2. If team can call model + 3. If team is in budget + 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ _model = request_body.get("model", None) - # 1. If user can call model + if team_object.blocked == True: + raise Exception( + f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." + ) + # 2. If user can call model if ( _model is not None and len(team_object.models) > 0 @@ -44,7 +49,7 @@ def common_checks( raise Exception( f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}" ) - # 2. If team is in budget + # 3. If team is in budget if ( team_object.max_budget is not None and team_object.spend is not None @@ -53,7 +58,7 @@ def common_checks( raise Exception( f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}" ) - # 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + # 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget if end_user_object is not None and end_user_object.litellm_budget_table is not None: end_user_budget = end_user_object.litellm_budget_table.max_budget if end_user_budget is not None and end_user_object.spend > end_user_budget: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1ccab49e4..209634fe5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -6140,6 +6140,50 @@ async def team_info( ) +@router.post( + "/team/block", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def block_team( + data: BlockTeamRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Blocks all calls from keys with this team id. + """ + global prisma_client + + if prisma_client is None: + raise Exception("No DB Connected.") + + record = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"blocked": True} + ) + + return record + + +@router.post( + "/team/unblock", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def unblock_team( + data: BlockTeamRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Blocks all calls from keys with this team id. + """ + global prisma_client + + if prisma_client is None: + raise Exception("No DB Connected.") + + record = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"blocked": False} + ) + + return record + + #### ORGANIZATION MANAGEMENT #### diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 598848776..323d6d1b9 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -71,6 +71,7 @@ model LiteLLM_TeamTable { max_parallel_requests Int? tpm_limit BigInt? rpm_limit BigInt? + blocked Boolean @default(false) budget_duration String? budget_reset_at DateTime? blocked Boolean @default(false) diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index 311bbfa57..13ad42b50 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -85,6 +85,8 @@ def test_context_window_with_fallbacks(model): ) except litellm.ServiceUnavailableError as e: pass + except litellm.APIConnectionError as e: + pass # for model in litellm.models_by_provider["bedrock"]: