[Feat - Team Member Permissions] - CRUD Endpoints for managing team member permissions (#9919)

* add team_member_permissions

* add GetTeamMemberPermissionsRequest types

* crud endpoint for team member permissions

* test team member permissions CRUD

* fix GetTeamMemberPermissionsRequest
This commit is contained in:
Ishaan Jaff 2025-04-11 17:15:16 -07:00 committed by GitHub
parent 2d6ad534bc
commit 91c0a794b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 560 additions and 57 deletions

View file

@ -191,6 +191,28 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMPydanticObjectBase):
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None
class KeyManagementRoutes(str, enum.Enum):
"""
Enum for key management routes
"""
# write routes
KEY_GENERATE = "/key/generate"
KEY_UPDATE = "/key/update"
KEY_DELETE = "/key/delete"
KEY_REGENERATE = "/key/regenerate"
KEY_REGENERATE_WITH_PATH_PARAM = "/key/{key_id}/regenerate"
KEY_BLOCK = "/key/block"
KEY_UNBLOCK = "/key/unblock"
# info and health routes
KEY_INFO = "/key/info"
KEY_HEALTH = "/key/health"
# list routes
KEY_LIST = "/key/list"
class LiteLLMRoutes(enum.Enum): class LiteLLMRoutes(enum.Enum):
openai_route_names = [ openai_route_names = [
"chat_completion", "chat_completion",
@ -321,14 +343,19 @@ class LiteLLMRoutes(enum.Enum):
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend # NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes = ["/global/spend/reset"] master_key_only_routes = ["/global/spend/reset"]
management_routes = [ # key key_management_routes = [
"/key/generate", KeyManagementRoutes.KEY_GENERATE,
"/key/{token_id}/regenerate", KeyManagementRoutes.KEY_UPDATE,
"/key/update", KeyManagementRoutes.KEY_DELETE,
"/key/delete", KeyManagementRoutes.KEY_INFO,
"/key/info", KeyManagementRoutes.KEY_REGENERATE,
"/key/health", KeyManagementRoutes.KEY_REGENERATE_WITH_PATH_PARAM,
"/key/list", KeyManagementRoutes.KEY_LIST,
KeyManagementRoutes.KEY_BLOCK,
KeyManagementRoutes.KEY_UNBLOCK,
]
management_routes = [
# user # user
"/user/new", "/user/new",
"/user/update", "/user/update",
@ -348,7 +375,7 @@ class LiteLLMRoutes(enum.Enum):
"/model/update", "/model/update",
"/model/delete", "/model/delete",
"/model/info", "/model/info",
] ] + key_management_routes
spend_tracking_routes = [ spend_tracking_routes = [
# spend # spend
@ -618,9 +645,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
allowed_cache_controls: Optional[list] = [] allowed_cache_controls: Optional[list] = []
config: Optional[dict] = {} config: Optional[dict] = {}
permissions: Optional[dict] = {} permissions: Optional[dict] = {}
model_max_budget: Optional[ model_max_budget: Optional[dict] = (
dict {}
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_rpm_limit: Optional[dict] = None model_rpm_limit: Optional[dict] = None
@ -876,12 +903,12 @@ class NewCustomerRequest(BudgetNewRequest):
alias: Optional[str] = None # human-friendly alias alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[ allowed_model_region: Optional[AllowedModelRegion] = (
AllowedModelRegion None # require all user requests to use models in this specific region
] = None # require all user requests to use models in this specific region )
default_model: Optional[ default_model: Optional[str] = (
str None # if no equivalent model in allowed region - default all requests to this model
] = None # if no equivalent model in allowed region - default all requests to this model )
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -903,12 +930,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[ allowed_model_region: Optional[AllowedModelRegion] = (
AllowedModelRegion None # require all user requests to use models in this specific region
] = None # require all user requests to use models in this specific region )
default_model: Optional[ default_model: Optional[str] = (
str None # if no equivalent model in allowed region - default all requests to this model
] = None # if no equivalent model in allowed region - default all requests to this model )
class DeleteCustomerRequest(LiteLLMPydanticObjectBase): class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@ -1043,9 +1070,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
class AddTeamCallback(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase):
callback_name: str callback_name: str
callback_type: Optional[ callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
Literal["success", "failure", "success_and_failure"] "success_and_failure"
] = "success_and_failure" )
callback_vars: Dict[str, str] callback_vars: Dict[str, str]
@model_validator(mode="before") @model_validator(mode="before")
@ -1110,6 +1137,7 @@ class LiteLLM_TeamTable(TeamBase):
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None model_id: Optional[int] = None
team_member_permissions: Optional[List[str]] = None
litellm_model_table: Optional[LiteLLM_ModelTable] = None litellm_model_table: Optional[LiteLLM_ModelTable] = None
created_at: Optional[datetime] = None created_at: Optional[datetime] = None
@ -1302,9 +1330,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
stored_in_db: Optional[bool] stored_in_db: Optional[bool]
field_default_value: Any field_default_value: Any
premium_field: bool = False premium_field: bool = False
nested_fields: Optional[ nested_fields: Optional[List[FieldDetail]] = (
List[FieldDetail] None # For nested dictionary or Pydantic fields
] = None # For nested dictionary or Pydantic fields )
class ConfigGeneralSettings(LiteLLMPydanticObjectBase): class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
@ -1570,9 +1598,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = None budget_id: Optional[str] = None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
user: Optional[ user: Optional[Any] = (
Any None # You might want to replace 'Any' with a more specific type if available
] = None # You might want to replace 'Any' with a more specific type if available )
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -2156,6 +2184,11 @@ class ProxyErrorTypes(str, enum.Enum):
Cache ping error Cache ping error
""" """
team_member_permission_error = "team_member_permission_error"
"""
Team member permission error
"""
@classmethod @classmethod
def get_model_access_error_type_for_object( def get_model_access_error_type_for_object(
cls, object_type: Literal["key", "user", "team"] cls, object_type: Literal["key", "user", "team"]
@ -2313,9 +2346,9 @@ class TeamModelDeleteRequest(BaseModel):
# Organization Member Requests # Organization Member Requests
class OrganizationMemberAddRequest(OrgMemberAddRequest): class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str organization_id: str
max_budget_in_organization: Optional[ max_budget_in_organization: Optional[float] = (
float None # Users max budget within the organization
] = None # Users max budget within the organization )
class OrganizationMemberDeleteRequest(MemberDeleteRequest): class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@ -2504,9 +2537,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
Maps provider names to their budget configs. Maps provider names to their budget configs.
""" """
providers: Dict[ providers: Dict[str, ProviderBudgetResponseObject] = (
str, ProviderBudgetResponseObject {}
] = {} # Dictionary mapping provider names to their budget configurations ) # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict): class ProxyStateVariables(TypedDict):
@ -2634,9 +2667,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
enforce_rbac: bool = False enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None role_mappings: Optional[List[RoleMapping]] = None
object_id_jwt_field: Optional[ object_id_jwt_field: Optional[str] = (
str None # can be either user / team, inferred from the role mapping
] = None # can be either user / team, inferred from the role mapping )
scope_mappings: Optional[List[ScopeMapping]] = None scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False enforce_scope_based_access: bool = False
enforce_team_based_model_access: bool = False enforce_team_based_model_access: bool = False

View file

@ -1,7 +1,7 @@
""" """
TEAM MANAGEMENT TEAM MANAGEMENT
All /team management endpoints All /team management endpoints
/team/new /team/new
/team/info /team/info
@ -62,6 +62,9 @@ from litellm.proxy.management_endpoints.common_utils import (
_is_user_team_admin, _is_user_team_admin,
_set_object_metadata_field, _set_object_metadata_field,
) )
from litellm.proxy.management_helpers.team_member_permission_checks import (
TeamMemberPermissionChecks,
)
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
management_endpoint_wrapper, management_endpoint_wrapper,
@ -72,6 +75,10 @@ from litellm.proxy.utils import (
handle_exception_on_proxy, handle_exception_on_proxy,
) )
from litellm.router import Router from litellm.router import Router
from litellm.types.proxy.management_endpoints.team_endpoints import (
GetTeamMemberPermissionsResponse,
UpdateTeamMemberPermissionsRequest,
)
router = APIRouter() router = APIRouter()
@ -506,12 +513,12 @@ async def update_team(
updated_kv["model_id"] = _model_id updated_kv["model_id"] = _model_id
updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv) updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
team_row: Optional[ team_row: Optional[LiteLLM_TeamTable] = (
LiteLLM_TeamTable await prisma_client.db.litellm_teamtable.update(
] = await prisma_client.db.litellm_teamtable.update( where={"team_id": data.team_id},
where={"team_id": data.team_id}, data=updated_kv,
data=updated_kv, include={"litellm_model_table": True}, # type: ignore
include={"litellm_model_table": True}, # type: ignore )
) )
if team_row is None or team_row.team_id is None: if team_row is None or team_row.team_id is None:
@ -1137,10 +1144,10 @@ async def delete_team(
team_rows: List[LiteLLM_TeamTable] = [] team_rows: List[LiteLLM_TeamTable] = []
for team_id in data.team_ids: for team_id in data.team_ids:
try: try:
team_row_base: Optional[ team_row_base: Optional[BaseModel] = (
BaseModel await prisma_client.db.litellm_teamtable.find_unique(
] = await prisma_client.db.litellm_teamtable.find_unique( where={"team_id": team_id}
where={"team_id": team_id} )
) )
if team_row_base is None: if team_row_base is None:
raise Exception raise Exception
@ -1298,10 +1305,10 @@ async def team_info(
) )
try: try:
team_info: Optional[ team_info: Optional[BaseModel] = (
BaseModel await prisma_client.db.litellm_teamtable.find_unique(
] = await prisma_client.db.litellm_teamtable.find_unique( where={"team_id": team_id}
where={"team_id": team_id} )
) )
if team_info is None: if team_info is None:
raise Exception raise Exception
@ -1926,3 +1933,89 @@ async def team_model_delete(
) )
return updated_team return updated_team
@router.get(
"/team/permissions_list",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def team_member_permissions(
team_id: str = fastapi.Query(
default=None, description="Team ID in the request parameters"
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> GetTeamMemberPermissionsResponse:
"""
Get the team member permissions for a team
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
team_row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if team_row is None:
raise HTTPException(
status_code=404,
detail={"error": f"Team not found, passed team_id={team_id}"},
)
team_obj = LiteLLM_TeamTable(**team_row.model_dump())
if team_obj.team_member_permissions is None:
team_obj.team_member_permissions = (
TeamMemberPermissionChecks.default_team_member_permissions()
)
return GetTeamMemberPermissionsResponse(
team_id=team_id,
team_member_permissions=team_obj.team_member_permissions,
all_available_permissions=TeamMemberPermissionChecks.get_all_available_team_member_permissions(),
)
@router.post(
"/team/permissions_update",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_team_member_permissions(
data: UpdateTeamMemberPermissionsRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> LiteLLM_TeamTable:
"""
Update the team member permissions for a team
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
team_row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": data.team_id}
)
if team_row is None:
raise HTTPException(
status_code=404,
detail={"error": f"Team not found, passed team_id={data.team_id}"},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
raise HTTPException(
status_code=403,
detail={"error": "Only proxy admin can update team member permissions"},
)
# Update the team member permissions
updated_team = await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id},
data={"team_member_permissions": data.team_member_permissions},
)
return updated_team

View file

@ -0,0 +1,181 @@
from typing import List, Optional
from litellm.caching import DualCache
from litellm.proxy._types import (
KeyManagementRoutes,
LiteLLM_TeamTableCachedObj,
LiteLLM_VerificationToken,
LiteLLMRoutes,
LitellmUserRoles,
Member,
ProxyErrorTypes,
ProxyException,
Span,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_checks import get_team_object
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import PrismaClient
DEFAULT_TEAM_MEMBER_PERMISSIONS = [
KeyManagementRoutes.KEY_INFO,
KeyManagementRoutes.KEY_HEALTH,
]
class TeamMemberPermissionChecks:
@staticmethod
def get_permissions_for_team_member(
team_member_object: Member,
team_table: LiteLLM_TeamTableCachedObj,
) -> List[KeyManagementRoutes]:
"""
Returns the permissions for a team member
"""
if team_table.team_member_permissions and isinstance(
team_table.team_member_permissions, list
):
return [
KeyManagementRoutes(permission)
for permission in team_table.team_member_permissions
]
return DEFAULT_TEAM_MEMBER_PERMISSIONS
@staticmethod
def _get_list_of_route_enum_as_str(
route_enum: List[KeyManagementRoutes],
) -> List[str]:
"""
Returns a list of the route enum as a list of strings
"""
return [route.value for route in route_enum]
@staticmethod
async def can_team_member_execute_key_management_endpoint(
user_api_key_dict: UserAPIKeyAuth,
route: KeyManagementRoutes,
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
existing_key_row: LiteLLM_VerificationToken,
):
"""
Main handler for checking if a team member can update a key
"""
from litellm.proxy.management_endpoints.key_management_endpoints import (
_get_user_in_team,
)
# 1. Don't execute these checks if the user role is proxy admin
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return
# 2. Check if the operation is being done on a team key
if existing_key_row.team_id is None:
return
# 3. Get Team Object from DB
team_table = await get_team_object(
team_id=existing_key_row.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_dict.parent_otel_span,
check_db_only=True,
)
# 4. Extract `Member` object from `team_table`
key_assigned_user_in_team = _get_user_in_team(
team_table=team_table, user_id=user_api_key_dict.user_id
)
# 5. Check if the team member has permissions for the endpoint
TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint(
team_member_object=key_assigned_user_in_team,
team_table=team_table,
route=route,
)
@staticmethod
def does_team_member_have_permissions_for_endpoint(
team_member_object: Optional[Member],
team_table: LiteLLM_TeamTableCachedObj,
route: str,
) -> Optional[bool]:
"""
Raises an exception if the team member does not have permissions for calling the endpoint for a team
"""
# permission checks only run for non-admin users
# Non-Admin user trying to access information about a team's key
if team_member_object is None:
return False
if team_member_object.role == "admin":
return True
_team_member_permissions = (
TeamMemberPermissionChecks.get_permissions_for_team_member(
team_member_object=team_member_object,
team_table=team_table,
)
)
team_member_permissions = (
TeamMemberPermissionChecks._get_list_of_route_enum_as_str(
_team_member_permissions
)
)
if not RouteChecks.check_route_access(
route=route, allowed_routes=team_member_permissions
):
raise ProxyException(
message=f"Team member does not have permissions for endpoint: {route}. You only have access to the following endpoints: {team_member_permissions}",
type=ProxyErrorTypes.team_member_permission_error,
param=route,
code=401,
)
return True
@staticmethod
async def user_belongs_to_keys_team(
user_api_key_dict: UserAPIKeyAuth,
existing_key_row: LiteLLM_VerificationToken,
) -> bool:
"""
Returns True if the user belongs to the team that the key is assigned to
"""
from litellm.proxy.management_endpoints.key_management_endpoints import (
_get_user_in_team,
)
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
if existing_key_row.team_id is None:
return False
team_table = await get_team_object(
team_id=existing_key_row.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_dict.parent_otel_span,
check_db_only=True,
)
# 4. Extract `Member` object from `team_table`
team_member_object = _get_user_in_team(
team_table=team_table, user_id=user_api_key_dict.user_id
)
return team_member_object is not None
@staticmethod
def get_all_available_team_member_permissions() -> List[str]:
"""
Returns all available team member permissions
"""
all_available_permissions = []
for route in LiteLLMRoutes.key_management_routes.value:
all_available_permissions.append(route.value)
return all_available_permissions
@staticmethod
def default_team_member_permissions() -> List[str]:
return [route.value for route in DEFAULT_TEAM_MEMBER_PERMISSIONS]

View file

@ -0,0 +1,35 @@
from typing import List, Optional
from pydantic import BaseModel
class GetTeamMemberPermissionsRequest(BaseModel):
"""Request to get the team member permissions for a team"""
team_id: str
class GetTeamMemberPermissionsResponse(BaseModel):
"""Response to get the team member permissions for a team"""
team_id: str
"""
The team id that the permissions are for
"""
team_member_permissions: Optional[List[str]] = []
"""
The team member permissions currently set for the team
"""
all_available_permissions: List[str]
"""
All available team member permissions
"""
class UpdateTeamMemberPermissionsRequest(BaseModel):
"""Request to update the team member permissions for a team"""
team_id: str
team_member_permissions: List[str]

View file

@ -0,0 +1,161 @@
import asyncio
import json
import os
import sys
import uuid
from typing import Optional, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../")
) # Adds the parent directory to the system path
from litellm.proxy._types import UserAPIKeyAuth # Import UserAPIKeyAuth
from litellm.proxy._types import LiteLLM_TeamTable, LitellmUserRoles
from litellm.proxy.management_endpoints.team_endpoints import (
user_api_key_auth, # Assuming this dependency is needed
)
from litellm.proxy.management_endpoints.team_endpoints import (
GetTeamMemberPermissionsResponse,
UpdateTeamMemberPermissionsRequest,
router,
)
from litellm.proxy.management_helpers.team_member_permission_checks import (
TeamMemberPermissionChecks,
)
from litellm.proxy.proxy_server import app
# Setup TestClient
client = TestClient(app)
# Mock prisma_client
mock_prisma_client = MagicMock()
# Fixture to provide the mock prisma client
@pytest.fixture(autouse=True)
def mock_db_client():
with patch(
"litellm.proxy.proxy_server.prisma_client", mock_prisma_client
): # Mock in both places if necessary
yield mock_prisma_client
mock_prisma_client.reset_mock()
# Fixture to provide a mock admin user auth object
@pytest.fixture
def mock_admin_auth():
mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN)
return mock_auth
# Test for /team/permissions_list endpoint (GET)
@pytest.mark.asyncio
async def test_get_team_permissions_list_success(mock_db_client, mock_admin_auth):
"""
Test successful retrieval of team member permissions.
"""
test_team_id = "test-team-123"
mock_team_data = {
"team_id": test_team_id,
"team_alias": "Test Team",
"team_member_permissions": ["/key/generate", "/key/update"],
"spend": 0.0,
}
mock_team_row = MagicMock()
mock_team_row.model_dump.return_value = mock_team_data
mock_db_client.db.litellm_teamtable.find_unique = AsyncMock(
return_value=mock_team_row
)
# Override the dependency for this test
app.dependency_overrides[user_api_key_auth] = lambda: mock_admin_auth
response = client.get(f"/team/permissions_list?team_id={test_team_id}")
assert response.status_code == 200
response_data = response.json()
assert response_data["team_id"] == test_team_id
assert (
response_data["team_member_permissions"]
== mock_team_data["team_member_permissions"]
)
assert (
response_data["all_available_permissions"]
== TeamMemberPermissionChecks.get_all_available_team_member_permissions()
)
mock_db_client.db.litellm_teamtable.find_unique.assert_awaited_once_with(
where={"team_id": test_team_id}
)
# Clean up dependency override
app.dependency_overrides = {}
# Test for /team/permissions_update endpoint (POST)
@pytest.mark.asyncio
async def test_update_team_permissions_success(mock_db_client, mock_admin_auth):
"""
Test successful update of team member permissions by an admin.
"""
test_team_id = "test-team-456"
update_payload = {
"team_id": test_team_id,
"team_member_permissions": ["/key/generate", "/key/update"],
}
mock_existing_team_data = {
"team_id": test_team_id,
"team_alias": "Existing Team",
"team_member_permissions": ["/key/list"],
"spend": 0.0,
"models": [],
}
mock_updated_team_data = {
**mock_existing_team_data,
"team_member_permissions": update_payload["team_member_permissions"],
}
mock_existing_team_row = MagicMock(spec=LiteLLM_TeamTable)
mock_existing_team_row.model_dump.return_value = mock_existing_team_data
# Set attributes directly if model_dump isn't enough for LiteLLM_TeamTable usage
for key, value in mock_existing_team_data.items():
setattr(mock_existing_team_row, key, value)
mock_updated_team_row = MagicMock(spec=LiteLLM_TeamTable)
mock_updated_team_row.model_dump.return_value = mock_updated_team_data
# Set attributes directly if model_dump isn't enough for LiteLLM_TeamTable usage
for key, value in mock_updated_team_data.items():
setattr(mock_updated_team_row, key, value)
mock_db_client.db.litellm_teamtable.find_unique = AsyncMock(
return_value=mock_existing_team_row
)
mock_db_client.db.litellm_teamtable.update = AsyncMock(
return_value=mock_updated_team_row
)
# Override the dependency for this test
app.dependency_overrides[user_api_key_auth] = lambda: mock_admin_auth
response = client.post("/team/permissions_update", json=update_payload)
assert response.status_code == 200
response_data = response.json()
# Use model_dump for comparison if the endpoint returns the Prisma model directly
assert response_data == mock_updated_team_row.model_dump()
mock_db_client.db.litellm_teamtable.find_unique.assert_awaited_once_with(
where={"team_id": test_team_id}
)
mock_db_client.db.litellm_teamtable.update.assert_awaited_once_with(
where={"team_id": test_team_id},
data={"team_member_permissions": update_payload["team_member_permissions"]},
)
# Clean up dependency override
app.dependency_overrides = {}