working team / group provisioning on SCIM

This commit is contained in:
Ishaan Jaff 2025-04-16 16:09:32 -07:00
parent 251c39846b
commit a6e2988efb
2 changed files with 86 additions and 77 deletions

View file

@ -1146,6 +1146,7 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
litellm_model_table: Optional[LiteLLM_ModelTable] = None
updated_at: Optional[datetime] = None
created_at: Optional[datetime] = None
model_config = ConfigDict(protected_namespaces=())

View file

@ -8,12 +8,30 @@ and integration purposes.
import uuid
from typing import List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Response
from fastapi.responses import JSONResponse
from fastapi import (
APIRouter,
Body,
Depends,
HTTPException,
Path,
Query,
Request,
Response,
)
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import LiteLLM_UserTable, NewUserRequest, NewUserResponse
from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_UserTable,
LitellmUserRoles,
Member,
NewTeamRequest,
NewUserRequest,
NewUserResponse,
UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.types.proxy.management_endpoints.scim_v2 import *
scim_router = APIRouter(
@ -34,6 +52,7 @@ class ScimTransformations:
DEFAULT_SCIM_NAME = "Unknown User"
DEFAULT_SCIM_FAMILY_NAME = "Unknown Family Name"
DEFAULT_SCIM_DISPLAY_NAME = "Unknown Display Name"
DEFAULT_SCIM_MEMBER_VALUE = "Unknown Member Value"
@staticmethod
async def transform_litellm_user_to_scim_user(
@ -124,9 +143,55 @@ class ScimTransformations:
return scim_metadata.givenName
if user.user_alias and len(user.user_alias) > 0:
return user.user_alias
return user.user_alias or ScimTransformations.DEFAULT_SCIM_NAME
return ScimTransformations.DEFAULT_SCIM_NAME
@staticmethod
async def transform_litellm_team_to_scim_group(
team: Union[LiteLLM_TeamTable, dict],
) -> SCIMGroup:
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail={"error": "No database connected"}
)
if isinstance(team, dict):
team = LiteLLM_TeamTable(**team)
# Get team members
scim_members: List[SCIMMember] = []
for member in team.members_with_roles or []:
scim_members.append(
SCIMMember(
value=ScimTransformations._get_scim_member_value(member),
display=member.user_email,
)
)
team_alias = getattr(team, "team_alias", team.team_id)
team_created_at = team.created_at.isoformat() if team.created_at else None
team_updated_at = team.updated_at.isoformat() if team.updated_at else None
return SCIMGroup(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=team.team_id,
displayName=team_alias,
members=scim_members,
meta={
"resourceType": "Group",
"created": team_created_at,
"lastModified": team_updated_at,
},
)
@staticmethod
def _get_scim_member_value(member: Member) -> str:
if member.user_email:
return member.user_email
return ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
# User Endpoints
@scim_router.get(
@ -501,31 +566,10 @@ async def get_group(
detail={"error": f"Group not found with ID: {group_id}"},
)
# Get team members
members = []
for member_id in team.members or []:
member = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": member_id}
)
if member:
display_name = member.user_email or member.user_id
members.append(SCIMMember(value=member.user_id, display=display_name))
team_alias = getattr(team, "team_alias", team.team_id)
team_created_at = team.created_at.isoformat() if team.created_at else None
team_updated_at = team.updated_at.isoformat() if team.updated_at else None
return SCIMGroup(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=team.team_id,
displayName=team_alias,
members=members,
meta={
"resourceType": "Group",
"created": team_created_at,
"lastModified": team_updated_at,
},
scim_group = await ScimTransformations.transform_litellm_team_to_scim_group(
team
)
return scim_group
except HTTPException:
raise
@ -568,7 +612,7 @@ async def create_group(
)
# Extract members
member_ids = []
members_with_roles: List[Member] = []
if group.members:
for member in group.members:
# Check if user exists
@ -576,59 +620,23 @@ async def create_group(
where={"user_id": member.value}
)
if user:
member_ids.append(member.value)
members_with_roles.append(Member(user_id=member.value, role="user"))
# Create team in database
new_team = await prisma_client.db.litellm_teamtable.create(
data={
"team_id": team_id,
"team_alias": group.displayName,
"members": member_ids,
"metadata": {"scim_data": group.model_dump()},
}
created_team = await new_team(
data=NewTeamRequest(
team_id=team_id,
team_alias=group.displayName,
members_with_roles=members_with_roles,
),
http_request=Request(scope={"type": "http", "path": "/scim/v2/Groups"}),
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
# For each member, update their teams list
for member_id in member_ids:
user = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": member_id}
)
if user:
current_teams = user.teams or []
if team_id not in current_teams:
await prisma_client.db.litellm_usertable.update(
where={"user_id": member_id}, data={"teams": {"push": team_id}}
)
# Get updated members for response
members = []
for member_id in member_ids:
user = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": member_id}
)
if user:
display_name = user.user_email or user.user_id
members.append(SCIMMember(value=user.user_id, display=display_name))
team_created_at = (
new_team.created_at.isoformat() if new_team.created_at else None
scim_group = await ScimTransformations.transform_litellm_team_to_scim_group(
created_team
)
team_updated_at = (
new_team.updated_at.isoformat() if new_team.updated_at else None
)
return SCIMGroup(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=team_id,
displayName=group.displayName,
members=members,
meta={
"resourceType": "Group",
"created": team_created_at,
"lastModified": team_updated_at,
},
)
return scim_group
except HTTPException:
raise
except Exception as e: