[Feat] - SSO - Use MSFT Graph API to assign users to teams (#9865)

* refactor SSO handler

* render sso JWT on ui

* docs debug sso

* fix sso login flow use await

* fix ui sso debug JWT

* test ui sso

* remove redis vl

* fix redisvl==0.5.1

* fix ml dtypes

* fix redisvl

* fix redis vl

* fix debug_sso_callback

* fix linting error

* fix redis semantic caching dep

* working graph api assignment

* test msft sso handler openid

* testing for msft group assignment

* fix debug graph api sso flow

* fix linting errors

* add_user_to_teams_from_sso_response

* fix linting error
This commit is contained in:
Ishaan Jaff 2025-04-09 18:26:43 -07:00 committed by GitHub
parent a1433da4a7
commit 4c1bb74c3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 368 additions and 34 deletions

View file

@ -11,7 +11,7 @@ Has all /sso/* routes
import asyncio
import os
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
@ -19,6 +19,11 @@ from fastapi.responses import RedirectResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import (
LiteLLM_UserTable,
LitellmUserRoles,
@ -51,6 +56,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
from litellm.proxy.management_endpoints.types import CustomOpenID
from litellm.secret_managers.main import str_to_bool
from litellm.types.proxy.management_endpoints.ui_sso import *
if TYPE_CHECKING:
from fastapi_sso.sso.base import OpenID
@ -357,7 +363,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
request=request,
microsoft_client_id=microsoft_client_id,
redirect_url=redirect_url,
jwt_handler=jwt_handler,
)
elif generic_client_id is not None:
result = await get_generic_sso_response(
@ -490,8 +495,10 @@ async def auth_callback(request: Request): # noqa: PLR0915
user_role = (
user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
)
sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
result=result,
user_info=user_info,
)
except Exception as e:
verbose_proxy_logger.debug(
@ -835,18 +842,42 @@ class SSOAuthenticationHandler:
redirect_url += "/" + sso_callback_route
return redirect_url
@staticmethod
async def add_user_to_teams_from_sso_response(
result: Optional[Union[CustomOpenID, OpenID, dict]],
user_info: Union[NewUserResponse, LiteLLM_UserTable],
):
"""
Adds the user as a team member to the teams specified in the SSO responses `team_ids` field
The `team_ids` field is populated by litellm after processing the SSO response
"""
sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
class MicrosoftSSOHandler:
"""
Handles Microsoft SSO callback response and returns a CustomOpenID object
"""
graph_api_base_url = "https://graph.microsoft.com/v1.0"
graph_api_user_groups_endpoint = f"{graph_api_base_url}/me/memberOf"
"""
Constants
"""
MAX_GRAPH_API_PAGES = 200
# used for debugging to show the user groups litellm found from Graph API
GRAPH_API_RESPONSE_KEY = "graph_api_user_groups"
@staticmethod
async def get_microsoft_callback_response(
request: Request,
microsoft_client_id: str,
redirect_url: str,
jwt_handler: JWTHandler,
return_raw_sso_response: bool = False,
) -> Union[CustomOpenID, OpenID, dict]:
"""
@ -880,24 +911,34 @@ class MicrosoftSSOHandler:
redirect_uri=redirect_url,
allow_insecure_http=True,
)
original_msft_result = await microsoft_sso.verify_and_process(
request=request,
convert_response=False,
original_msft_result = (
await microsoft_sso.verify_and_process(
request=request,
convert_response=False,
)
or {}
)
user_team_ids = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
access_token=microsoft_sso.access_token
)
# if user is trying to get the raw sso response for debugging, return the raw sso response
if return_raw_sso_response:
original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = (
user_team_ids
)
return original_msft_result or {}
result = MicrosoftSSOHandler.openid_from_response(
response=original_msft_result,
jwt_handler=jwt_handler,
team_ids=user_team_ids,
)
return result
@staticmethod
def openid_from_response(
response: Optional[dict], jwt_handler: JWTHandler
response: Optional[dict], team_ids: List[str]
) -> CustomOpenID:
response = response or {}
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
@ -908,11 +949,110 @@ class MicrosoftSSOHandler:
id=response.get("id"),
first_name=response.get("givenName"),
last_name=response.get("surname"),
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
team_ids=team_ids,
)
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
return openid_response
@staticmethod
async def get_user_groups_from_graph_api(
access_token: Optional[str] = None,
) -> List[str]:
"""
Returns a list of `team_ids` the user belongs to from the Microsoft Graph API
Args:
access_token (Optional[str]): Microsoft Graph API access token
Returns:
List[str]: List of group IDs the user belongs to
"""
try:
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SSO_HANDLER
)
all_group_ids = []
next_link: Optional[str] = (
MicrosoftSSOHandler.graph_api_user_groups_endpoint
)
auth_headers = {"Authorization": f"Bearer {access_token}"}
page_count = 0
while (
next_link is not None
and page_count < MicrosoftSSOHandler.MAX_GRAPH_API_PAGES
):
group_ids, next_link = await MicrosoftSSOHandler.fetch_and_parse_groups(
url=next_link, headers=auth_headers, async_client=async_client
)
all_group_ids.extend(group_ids)
page_count += 1
if (
next_link is not None
and page_count >= MicrosoftSSOHandler.MAX_GRAPH_API_PAGES
):
verbose_proxy_logger.warning(
f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included."
)
return all_group_ids
except Exception as e:
verbose_proxy_logger.error(
f"Error getting user groups from Microsoft Graph API: {e}"
)
return []
@staticmethod
async def fetch_and_parse_groups(
url: str, headers: dict, async_client: AsyncHTTPHandler
) -> Tuple[List[str], Optional[str]]:
"""Helper function to fetch and parse group data from a URL"""
response = await async_client.get(url, headers=headers)
response_json = response.json()
response_typed = await MicrosoftSSOHandler._cast_graph_api_response_dict(
response=response_json
)
group_ids = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(
response=response_typed
)
return group_ids, response_typed.get("odata_nextLink")
@staticmethod
def _get_group_ids_from_graph_api_response(
response: MicrosoftGraphAPIUserGroupResponse,
) -> List[str]:
group_ids = []
for _object in response.get("value", []) or []:
_group_id = _object.get("id")
if _group_id is not None:
group_ids.append(_group_id)
return group_ids
@staticmethod
async def _cast_graph_api_response_dict(
response: dict,
) -> MicrosoftGraphAPIUserGroupResponse:
directory_objects: List[MicrosoftGraphAPIUserGroupDirectoryObject] = []
for _object in response.get("value", []):
directory_objects.append(
MicrosoftGraphAPIUserGroupDirectoryObject(
odata_type=_object.get("@odata.type"),
id=_object.get("id"),
deletedDateTime=_object.get("deletedDateTime"),
description=_object.get("description"),
displayName=_object.get("displayName"),
roleTemplateId=_object.get("roleTemplateId"),
)
)
return MicrosoftGraphAPIUserGroupResponse(
odata_context=response.get("@odata.context"),
odata_nextLink=response.get("@odata.nextLink"),
value=directory_objects,
)
class GoogleSSOHandler:
"""
@ -1046,9 +1186,9 @@ async def debug_sso_callback(request: Request):
request=request,
microsoft_client_id=microsoft_client_id,
redirect_url=redirect_url,
jwt_handler=jwt_handler,
return_raw_sso_response=True,
)
elif generic_client_id is not None:
result = await get_generic_sso_response(
request=request,