mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
create_litellm_teams_from_service_principal_team_ids
This commit is contained in:
parent
5a91142d31
commit
958c284957
1 changed files with 132 additions and 1 deletions
|
@ -28,6 +28,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
NewTeamRequest,
|
||||
NewUserRequest,
|
||||
NewUserResponse,
|
||||
ProxyErrorTypes,
|
||||
|
@ -53,7 +54,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
|||
check_is_admin_only_access,
|
||||
has_admin_ui_access,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
||||
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
@ -1011,6 +1012,27 @@ class MicrosoftSSOHandler:
|
|||
llm_provider=httpxSpecialProvider.SSO_HANDLER
|
||||
)
|
||||
|
||||
# Handle MSFT Enterprise Application Groups
|
||||
service_principal_id = os.getenv("MICROSOFT_SERVICE_PRINCIPAL_ID", None)
|
||||
service_principal_group_ids: Optional[List[str]] = []
|
||||
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
|
||||
if service_principal_id:
|
||||
service_principal_group_ids, service_principal_teams = (
|
||||
await MicrosoftSSOHandler.get_group_ids_from_service_principal(
|
||||
service_principal_id=service_principal_id,
|
||||
async_client=async_client,
|
||||
access_token=access_token,
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Service principal group IDs: {service_principal_group_ids}"
|
||||
)
|
||||
if len(service_principal_group_ids) > 0:
|
||||
await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
|
||||
service_principal_teams=service_principal_teams,
|
||||
)
|
||||
|
||||
# Fetch user membership from Microsoft Graph API
|
||||
all_group_ids = []
|
||||
next_link: Optional[str] = (
|
||||
MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||||
|
@ -1036,6 +1058,14 @@ class MicrosoftSSOHandler:
|
|||
f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included."
|
||||
)
|
||||
|
||||
# If service_principal_group_ids is not empty, only return group_ids that are in both all_group_ids and service_principal_group_ids
|
||||
if len(service_principal_group_ids) > 0:
|
||||
all_group_ids = [
|
||||
group_id
|
||||
for group_id in all_group_ids
|
||||
if group_id in service_principal_group_ids
|
||||
]
|
||||
|
||||
return all_group_ids
|
||||
|
||||
except Exception as e:
|
||||
|
@ -1092,6 +1122,107 @@ class MicrosoftSSOHandler:
|
|||
value=directory_objects,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_group_ids_from_service_principal(
|
||||
service_principal_id: str,
|
||||
async_client: AsyncHTTPHandler,
|
||||
access_token: Optional[str] = None,
|
||||
) -> Tuple[List[str], List[MicrosoftServicePrincipalTeam]]:
|
||||
"""
|
||||
Gets the groups belonging to the Service Principal Application
|
||||
|
||||
Service Principal Id is an `Enterprise Application` in Azure AD
|
||||
|
||||
Users use Enterprise Applications to manage Groups and Users on Microsoft Entra ID
|
||||
"""
|
||||
base_url = "https://graph.microsoft.com/v1.0"
|
||||
# Endpoint to get app role assignments for the given service principal
|
||||
endpoint = f"/servicePrincipals/{service_principal_id}/appRoleAssignedTo"
|
||||
url = base_url + endpoint
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await async_client.get(url, headers=headers)
|
||||
response_json = response.json()
|
||||
verbose_proxy_logger.debug(
|
||||
f"Response from service principal app role assigned to: {response_json}"
|
||||
)
|
||||
group_ids: List[str] = []
|
||||
service_principal_teams: List[MicrosoftServicePrincipalTeam] = []
|
||||
|
||||
for _object in response_json.get("value", []):
|
||||
if _object.get("principalType") == "Group":
|
||||
# Append the group ID to the list
|
||||
group_ids.append(_object.get("principalId"))
|
||||
# Append the service principal team to the list
|
||||
service_principal_teams.append(
|
||||
MicrosoftServicePrincipalTeam(
|
||||
principalDisplayName=_object.get("principalDisplayName"),
|
||||
principalId=_object.get("principalId"),
|
||||
)
|
||||
)
|
||||
|
||||
return group_ids, service_principal_teams
|
||||
|
||||
@staticmethod
|
||||
async def create_litellm_teams_from_service_principal_team_ids(
|
||||
service_principal_teams: List[MicrosoftServicePrincipalTeam],
|
||||
):
|
||||
"""
|
||||
Creates Litellm Teams from the Service Principal Group IDs
|
||||
|
||||
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise ProxyException(
|
||||
message="Prisma client not found. Set it in the proxy_server.py file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="prisma_client",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
|
||||
)
|
||||
for service_principal_team in service_principal_teams:
|
||||
litellm_team_id: Optional[str] = service_principal_team.get("principalId")
|
||||
litellm_team_name: Optional[str] = service_principal_team.get(
|
||||
"principalDisplayName"
|
||||
)
|
||||
if litellm_team_id:
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}"
|
||||
)
|
||||
|
||||
team_obj = await prisma_client.db.litellm_teamtable.find_first(
|
||||
where={"team_id": litellm_team_id}
|
||||
)
|
||||
verbose_proxy_logger.debug(f"Team object: {team_obj}")
|
||||
if team_obj:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Team already exists: {litellm_team_id} - {litellm_team_name}"
|
||||
)
|
||||
continue
|
||||
await new_team(
|
||||
data=NewTeamRequest(
|
||||
team_id=litellm_team_id,
|
||||
team_alias=litellm_team_name,
|
||||
),
|
||||
# params used for Audit Logging
|
||||
http_request=Request(scope={"type": "http", "method": "POST"}),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
token="",
|
||||
key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
|
||||
|
||||
|
||||
class GoogleSSOHandler:
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue