[Feat - Team Member Permissions] - CRUD Endpoints for managing team member permissions (#9919)

* add team_member_permissions

* add GetTeamMemberPermissionsRequest types

* crud endpoint for team member permissions

* test team member permissions CRUD

* fix GetTeamMemberPermissionsRequest
This commit is contained in:
Ishaan Jaff 2025-04-11 17:15:16 -07:00 committed by GitHub
parent 2d6ad534bc
commit 91c0a794b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 560 additions and 57 deletions

View file

@ -191,6 +191,28 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMPydanticObjectBase):
rpm_limit: Optional[int] = None
class KeyManagementRoutes(str, enum.Enum):
"""
Enum for key management routes
"""
# write routes
KEY_GENERATE = "/key/generate"
KEY_UPDATE = "/key/update"
KEY_DELETE = "/key/delete"
KEY_REGENERATE = "/key/regenerate"
KEY_REGENERATE_WITH_PATH_PARAM = "/key/{key_id}/regenerate"
KEY_BLOCK = "/key/block"
KEY_UNBLOCK = "/key/unblock"
# info and health routes
KEY_INFO = "/key/info"
KEY_HEALTH = "/key/health"
# list routes
KEY_LIST = "/key/list"
class LiteLLMRoutes(enum.Enum):
openai_route_names = [
"chat_completion",
@ -321,14 +343,19 @@ class LiteLLMRoutes(enum.Enum):
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes = ["/global/spend/reset"]
management_routes = [ # key
"/key/generate",
"/key/{token_id}/regenerate",
"/key/update",
"/key/delete",
"/key/info",
"/key/health",
"/key/list",
key_management_routes = [
KeyManagementRoutes.KEY_GENERATE,
KeyManagementRoutes.KEY_UPDATE,
KeyManagementRoutes.KEY_DELETE,
KeyManagementRoutes.KEY_INFO,
KeyManagementRoutes.KEY_REGENERATE,
KeyManagementRoutes.KEY_REGENERATE_WITH_PATH_PARAM,
KeyManagementRoutes.KEY_LIST,
KeyManagementRoutes.KEY_BLOCK,
KeyManagementRoutes.KEY_UNBLOCK,
]
management_routes = [
# user
"/user/new",
"/user/update",
@ -348,7 +375,7 @@ class LiteLLMRoutes(enum.Enum):
"/model/update",
"/model/delete",
"/model/info",
]
] + key_management_routes
spend_tracking_routes = [
# spend
@ -618,9 +645,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
allowed_cache_controls: Optional[list] = []
config: Optional[dict] = {}
permissions: Optional[dict] = {}
model_max_budget: Optional[
dict
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_max_budget: Optional[dict] = (
{}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = ConfigDict(protected_namespaces=())
model_rpm_limit: Optional[dict] = None
@ -876,12 +903,12 @@ class NewCustomerRequest(BudgetNewRequest):
alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
@model_validator(mode="before")
@classmethod
@ -903,12 +930,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@ -1043,9 +1070,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
class AddTeamCallback(LiteLLMPydanticObjectBase):
callback_name: str
callback_type: Optional[
Literal["success", "failure", "success_and_failure"]
] = "success_and_failure"
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
"success_and_failure"
)
callback_vars: Dict[str, str]
@model_validator(mode="before")
@ -1110,6 +1137,7 @@ class LiteLLM_TeamTable(TeamBase):
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
team_member_permissions: Optional[List[str]] = None
litellm_model_table: Optional[LiteLLM_ModelTable] = None
created_at: Optional[datetime] = None
@ -1302,9 +1330,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
stored_in_db: Optional[bool]
field_default_value: Any
premium_field: bool = False
nested_fields: Optional[
List[FieldDetail]
] = None # For nested dictionary or Pydantic fields
nested_fields: Optional[List[FieldDetail]] = (
None # For nested dictionary or Pydantic fields
)
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
@ -1570,9 +1598,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = None
created_at: datetime
updated_at: datetime
user: Optional[
Any
] = None # You might want to replace 'Any' with a more specific type if available
user: Optional[Any] = (
None # You might want to replace 'Any' with a more specific type if available
)
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
model_config = ConfigDict(protected_namespaces=())
@ -2156,6 +2184,11 @@ class ProxyErrorTypes(str, enum.Enum):
Cache ping error
"""
team_member_permission_error = "team_member_permission_error"
"""
Team member permission error
"""
@classmethod
def get_model_access_error_type_for_object(
cls, object_type: Literal["key", "user", "team"]
@ -2313,9 +2346,9 @@ class TeamModelDeleteRequest(BaseModel):
# Organization Member Requests
class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str
max_budget_in_organization: Optional[
float
] = None # Users max budget within the organization
max_budget_in_organization: Optional[float] = (
None # Users max budget within the organization
)
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@ -2504,9 +2537,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
Maps provider names to their budget configs.
"""
providers: Dict[
str, ProviderBudgetResponseObject
] = {} # Dictionary mapping provider names to their budget configurations
providers: Dict[str, ProviderBudgetResponseObject] = (
{}
) # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict):
@ -2634,9 +2667,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None
object_id_jwt_field: Optional[
str
] = None # can be either user / team, inferred from the role mapping
object_id_jwt_field: Optional[str] = (
None # can be either user / team, inferred from the role mapping
)
scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False
enforce_team_based_model_access: bool = False

View file

@ -1,7 +1,7 @@
"""
TEAM MANAGEMENT
All /team management endpoints
All /team management endpoints
/team/new
/team/info
@ -62,6 +62,9 @@ from litellm.proxy.management_endpoints.common_utils import (
_is_user_team_admin,
_set_object_metadata_field,
)
from litellm.proxy.management_helpers.team_member_permission_checks import (
TeamMemberPermissionChecks,
)
from litellm.proxy.management_helpers.utils import (
add_new_member,
management_endpoint_wrapper,
@ -72,6 +75,10 @@ from litellm.proxy.utils import (
handle_exception_on_proxy,
)
from litellm.router import Router
from litellm.types.proxy.management_endpoints.team_endpoints import (
GetTeamMemberPermissionsResponse,
UpdateTeamMemberPermissionsRequest,
)
router = APIRouter()
@ -506,12 +513,12 @@ async def update_team(
updated_kv["model_id"] = _model_id
updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
team_row: Optional[
LiteLLM_TeamTable
] = await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id},
data=updated_kv,
include={"litellm_model_table": True}, # type: ignore
team_row: Optional[LiteLLM_TeamTable] = (
await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id},
data=updated_kv,
include={"litellm_model_table": True}, # type: ignore
)
)
if team_row is None or team_row.team_id is None:
@ -1137,10 +1144,10 @@ async def delete_team(
team_rows: List[LiteLLM_TeamTable] = []
for team_id in data.team_ids:
try:
team_row_base: Optional[
BaseModel
] = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
team_row_base: Optional[BaseModel] = (
await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
)
if team_row_base is None:
raise Exception
@ -1298,10 +1305,10 @@ async def team_info(
)
try:
team_info: Optional[
BaseModel
] = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
team_info: Optional[BaseModel] = (
await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
)
if team_info is None:
raise Exception
@ -1926,3 +1933,89 @@ async def team_model_delete(
)
return updated_team
@router.get(
"/team/permissions_list",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def team_member_permissions(
team_id: str = fastapi.Query(
default=None, description="Team ID in the request parameters"
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> GetTeamMemberPermissionsResponse:
"""
Get the team member permissions for a team
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
team_row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if team_row is None:
raise HTTPException(
status_code=404,
detail={"error": f"Team not found, passed team_id={team_id}"},
)
team_obj = LiteLLM_TeamTable(**team_row.model_dump())
if team_obj.team_member_permissions is None:
team_obj.team_member_permissions = (
TeamMemberPermissionChecks.default_team_member_permissions()
)
return GetTeamMemberPermissionsResponse(
team_id=team_id,
team_member_permissions=team_obj.team_member_permissions,
all_available_permissions=TeamMemberPermissionChecks.get_all_available_team_member_permissions(),
)
@router.post(
"/team/permissions_update",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_team_member_permissions(
data: UpdateTeamMemberPermissionsRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> LiteLLM_TeamTable:
"""
Update the team member permissions for a team
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
team_row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": data.team_id}
)
if team_row is None:
raise HTTPException(
status_code=404,
detail={"error": f"Team not found, passed team_id={data.team_id}"},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
raise HTTPException(
status_code=403,
detail={"error": "Only proxy admin can update team member permissions"},
)
# Update the team member permissions
updated_team = await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id},
data={"team_member_permissions": data.team_member_permissions},
)
return updated_team

View file

@ -0,0 +1,181 @@
from typing import List, Optional
from litellm.caching import DualCache
from litellm.proxy._types import (
KeyManagementRoutes,
LiteLLM_TeamTableCachedObj,
LiteLLM_VerificationToken,
LiteLLMRoutes,
LitellmUserRoles,
Member,
ProxyErrorTypes,
ProxyException,
Span,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_checks import get_team_object
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import PrismaClient
DEFAULT_TEAM_MEMBER_PERMISSIONS = [
KeyManagementRoutes.KEY_INFO,
KeyManagementRoutes.KEY_HEALTH,
]
class TeamMemberPermissionChecks:
@staticmethod
def get_permissions_for_team_member(
team_member_object: Member,
team_table: LiteLLM_TeamTableCachedObj,
) -> List[KeyManagementRoutes]:
"""
Returns the permissions for a team member
"""
if team_table.team_member_permissions and isinstance(
team_table.team_member_permissions, list
):
return [
KeyManagementRoutes(permission)
for permission in team_table.team_member_permissions
]
return DEFAULT_TEAM_MEMBER_PERMISSIONS
@staticmethod
def _get_list_of_route_enum_as_str(
route_enum: List[KeyManagementRoutes],
) -> List[str]:
"""
Returns a list of the route enum as a list of strings
"""
return [route.value for route in route_enum]
@staticmethod
async def can_team_member_execute_key_management_endpoint(
user_api_key_dict: UserAPIKeyAuth,
route: KeyManagementRoutes,
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
existing_key_row: LiteLLM_VerificationToken,
):
"""
Main handler for checking if a team member can update a key
"""
from litellm.proxy.management_endpoints.key_management_endpoints import (
_get_user_in_team,
)
# 1. Don't execute these checks if the user role is proxy admin
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return
# 2. Check if the operation is being done on a team key
if existing_key_row.team_id is None:
return
# 3. Get Team Object from DB
team_table = await get_team_object(
team_id=existing_key_row.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_dict.parent_otel_span,
check_db_only=True,
)
# 4. Extract `Member` object from `team_table`
key_assigned_user_in_team = _get_user_in_team(
team_table=team_table, user_id=user_api_key_dict.user_id
)
# 5. Check if the team member has permissions for the endpoint
TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint(
team_member_object=key_assigned_user_in_team,
team_table=team_table,
route=route,
)
@staticmethod
def does_team_member_have_permissions_for_endpoint(
team_member_object: Optional[Member],
team_table: LiteLLM_TeamTableCachedObj,
route: str,
) -> Optional[bool]:
"""
Raises an exception if the team member does not have permissions for calling the endpoint for a team
"""
# permission checks only run for non-admin users
# Non-Admin user trying to access information about a team's key
if team_member_object is None:
return False
if team_member_object.role == "admin":
return True
_team_member_permissions = (
TeamMemberPermissionChecks.get_permissions_for_team_member(
team_member_object=team_member_object,
team_table=team_table,
)
)
team_member_permissions = (
TeamMemberPermissionChecks._get_list_of_route_enum_as_str(
_team_member_permissions
)
)
if not RouteChecks.check_route_access(
route=route, allowed_routes=team_member_permissions
):
raise ProxyException(
message=f"Team member does not have permissions for endpoint: {route}. You only have access to the following endpoints: {team_member_permissions}",
type=ProxyErrorTypes.team_member_permission_error,
param=route,
code=401,
)
return True
@staticmethod
async def user_belongs_to_keys_team(
user_api_key_dict: UserAPIKeyAuth,
existing_key_row: LiteLLM_VerificationToken,
) -> bool:
"""
Returns True if the user belongs to the team that the key is assigned to
"""
from litellm.proxy.management_endpoints.key_management_endpoints import (
_get_user_in_team,
)
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
if existing_key_row.team_id is None:
return False
team_table = await get_team_object(
team_id=existing_key_row.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_dict.parent_otel_span,
check_db_only=True,
)
# 4. Extract `Member` object from `team_table`
team_member_object = _get_user_in_team(
team_table=team_table, user_id=user_api_key_dict.user_id
)
return team_member_object is not None
@staticmethod
def get_all_available_team_member_permissions() -> List[str]:
"""
Returns all available team member permissions
"""
all_available_permissions = []
for route in LiteLLMRoutes.key_management_routes.value:
all_available_permissions.append(route.value)
return all_available_permissions
@staticmethod
def default_team_member_permissions() -> List[str]:
return [route.value for route in DEFAULT_TEAM_MEMBER_PERMISSIONS]

View file

@ -0,0 +1,35 @@
from typing import List, Optional
from pydantic import BaseModel
class GetTeamMemberPermissionsRequest(BaseModel):
"""Request to get the team member permissions for a team"""
team_id: str
class GetTeamMemberPermissionsResponse(BaseModel):
"""Response to get the team member permissions for a team"""
team_id: str
"""
The team id that the permissions are for
"""
team_member_permissions: Optional[List[str]] = []
"""
The team member permissions currently set for the team
"""
all_available_permissions: List[str]
"""
All available team member permissions
"""
class UpdateTeamMemberPermissionsRequest(BaseModel):
"""Request to update the team member permissions for a team"""
team_id: str
team_member_permissions: List[str]

View file

@ -0,0 +1,161 @@
import asyncio
import json
import os
import sys
import uuid
from typing import Optional, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../")
) # Adds the parent directory to the system path
from litellm.proxy._types import UserAPIKeyAuth # Import UserAPIKeyAuth
from litellm.proxy._types import LiteLLM_TeamTable, LitellmUserRoles
from litellm.proxy.management_endpoints.team_endpoints import (
user_api_key_auth, # Assuming this dependency is needed
)
from litellm.proxy.management_endpoints.team_endpoints import (
GetTeamMemberPermissionsResponse,
UpdateTeamMemberPermissionsRequest,
router,
)
from litellm.proxy.management_helpers.team_member_permission_checks import (
TeamMemberPermissionChecks,
)
from litellm.proxy.proxy_server import app
# Setup TestClient
client = TestClient(app)
# Mock prisma_client
mock_prisma_client = MagicMock()
# Fixture to provide the mock prisma client
@pytest.fixture(autouse=True)
def mock_db_client():
with patch(
"litellm.proxy.proxy_server.prisma_client", mock_prisma_client
): # Mock in both places if necessary
yield mock_prisma_client
mock_prisma_client.reset_mock()
# Fixture to provide a mock admin user auth object
@pytest.fixture
def mock_admin_auth():
mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN)
return mock_auth
# Test for /team/permissions_list endpoint (GET)
@pytest.mark.asyncio
async def test_get_team_permissions_list_success(mock_db_client, mock_admin_auth):
"""
Test successful retrieval of team member permissions.
"""
test_team_id = "test-team-123"
mock_team_data = {
"team_id": test_team_id,
"team_alias": "Test Team",
"team_member_permissions": ["/key/generate", "/key/update"],
"spend": 0.0,
}
mock_team_row = MagicMock()
mock_team_row.model_dump.return_value = mock_team_data
mock_db_client.db.litellm_teamtable.find_unique = AsyncMock(
return_value=mock_team_row
)
# Override the dependency for this test
app.dependency_overrides[user_api_key_auth] = lambda: mock_admin_auth
response = client.get(f"/team/permissions_list?team_id={test_team_id}")
assert response.status_code == 200
response_data = response.json()
assert response_data["team_id"] == test_team_id
assert (
response_data["team_member_permissions"]
== mock_team_data["team_member_permissions"]
)
assert (
response_data["all_available_permissions"]
== TeamMemberPermissionChecks.get_all_available_team_member_permissions()
)
mock_db_client.db.litellm_teamtable.find_unique.assert_awaited_once_with(
where={"team_id": test_team_id}
)
# Clean up dependency override
app.dependency_overrides = {}
# Test for /team/permissions_update endpoint (POST)
@pytest.mark.asyncio
async def test_update_team_permissions_success(mock_db_client, mock_admin_auth):
"""
Test successful update of team member permissions by an admin.
"""
test_team_id = "test-team-456"
update_payload = {
"team_id": test_team_id,
"team_member_permissions": ["/key/generate", "/key/update"],
}
mock_existing_team_data = {
"team_id": test_team_id,
"team_alias": "Existing Team",
"team_member_permissions": ["/key/list"],
"spend": 0.0,
"models": [],
}
mock_updated_team_data = {
**mock_existing_team_data,
"team_member_permissions": update_payload["team_member_permissions"],
}
mock_existing_team_row = MagicMock(spec=LiteLLM_TeamTable)
mock_existing_team_row.model_dump.return_value = mock_existing_team_data
# Set attributes directly if model_dump isn't enough for LiteLLM_TeamTable usage
for key, value in mock_existing_team_data.items():
setattr(mock_existing_team_row, key, value)
mock_updated_team_row = MagicMock(spec=LiteLLM_TeamTable)
mock_updated_team_row.model_dump.return_value = mock_updated_team_data
# Set attributes directly if model_dump isn't enough for LiteLLM_TeamTable usage
for key, value in mock_updated_team_data.items():
setattr(mock_updated_team_row, key, value)
mock_db_client.db.litellm_teamtable.find_unique = AsyncMock(
return_value=mock_existing_team_row
)
mock_db_client.db.litellm_teamtable.update = AsyncMock(
return_value=mock_updated_team_row
)
# Override the dependency for this test
app.dependency_overrides[user_api_key_auth] = lambda: mock_admin_auth
response = client.post("/team/permissions_update", json=update_payload)
assert response.status_code == 200
response_data = response.json()
# Use model_dump for comparison if the endpoint returns the Prisma model directly
assert response_data == mock_updated_team_row.model_dump()
mock_db_client.db.litellm_teamtable.find_unique.assert_awaited_once_with(
where={"team_id": test_team_id}
)
mock_db_client.db.litellm_teamtable.update.assert_awaited_once_with(
where={"team_id": test_team_id},
data={"team_member_permissions": update_payload["team_member_permissions"]},
)
# Clean up dependency override
app.dependency_overrides = {}