From 0601ae55c49f4f01e356acf7e7309bedd7f70c12 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 9 Apr 2025 13:58:32 -0700 Subject: [PATCH] working graph api assignment --- litellm/proxy/management_endpoints/ui_sso.py | 111 +++++++++++++++++- litellm/types/llms/custom_http.py | 1 + .../proxy/management_endpoints/ui_sso.py | 20 ++++ 3 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 litellm/types/proxy/management_endpoints/ui_sso.py diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index c9388bc4eb..96c6f91a5b 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -11,7 +11,7 @@ Has all /sso/* routes import asyncio import os import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse @@ -19,6 +19,11 @@ from fastapi.responses import RedirectResponse import litellm from litellm._logging import verbose_proxy_logger from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.proxy._types import ( LiteLLM_UserTable, LitellmUserRoles, @@ -51,6 +56,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import ( from litellm.proxy.management_endpoints.team_endpoints import team_member_add from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.secret_managers.main import str_to_bool +from litellm.types.proxy.management_endpoints.ui_sso import * if TYPE_CHECKING: from fastapi_sso.sso.base import OpenID @@ -841,6 +847,10 @@ class MicrosoftSSOHandler: Handles Microsoft SSO callback response and returns a CustomOpenID object """ + graph_api_base_url = "https://graph.microsoft.com/v1.0" + graph_api_user_groups_endpoint = f"{graph_api_base_url}/me/memberOf" + MAX_GRAPH_API_PAGES = 200 + @staticmethod async def get_microsoft_callback_response( request: Request, @@ -889,15 +899,20 @@ class MicrosoftSSOHandler: if return_raw_sso_response: return original_msft_result or {} + user_team_ids = await MicrosoftSSOHandler.get_user_groups_from_graph_api( + access_token=microsoft_sso.access_token + ) + result = MicrosoftSSOHandler.openid_from_response( response=original_msft_result, jwt_handler=jwt_handler, + team_ids=user_team_ids, ) return result @staticmethod def openid_from_response( - response: Optional[dict], jwt_handler: JWTHandler + response: Optional[dict], jwt_handler: JWTHandler, team_ids: List[str] ) -> CustomOpenID: response = response or {} verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") @@ -908,11 +923,101 @@ class MicrosoftSSOHandler: id=response.get("id"), first_name=response.get("givenName"), last_name=response.get("surname"), - team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)), + team_ids=team_ids, ) verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}") return openid_response + @staticmethod + async def get_user_groups_from_graph_api( + access_token: Optional[str] = None, + ) -> List[str]: + """ + Returns a list of `team_ids` the user belongs to from the Microsoft Graph API + + Args: + access_token (Optional[str]): Microsoft Graph API access token + + Returns: + List[str]: List of group IDs the user belongs to + """ + try: + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SSO_HANDLER + ) + + all_group_ids = [] + next_link = MicrosoftSSOHandler.graph_api_user_groups_endpoint + auth_headers = {"Authorization": f"Bearer {access_token}"} + page_count = 0 + + while next_link and page_count < MicrosoftSSOHandler.MAX_GRAPH_API_PAGES: + group_ids, next_link = await MicrosoftSSOHandler.fetch_and_parse_groups( + url=next_link, headers=auth_headers, async_client=async_client + ) + all_group_ids.extend(group_ids) + page_count += 1 + + if next_link and page_count >= MicrosoftSSOHandler.MAX_GRAPH_API_PAGES: + verbose_proxy_logger.warning( + f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included." + ) + + return all_group_ids + + except Exception as e: + verbose_proxy_logger.error( + f"Error getting user groups from Microsoft Graph API: {e}" + ) + return [] + + @staticmethod + async def fetch_and_parse_groups( + url: str, headers: dict, async_client: AsyncHTTPHandler + ) -> Tuple[List[str], Optional[str]]: + """Helper function to fetch and parse group data from a URL""" + response = await async_client.get(url, headers=headers) + response_json = response.json() + response_typed = await MicrosoftSSOHandler._cast_graph_api_response_dict( + response=response_json + ) + group_ids = MicrosoftSSOHandler._get_group_ids_from_graph_api_response( + response=response_typed + ) + return group_ids, response_typed.get("odata_nextLink") + + @staticmethod + def _get_group_ids_from_graph_api_response( + response: MicrosoftGraphAPIUserGroupResponse, + ) -> List[str]: + group_ids = [] + for _object in response.get("value", []) or []: + if _object.get("id") is not None: + group_ids.append(_object.get("id")) + return group_ids + + @staticmethod + async def _cast_graph_api_response_dict( + response: dict, + ) -> MicrosoftGraphAPIUserGroupResponse: + directory_objects: List[MicrosoftGraphAPIUserGroupDirectoryObject] = [] + for _object in response.get("value", []): + directory_objects.append( + MicrosoftGraphAPIUserGroupDirectoryObject( + odata_type=_object.get("@odata.type"), + id=_object.get("id"), + deletedDateTime=_object.get("deletedDateTime"), + description=_object.get("description"), + displayName=_object.get("displayName"), + roleTemplateId=_object.get("roleTemplateId"), + ) + ) + return MicrosoftGraphAPIUserGroupResponse( + odata_context=response.get("@odata.context"), + odata_nextLink=response.get("@odata.nextLink"), + value=directory_objects, + ) + class GoogleSSOHandler: """ diff --git a/litellm/types/llms/custom_http.py b/litellm/types/llms/custom_http.py index 5eec187dd4..8759dedec6 100644 --- a/litellm/types/llms/custom_http.py +++ b/litellm/types/llms/custom_http.py @@ -19,6 +19,7 @@ class httpxSpecialProvider(str, Enum): SecretManager = "secret_manager" PassThroughEndpoint = "pass_through_endpoint" PromptFactory = "prompt_factory" + SSO_HANDLER = "sso_handler" VerifyTypes = Union[str, bool, ssl.SSLContext] diff --git a/litellm/types/proxy/management_endpoints/ui_sso.py b/litellm/types/proxy/management_endpoints/ui_sso.py new file mode 100644 index 0000000000..a706577f3d --- /dev/null +++ b/litellm/types/proxy/management_endpoints/ui_sso.py @@ -0,0 +1,20 @@ +from typing import List, Optional, TypedDict + + +class MicrosoftGraphAPIUserGroupDirectoryObject(TypedDict, total=False): + """Model for Microsoft Graph API directory object""" + + odata_type: Optional[str] + id: Optional[str] + deletedDateTime: Optional[str] + description: Optional[str] + displayName: Optional[str] + roleTemplateId: Optional[str] + + +class MicrosoftGraphAPIUserGroupResponse(TypedDict, total=False): + """Model for Microsoft Graph API user groups response""" + + odata_context: Optional[str] + odata_nextLink: Optional[str] + value: Optional[List[MicrosoftGraphAPIUserGroupDirectoryObject]]