mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
[Feat - Team Member Permissions] - CRUD Endpoints for managing team member permissions (#9919)
* add team_member_permissions * add GetTeamMemberPermissionsRequest types * crud endpoint for team member permissions * test team member permissions CRUD * fix GetTeamMemberPermissionsRequest
This commit is contained in:
parent
2d6ad534bc
commit
91c0a794b9
5 changed files with 560 additions and 57 deletions
|
@ -191,6 +191,28 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMPydanticObjectBase):
|
|||
rpm_limit: Optional[int] = None
|
||||
|
||||
|
||||
class KeyManagementRoutes(str, enum.Enum):
|
||||
"""
|
||||
Enum for key management routes
|
||||
"""
|
||||
|
||||
# write routes
|
||||
KEY_GENERATE = "/key/generate"
|
||||
KEY_UPDATE = "/key/update"
|
||||
KEY_DELETE = "/key/delete"
|
||||
KEY_REGENERATE = "/key/regenerate"
|
||||
KEY_REGENERATE_WITH_PATH_PARAM = "/key/{key_id}/regenerate"
|
||||
KEY_BLOCK = "/key/block"
|
||||
KEY_UNBLOCK = "/key/unblock"
|
||||
|
||||
# info and health routes
|
||||
KEY_INFO = "/key/info"
|
||||
KEY_HEALTH = "/key/health"
|
||||
|
||||
# list routes
|
||||
KEY_LIST = "/key/list"
|
||||
|
||||
|
||||
class LiteLLMRoutes(enum.Enum):
|
||||
openai_route_names = [
|
||||
"chat_completion",
|
||||
|
@ -321,14 +343,19 @@ class LiteLLMRoutes(enum.Enum):
|
|||
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
||||
master_key_only_routes = ["/global/spend/reset"]
|
||||
|
||||
management_routes = [ # key
|
||||
"/key/generate",
|
||||
"/key/{token_id}/regenerate",
|
||||
"/key/update",
|
||||
"/key/delete",
|
||||
"/key/info",
|
||||
"/key/health",
|
||||
"/key/list",
|
||||
key_management_routes = [
|
||||
KeyManagementRoutes.KEY_GENERATE,
|
||||
KeyManagementRoutes.KEY_UPDATE,
|
||||
KeyManagementRoutes.KEY_DELETE,
|
||||
KeyManagementRoutes.KEY_INFO,
|
||||
KeyManagementRoutes.KEY_REGENERATE,
|
||||
KeyManagementRoutes.KEY_REGENERATE_WITH_PATH_PARAM,
|
||||
KeyManagementRoutes.KEY_LIST,
|
||||
KeyManagementRoutes.KEY_BLOCK,
|
||||
KeyManagementRoutes.KEY_UNBLOCK,
|
||||
]
|
||||
|
||||
management_routes = [
|
||||
# user
|
||||
"/user/new",
|
||||
"/user/update",
|
||||
|
@ -348,7 +375,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/model/update",
|
||||
"/model/delete",
|
||||
"/model/info",
|
||||
]
|
||||
] + key_management_routes
|
||||
|
||||
spend_tracking_routes = [
|
||||
# spend
|
||||
|
@ -618,9 +645,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
|
|||
allowed_cache_controls: Optional[list] = []
|
||||
config: Optional[dict] = {}
|
||||
permissions: Optional[dict] = {}
|
||||
model_max_budget: Optional[
|
||||
dict
|
||||
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
|
||||
model_max_budget: Optional[dict] = (
|
||||
{}
|
||||
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
model_rpm_limit: Optional[dict] = None
|
||||
|
@ -876,12 +903,12 @@ class NewCustomerRequest(BudgetNewRequest):
|
|||
alias: Optional[str] = None # human-friendly alias
|
||||
blocked: bool = False # allow/disallow requests for this end-user
|
||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
allowed_model_region: Optional[
|
||||
AllowedModelRegion
|
||||
] = None # require all user requests to use models in this specific region
|
||||
default_model: Optional[
|
||||
str
|
||||
] = None # if no equivalent model in allowed region - default all requests to this model
|
||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@ -903,12 +930,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
|
|||
blocked: bool = False # allow/disallow requests for this end-user
|
||||
max_budget: Optional[float] = None
|
||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
allowed_model_region: Optional[
|
||||
AllowedModelRegion
|
||||
] = None # require all user requests to use models in this specific region
|
||||
default_model: Optional[
|
||||
str
|
||||
] = None # if no equivalent model in allowed region - default all requests to this model
|
||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
|
||||
|
||||
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
|
||||
|
@ -1043,9 +1070,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
|
|||
|
||||
class AddTeamCallback(LiteLLMPydanticObjectBase):
|
||||
callback_name: str
|
||||
callback_type: Optional[
|
||||
Literal["success", "failure", "success_and_failure"]
|
||||
] = "success_and_failure"
|
||||
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
|
||||
"success_and_failure"
|
||||
)
|
||||
callback_vars: Dict[str, str]
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
@ -1110,6 +1137,7 @@ class LiteLLM_TeamTable(TeamBase):
|
|||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
model_id: Optional[int] = None
|
||||
team_member_permissions: Optional[List[str]] = None
|
||||
litellm_model_table: Optional[LiteLLM_ModelTable] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
|
@ -1302,9 +1330,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
|
|||
stored_in_db: Optional[bool]
|
||||
field_default_value: Any
|
||||
premium_field: bool = False
|
||||
nested_fields: Optional[
|
||||
List[FieldDetail]
|
||||
] = None # For nested dictionary or Pydantic fields
|
||||
nested_fields: Optional[List[FieldDetail]] = (
|
||||
None # For nested dictionary or Pydantic fields
|
||||
)
|
||||
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
|
||||
|
@ -1570,9 +1598,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
|
|||
budget_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
user: Optional[
|
||||
Any
|
||||
] = None # You might want to replace 'Any' with a more specific type if available
|
||||
user: Optional[Any] = (
|
||||
None # You might want to replace 'Any' with a more specific type if available
|
||||
)
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
@ -2156,6 +2184,11 @@ class ProxyErrorTypes(str, enum.Enum):
|
|||
Cache ping error
|
||||
"""
|
||||
|
||||
team_member_permission_error = "team_member_permission_error"
|
||||
"""
|
||||
Team member permission error
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_model_access_error_type_for_object(
|
||||
cls, object_type: Literal["key", "user", "team"]
|
||||
|
@ -2313,9 +2346,9 @@ class TeamModelDeleteRequest(BaseModel):
|
|||
# Organization Member Requests
|
||||
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
||||
organization_id: str
|
||||
max_budget_in_organization: Optional[
|
||||
float
|
||||
] = None # Users max budget within the organization
|
||||
max_budget_in_organization: Optional[float] = (
|
||||
None # Users max budget within the organization
|
||||
)
|
||||
|
||||
|
||||
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
|
||||
|
@ -2504,9 +2537,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
|
|||
Maps provider names to their budget configs.
|
||||
"""
|
||||
|
||||
providers: Dict[
|
||||
str, ProviderBudgetResponseObject
|
||||
] = {} # Dictionary mapping provider names to their budget configurations
|
||||
providers: Dict[str, ProviderBudgetResponseObject] = (
|
||||
{}
|
||||
) # Dictionary mapping provider names to their budget configurations
|
||||
|
||||
|
||||
class ProxyStateVariables(TypedDict):
|
||||
|
@ -2634,9 +2667,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||
enforce_rbac: bool = False
|
||||
roles_jwt_field: Optional[str] = None # v2 on role mappings
|
||||
role_mappings: Optional[List[RoleMapping]] = None
|
||||
object_id_jwt_field: Optional[
|
||||
str
|
||||
] = None # can be either user / team, inferred from the role mapping
|
||||
object_id_jwt_field: Optional[str] = (
|
||||
None # can be either user / team, inferred from the role mapping
|
||||
)
|
||||
scope_mappings: Optional[List[ScopeMapping]] = None
|
||||
enforce_scope_based_access: bool = False
|
||||
enforce_team_based_model_access: bool = False
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
TEAM MANAGEMENT
|
||||
|
||||
All /team management endpoints
|
||||
All /team management endpoints
|
||||
|
||||
/team/new
|
||||
/team/info
|
||||
|
@ -62,6 +62,9 @@ from litellm.proxy.management_endpoints.common_utils import (
|
|||
_is_user_team_admin,
|
||||
_set_object_metadata_field,
|
||||
)
|
||||
from litellm.proxy.management_helpers.team_member_permission_checks import (
|
||||
TeamMemberPermissionChecks,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
management_endpoint_wrapper,
|
||||
|
@ -72,6 +75,10 @@ from litellm.proxy.utils import (
|
|||
handle_exception_on_proxy,
|
||||
)
|
||||
from litellm.router import Router
|
||||
from litellm.types.proxy.management_endpoints.team_endpoints import (
|
||||
GetTeamMemberPermissionsResponse,
|
||||
UpdateTeamMemberPermissionsRequest,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -506,12 +513,12 @@ async def update_team(
|
|||
updated_kv["model_id"] = _model_id
|
||||
|
||||
updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
|
||||
team_row: Optional[
|
||||
LiteLLM_TeamTable
|
||||
] = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id},
|
||||
data=updated_kv,
|
||||
include={"litellm_model_table": True}, # type: ignore
|
||||
team_row: Optional[LiteLLM_TeamTable] = (
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id},
|
||||
data=updated_kv,
|
||||
include={"litellm_model_table": True}, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
if team_row is None or team_row.team_id is None:
|
||||
|
@ -1137,10 +1144,10 @@ async def delete_team(
|
|||
team_rows: List[LiteLLM_TeamTable] = []
|
||||
for team_id in data.team_ids:
|
||||
try:
|
||||
team_row_base: Optional[
|
||||
BaseModel
|
||||
] = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
team_row_base: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
)
|
||||
if team_row_base is None:
|
||||
raise Exception
|
||||
|
@ -1298,10 +1305,10 @@ async def team_info(
|
|||
)
|
||||
|
||||
try:
|
||||
team_info: Optional[
|
||||
BaseModel
|
||||
] = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
team_info: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
)
|
||||
if team_info is None:
|
||||
raise Exception
|
||||
|
@ -1926,3 +1933,89 @@ async def team_model_delete(
|
|||
)
|
||||
|
||||
return updated_team
|
||||
|
||||
|
||||
@router.get(
|
||||
"/team/permissions_list",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def team_member_permissions(
|
||||
team_id: str = fastapi.Query(
|
||||
default=None, description="Team ID in the request parameters"
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> GetTeamMemberPermissionsResponse:
|
||||
"""
|
||||
Get the team member permissions for a team
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
if team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team not found, passed team_id={team_id}"},
|
||||
)
|
||||
|
||||
team_obj = LiteLLM_TeamTable(**team_row.model_dump())
|
||||
if team_obj.team_member_permissions is None:
|
||||
team_obj.team_member_permissions = (
|
||||
TeamMemberPermissionChecks.default_team_member_permissions()
|
||||
)
|
||||
|
||||
return GetTeamMemberPermissionsResponse(
|
||||
team_id=team_id,
|
||||
team_member_permissions=team_obj.team_member_permissions,
|
||||
all_available_permissions=TeamMemberPermissionChecks.get_all_available_team_member_permissions(),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/permissions_update",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_team_member_permissions(
|
||||
data: UpdateTeamMemberPermissionsRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> LiteLLM_TeamTable:
|
||||
"""
|
||||
Update the team member permissions for a team
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
if team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Only proxy admin can update team member permissions"},
|
||||
)
|
||||
|
||||
# Update the team member permissions
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id},
|
||||
data={"team_member_permissions": data.team_member_permissions},
|
||||
)
|
||||
|
||||
return updated_team
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import (
|
||||
KeyManagementRoutes,
|
||||
LiteLLM_TeamTableCachedObj,
|
||||
LiteLLM_VerificationToken,
|
||||
LiteLLMRoutes,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
Span,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
DEFAULT_TEAM_MEMBER_PERMISSIONS = [
|
||||
KeyManagementRoutes.KEY_INFO,
|
||||
KeyManagementRoutes.KEY_HEALTH,
|
||||
]
|
||||
|
||||
|
||||
class TeamMemberPermissionChecks:
|
||||
@staticmethod
|
||||
def get_permissions_for_team_member(
|
||||
team_member_object: Member,
|
||||
team_table: LiteLLM_TeamTableCachedObj,
|
||||
) -> List[KeyManagementRoutes]:
|
||||
"""
|
||||
Returns the permissions for a team member
|
||||
"""
|
||||
if team_table.team_member_permissions and isinstance(
|
||||
team_table.team_member_permissions, list
|
||||
):
|
||||
return [
|
||||
KeyManagementRoutes(permission)
|
||||
for permission in team_table.team_member_permissions
|
||||
]
|
||||
|
||||
return DEFAULT_TEAM_MEMBER_PERMISSIONS
|
||||
|
||||
@staticmethod
|
||||
def _get_list_of_route_enum_as_str(
|
||||
route_enum: List[KeyManagementRoutes],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of the route enum as a list of strings
|
||||
"""
|
||||
return [route.value for route in route_enum]
|
||||
|
||||
@staticmethod
|
||||
async def can_team_member_execute_key_management_endpoint(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
route: KeyManagementRoutes,
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
existing_key_row: LiteLLM_VerificationToken,
|
||||
):
|
||||
"""
|
||||
Main handler for checking if a team member can update a key
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_get_user_in_team,
|
||||
)
|
||||
|
||||
# 1. Don't execute these checks if the user role is proxy admin
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
return
|
||||
|
||||
# 2. Check if the operation is being done on a team key
|
||||
if existing_key_row.team_id is None:
|
||||
return
|
||||
|
||||
# 3. Get Team Object from DB
|
||||
team_table = await get_team_object(
|
||||
team_id=existing_key_row.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
check_db_only=True,
|
||||
)
|
||||
|
||||
# 4. Extract `Member` object from `team_table`
|
||||
key_assigned_user_in_team = _get_user_in_team(
|
||||
team_table=team_table, user_id=user_api_key_dict.user_id
|
||||
)
|
||||
|
||||
# 5. Check if the team member has permissions for the endpoint
|
||||
TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint(
|
||||
team_member_object=key_assigned_user_in_team,
|
||||
team_table=team_table,
|
||||
route=route,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def does_team_member_have_permissions_for_endpoint(
|
||||
team_member_object: Optional[Member],
|
||||
team_table: LiteLLM_TeamTableCachedObj,
|
||||
route: str,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Raises an exception if the team member does not have permissions for calling the endpoint for a team
|
||||
"""
|
||||
|
||||
# permission checks only run for non-admin users
|
||||
# Non-Admin user trying to access information about a team's key
|
||||
if team_member_object is None:
|
||||
return False
|
||||
if team_member_object.role == "admin":
|
||||
return True
|
||||
|
||||
_team_member_permissions = (
|
||||
TeamMemberPermissionChecks.get_permissions_for_team_member(
|
||||
team_member_object=team_member_object,
|
||||
team_table=team_table,
|
||||
)
|
||||
)
|
||||
team_member_permissions = (
|
||||
TeamMemberPermissionChecks._get_list_of_route_enum_as_str(
|
||||
_team_member_permissions
|
||||
)
|
||||
)
|
||||
|
||||
if not RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=team_member_permissions
|
||||
):
|
||||
raise ProxyException(
|
||||
message=f"Team member does not have permissions for endpoint: {route}. You only have access to the following endpoints: {team_member_permissions}",
|
||||
type=ProxyErrorTypes.team_member_permission_error,
|
||||
param=route,
|
||||
code=401,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def user_belongs_to_keys_team(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
existing_key_row: LiteLLM_VerificationToken,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the user belongs to the team that the key is assigned to
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_get_user_in_team,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
|
||||
|
||||
if existing_key_row.team_id is None:
|
||||
return False
|
||||
team_table = await get_team_object(
|
||||
team_id=existing_key_row.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
check_db_only=True,
|
||||
)
|
||||
|
||||
# 4. Extract `Member` object from `team_table`
|
||||
team_member_object = _get_user_in_team(
|
||||
team_table=team_table, user_id=user_api_key_dict.user_id
|
||||
)
|
||||
return team_member_object is not None
|
||||
|
||||
@staticmethod
|
||||
def get_all_available_team_member_permissions() -> List[str]:
|
||||
"""
|
||||
Returns all available team member permissions
|
||||
"""
|
||||
all_available_permissions = []
|
||||
for route in LiteLLMRoutes.key_management_routes.value:
|
||||
all_available_permissions.append(route.value)
|
||||
return all_available_permissions
|
||||
|
||||
@staticmethod
|
||||
def default_team_member_permissions() -> List[str]:
|
||||
return [route.value for route in DEFAULT_TEAM_MEMBER_PERMISSIONS]
|
35
litellm/types/proxy/management_endpoints/team_endpoints.py
Normal file
35
litellm/types/proxy/management_endpoints/team_endpoints.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GetTeamMemberPermissionsRequest(BaseModel):
|
||||
"""Request to get the team member permissions for a team"""
|
||||
|
||||
team_id: str
|
||||
|
||||
|
||||
class GetTeamMemberPermissionsResponse(BaseModel):
|
||||
"""Response to get the team member permissions for a team"""
|
||||
|
||||
team_id: str
|
||||
"""
|
||||
The team id that the permissions are for
|
||||
"""
|
||||
|
||||
team_member_permissions: Optional[List[str]] = []
|
||||
"""
|
||||
The team member permissions currently set for the team
|
||||
"""
|
||||
|
||||
all_available_permissions: List[str]
|
||||
"""
|
||||
All available team member permissions
|
||||
"""
|
||||
|
||||
|
||||
class UpdateTeamMemberPermissionsRequest(BaseModel):
|
||||
"""Request to update the team member permissions for a team"""
|
||||
|
||||
team_id: str
|
||||
team_member_permissions: List[str]
|
161
tests/litellm/proxy/management_endpoints/test_team_endpoints.py
Normal file
161
tests/litellm/proxy/management_endpoints/test_team_endpoints.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Optional, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm.proxy._types import UserAPIKeyAuth # Import UserAPIKeyAuth
|
||||
from litellm.proxy._types import LiteLLM_TeamTable, LitellmUserRoles
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
user_api_key_auth, # Assuming this dependency is needed
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
GetTeamMemberPermissionsResponse,
|
||||
UpdateTeamMemberPermissionsRequest,
|
||||
router,
|
||||
)
|
||||
from litellm.proxy.management_helpers.team_member_permission_checks import (
|
||||
TeamMemberPermissionChecks,
|
||||
)
|
||||
from litellm.proxy.proxy_server import app
|
||||
|
||||
# Setup TestClient
|
||||
client = TestClient(app)
|
||||
|
||||
# Mock prisma_client
|
||||
mock_prisma_client = MagicMock()
|
||||
|
||||
|
||||
# Fixture to provide the mock prisma client
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db_client():
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client", mock_prisma_client
|
||||
): # Mock in both places if necessary
|
||||
yield mock_prisma_client
|
||||
mock_prisma_client.reset_mock()
|
||||
|
||||
|
||||
# Fixture to provide a mock admin user auth object
|
||||
@pytest.fixture
|
||||
def mock_admin_auth():
|
||||
mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
return mock_auth
|
||||
|
||||
|
||||
# Test for /team/permissions_list endpoint (GET)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_team_permissions_list_success(mock_db_client, mock_admin_auth):
|
||||
"""
|
||||
Test successful retrieval of team member permissions.
|
||||
"""
|
||||
test_team_id = "test-team-123"
|
||||
mock_team_data = {
|
||||
"team_id": test_team_id,
|
||||
"team_alias": "Test Team",
|
||||
"team_member_permissions": ["/key/generate", "/key/update"],
|
||||
"spend": 0.0,
|
||||
}
|
||||
mock_team_row = MagicMock()
|
||||
mock_team_row.model_dump.return_value = mock_team_data
|
||||
mock_db_client.db.litellm_teamtable.find_unique = AsyncMock(
|
||||
return_value=mock_team_row
|
||||
)
|
||||
|
||||
# Override the dependency for this test
|
||||
app.dependency_overrides[user_api_key_auth] = lambda: mock_admin_auth
|
||||
|
||||
response = client.get(f"/team/permissions_list?team_id={test_team_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["team_id"] == test_team_id
|
||||
assert (
|
||||
response_data["team_member_permissions"]
|
||||
== mock_team_data["team_member_permissions"]
|
||||
)
|
||||
assert (
|
||||
response_data["all_available_permissions"]
|
||||
== TeamMemberPermissionChecks.get_all_available_team_member_permissions()
|
||||
)
|
||||
mock_db_client.db.litellm_teamtable.find_unique.assert_awaited_once_with(
|
||||
where={"team_id": test_team_id}
|
||||
)
|
||||
|
||||
# Clean up dependency override
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
# Test for /team/permissions_update endpoint (POST)
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_team_permissions_success(mock_db_client, mock_admin_auth):
|
||||
"""
|
||||
Test successful update of team member permissions by an admin.
|
||||
"""
|
||||
test_team_id = "test-team-456"
|
||||
update_payload = {
|
||||
"team_id": test_team_id,
|
||||
"team_member_permissions": ["/key/generate", "/key/update"],
|
||||
}
|
||||
|
||||
mock_existing_team_data = {
|
||||
"team_id": test_team_id,
|
||||
"team_alias": "Existing Team",
|
||||
"team_member_permissions": ["/key/list"],
|
||||
"spend": 0.0,
|
||||
"models": [],
|
||||
}
|
||||
mock_updated_team_data = {
|
||||
**mock_existing_team_data,
|
||||
"team_member_permissions": update_payload["team_member_permissions"],
|
||||
}
|
||||
|
||||
mock_existing_team_row = MagicMock(spec=LiteLLM_TeamTable)
|
||||
mock_existing_team_row.model_dump.return_value = mock_existing_team_data
|
||||
# Set attributes directly if model_dump isn't enough for LiteLLM_TeamTable usage
|
||||
for key, value in mock_existing_team_data.items():
|
||||
setattr(mock_existing_team_row, key, value)
|
||||
|
||||
mock_updated_team_row = MagicMock(spec=LiteLLM_TeamTable)
|
||||
mock_updated_team_row.model_dump.return_value = mock_updated_team_data
|
||||
# Set attributes directly if model_dump isn't enough for LiteLLM_TeamTable usage
|
||||
for key, value in mock_updated_team_data.items():
|
||||
setattr(mock_updated_team_row, key, value)
|
||||
|
||||
mock_db_client.db.litellm_teamtable.find_unique = AsyncMock(
|
||||
return_value=mock_existing_team_row
|
||||
)
|
||||
mock_db_client.db.litellm_teamtable.update = AsyncMock(
|
||||
return_value=mock_updated_team_row
|
||||
)
|
||||
|
||||
# Override the dependency for this test
|
||||
app.dependency_overrides[user_api_key_auth] = lambda: mock_admin_auth
|
||||
|
||||
response = client.post("/team/permissions_update", json=update_payload)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
|
||||
# Use model_dump for comparison if the endpoint returns the Prisma model directly
|
||||
assert response_data == mock_updated_team_row.model_dump()
|
||||
|
||||
mock_db_client.db.litellm_teamtable.find_unique.assert_awaited_once_with(
|
||||
where={"team_id": test_team_id}
|
||||
)
|
||||
mock_db_client.db.litellm_teamtable.update.assert_awaited_once_with(
|
||||
where={"team_id": test_team_id},
|
||||
data={"team_member_permissions": update_payload["team_member_permissions"]},
|
||||
)
|
||||
|
||||
# Clean up dependency override
|
||||
app.dependency_overrides = {}
|
Loading…
Add table
Add a link
Reference in a new issue