diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index c9388bc4eb..2fe86d4e6c 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -11,7 +11,7 @@ Has all /sso/* routes import asyncio import os import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse @@ -19,6 +19,11 @@ from fastapi.responses import RedirectResponse import litellm from litellm._logging import verbose_proxy_logger from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.proxy._types import ( LiteLLM_UserTable, LitellmUserRoles, @@ -51,6 +56,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import ( from litellm.proxy.management_endpoints.team_endpoints import team_member_add from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.secret_managers.main import str_to_bool +from litellm.types.proxy.management_endpoints.ui_sso import * if TYPE_CHECKING: from fastapi_sso.sso.base import OpenID @@ -357,7 +363,6 @@ async def auth_callback(request: Request): # noqa: PLR0915 request=request, microsoft_client_id=microsoft_client_id, redirect_url=redirect_url, - jwt_handler=jwt_handler, ) elif generic_client_id is not None: result = await get_generic_sso_response( @@ -490,8 +495,10 @@ async def auth_callback(request: Request): # noqa: PLR0915 user_role = ( user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY ) - sso_teams = getattr(result, "team_ids", []) - await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) + await SSOAuthenticationHandler.add_user_to_teams_from_sso_response( + result=result, + user_info=user_info, + ) except Exception as e: verbose_proxy_logger.debug( @@ -835,18 +842,42 @@ class SSOAuthenticationHandler: redirect_url += "/" + sso_callback_route return redirect_url + @staticmethod + async def add_user_to_teams_from_sso_response( + result: Optional[Union[CustomOpenID, OpenID, dict]], + user_info: Union[NewUserResponse, LiteLLM_UserTable], + ): + """ + Adds the user as a team member to the teams specified in the SSO responses `team_ids` field + + + The `team_ids` field is populated by litellm after processing the SSO response + """ + sso_teams = getattr(result, "team_ids", []) + await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) + class MicrosoftSSOHandler: """ Handles Microsoft SSO callback response and returns a CustomOpenID object """ + graph_api_base_url = "https://graph.microsoft.com/v1.0" + graph_api_user_groups_endpoint = f"{graph_api_base_url}/me/memberOf" + + """ + Constants + """ + MAX_GRAPH_API_PAGES = 200 + + # used for debugging to show the user groups litellm found from Graph API + GRAPH_API_RESPONSE_KEY = "graph_api_user_groups" + @staticmethod async def get_microsoft_callback_response( request: Request, microsoft_client_id: str, redirect_url: str, - jwt_handler: JWTHandler, return_raw_sso_response: bool = False, ) -> Union[CustomOpenID, OpenID, dict]: """ @@ -880,24 +911,34 @@ class MicrosoftSSOHandler: redirect_uri=redirect_url, allow_insecure_http=True, ) - original_msft_result = await microsoft_sso.verify_and_process( - request=request, - convert_response=False, + original_msft_result = ( + await microsoft_sso.verify_and_process( + request=request, + convert_response=False, + ) + or {} + ) + + user_team_ids = await MicrosoftSSOHandler.get_user_groups_from_graph_api( + access_token=microsoft_sso.access_token ) # if user is trying to get the raw sso response for debugging, return the raw sso response if return_raw_sso_response: + original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = ( + user_team_ids + ) return original_msft_result or {} result = MicrosoftSSOHandler.openid_from_response( response=original_msft_result, - jwt_handler=jwt_handler, + team_ids=user_team_ids, ) return result @staticmethod def openid_from_response( - response: Optional[dict], jwt_handler: JWTHandler + response: Optional[dict], team_ids: List[str] ) -> CustomOpenID: response = response or {} verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") @@ -908,11 +949,110 @@ class MicrosoftSSOHandler: id=response.get("id"), first_name=response.get("givenName"), last_name=response.get("surname"), - team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)), + team_ids=team_ids, ) verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}") return openid_response + @staticmethod + async def get_user_groups_from_graph_api( + access_token: Optional[str] = None, + ) -> List[str]: + """ + Returns a list of `team_ids` the user belongs to from the Microsoft Graph API + + Args: + access_token (Optional[str]): Microsoft Graph API access token + + Returns: + List[str]: List of group IDs the user belongs to + """ + try: + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SSO_HANDLER + ) + + all_group_ids = [] + next_link: Optional[str] = ( + MicrosoftSSOHandler.graph_api_user_groups_endpoint + ) + auth_headers = {"Authorization": f"Bearer {access_token}"} + page_count = 0 + + while ( + next_link is not None + and page_count < MicrosoftSSOHandler.MAX_GRAPH_API_PAGES + ): + group_ids, next_link = await MicrosoftSSOHandler.fetch_and_parse_groups( + url=next_link, headers=auth_headers, async_client=async_client + ) + all_group_ids.extend(group_ids) + page_count += 1 + + if ( + next_link is not None + and page_count >= MicrosoftSSOHandler.MAX_GRAPH_API_PAGES + ): + verbose_proxy_logger.warning( + f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included." + ) + + return all_group_ids + + except Exception as e: + verbose_proxy_logger.error( + f"Error getting user groups from Microsoft Graph API: {e}" + ) + return [] + + @staticmethod + async def fetch_and_parse_groups( + url: str, headers: dict, async_client: AsyncHTTPHandler + ) -> Tuple[List[str], Optional[str]]: + """Helper function to fetch and parse group data from a URL""" + response = await async_client.get(url, headers=headers) + response_json = response.json() + response_typed = await MicrosoftSSOHandler._cast_graph_api_response_dict( + response=response_json + ) + group_ids = MicrosoftSSOHandler._get_group_ids_from_graph_api_response( + response=response_typed + ) + return group_ids, response_typed.get("odata_nextLink") + + @staticmethod + def _get_group_ids_from_graph_api_response( + response: MicrosoftGraphAPIUserGroupResponse, + ) -> List[str]: + group_ids = [] + for _object in response.get("value", []) or []: + _group_id = _object.get("id") + if _group_id is not None: + group_ids.append(_group_id) + return group_ids + + @staticmethod + async def _cast_graph_api_response_dict( + response: dict, + ) -> MicrosoftGraphAPIUserGroupResponse: + directory_objects: List[MicrosoftGraphAPIUserGroupDirectoryObject] = [] + for _object in response.get("value", []): + directory_objects.append( + MicrosoftGraphAPIUserGroupDirectoryObject( + odata_type=_object.get("@odata.type"), + id=_object.get("id"), + deletedDateTime=_object.get("deletedDateTime"), + description=_object.get("description"), + displayName=_object.get("displayName"), + roleTemplateId=_object.get("roleTemplateId"), + ) + ) + return MicrosoftGraphAPIUserGroupResponse( + odata_context=response.get("@odata.context"), + odata_nextLink=response.get("@odata.nextLink"), + value=directory_objects, + ) + class GoogleSSOHandler: """ @@ -1046,9 +1186,9 @@ async def debug_sso_callback(request: Request): request=request, microsoft_client_id=microsoft_client_id, redirect_url=redirect_url, - jwt_handler=jwt_handler, return_raw_sso_response=True, ) + elif generic_client_id is not None: result = await get_generic_sso_response( request=request, diff --git a/litellm/types/llms/custom_http.py b/litellm/types/llms/custom_http.py index 5eec187dd4..8759dedec6 100644 --- a/litellm/types/llms/custom_http.py +++ b/litellm/types/llms/custom_http.py @@ -19,6 +19,7 @@ class httpxSpecialProvider(str, Enum): SecretManager = "secret_manager" PassThroughEndpoint = "pass_through_endpoint" PromptFactory = "prompt_factory" + SSO_HANDLER = "sso_handler" VerifyTypes = Union[str, bool, ssl.SSLContext] diff --git a/litellm/types/proxy/management_endpoints/ui_sso.py b/litellm/types/proxy/management_endpoints/ui_sso.py new file mode 100644 index 0000000000..a706577f3d --- /dev/null +++ b/litellm/types/proxy/management_endpoints/ui_sso.py @@ -0,0 +1,20 @@ +from typing import List, Optional, TypedDict + + +class MicrosoftGraphAPIUserGroupDirectoryObject(TypedDict, total=False): + """Model for Microsoft Graph API directory object""" + + odata_type: Optional[str] + id: Optional[str] + deletedDateTime: Optional[str] + description: Optional[str] + displayName: Optional[str] + roleTemplateId: Optional[str] + + +class MicrosoftGraphAPIUserGroupResponse(TypedDict, total=False): + """Model for Microsoft Graph API user groups response""" + + odata_context: Optional[str] + odata_nextLink: Optional[str] + value: Optional[List[MicrosoftGraphAPIUserGroupDirectoryObject]] diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py index b785b01f8c..606f3833be 100644 --- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py @@ -19,6 +19,10 @@ from litellm.proxy.management_endpoints.ui_sso import ( GoogleSSOHandler, MicrosoftSSOHandler, ) +from litellm.types.proxy.management_endpoints.ui_sso import ( + MicrosoftGraphAPIUserGroupDirectoryObject, + MicrosoftGraphAPIUserGroupResponse, +) def test_microsoft_sso_handler_openid_from_response(): @@ -32,23 +36,14 @@ def test_microsoft_sso_handler_openid_from_response(): "surname": "User", "some_other_field": "value", } - - # Create a mock JWTHandler that returns predetermined team IDs - mock_jwt_handler = MagicMock(spec=JWTHandler) expected_team_ids = ["team1", "team2"] - mock_jwt_handler.get_team_ids_from_jwt.return_value = expected_team_ids - # Act # Call the method being tested result = MicrosoftSSOHandler.openid_from_response( - response=mock_response, jwt_handler=mock_jwt_handler + response=mock_response, team_ids=expected_team_ids ) # Assert - # Verify the JWT handler was called with the correct parameters - mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with( - cast(dict, mock_response) - ) # Check that the result is a CustomOpenID object with the expected values assert isinstance(result, CustomOpenID) @@ -64,13 +59,9 @@ def test_microsoft_sso_handler_openid_from_response(): def test_microsoft_sso_handler_with_empty_response(): # Arrange # Test with None response - mock_jwt_handler = MagicMock(spec=JWTHandler) - mock_jwt_handler.get_team_ids_from_jwt.return_value = [] # Act - result = MicrosoftSSOHandler.openid_from_response( - response=None, jwt_handler=mock_jwt_handler - ) + result = MicrosoftSSOHandler.openid_from_response(response=None, team_ids=[]) # Assert assert isinstance(result, CustomOpenID) @@ -82,14 +73,10 @@ def test_microsoft_sso_handler_with_empty_response(): assert result.last_name is None assert result.team_ids == [] - # Make sure the JWT handler was called with an empty dict - mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with({}) - def test_get_microsoft_callback_response(): # Arrange mock_request = MagicMock(spec=Request) - mock_jwt_handler = MagicMock(spec=JWTHandler) mock_response = { "mail": "microsoft_user@example.com", "displayName": "Microsoft User", @@ -115,7 +102,6 @@ def test_get_microsoft_callback_response(): request=mock_request, microsoft_client_id="mock_client_id", redirect_url="http://mock_redirect_url", - jwt_handler=mock_jwt_handler, ) ) @@ -132,7 +118,6 @@ def test_get_microsoft_callback_response(): def test_get_microsoft_callback_response_raw_sso_response(): # Arrange mock_request = MagicMock(spec=Request) - mock_jwt_handler = MagicMock(spec=JWTHandler) mock_response = { "mail": "microsoft_user@example.com", "displayName": "Microsoft User", @@ -157,7 +142,6 @@ def test_get_microsoft_callback_response_raw_sso_response(): request=mock_request, microsoft_client_id="mock_client_id", redirect_url="http://mock_redirect_url", - jwt_handler=mock_jwt_handler, return_raw_sso_response=True, ) ) @@ -206,3 +190,192 @@ def test_get_google_callback_response(): assert result.get("sub") == "google123" assert result.get("given_name") == "Google" assert result.get("family_name") == "User" + + +@pytest.mark.asyncio +async def test_get_user_groups_from_graph_api(): + # Arrange + mock_response = { + "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", + "value": [ + { + "@odata.type": "#microsoft.graph.group", + "id": "group1", + "displayName": "Group 1", + }, + { + "@odata.type": "#microsoft.graph.group", + "id": "group2", + "displayName": "Group 2", + }, + ], + } + + async def mock_get(*args, **kwargs): + mock = MagicMock() + mock.json.return_value = mock_response + return mock + + with patch( + "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" + ) as mock_client: + mock_client.return_value = MagicMock() + mock_client.return_value.get = mock_get + + # Act + result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( + access_token="mock_token" + ) + + # Assert + assert isinstance(result, list) + assert len(result) == 2 + assert "group1" in result + assert "group2" in result + + +@pytest.mark.asyncio +async def test_get_user_groups_pagination(): + # Arrange + first_response = { + "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", + "@odata.nextLink": "https://graph.microsoft.com/v1.0/me/memberOf?$skiptoken=page2", + "value": [ + { + "@odata.type": "#microsoft.graph.group", + "id": "group1", + "displayName": "Group 1", + }, + ], + } + second_response = { + "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", + "value": [ + { + "@odata.type": "#microsoft.graph.group", + "id": "group2", + "displayName": "Group 2", + }, + ], + } + + responses = [first_response, second_response] + current_response = {"index": 0} + + async def mock_get(*args, **kwargs): + mock = MagicMock() + mock.json.return_value = responses[current_response["index"]] + current_response["index"] += 1 + return mock + + with patch( + "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" + ) as mock_client: + mock_client.return_value = MagicMock() + mock_client.return_value.get = mock_get + + # Act + result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( + access_token="mock_token" + ) + + # Assert + assert isinstance(result, list) + assert len(result) == 2 + assert "group1" in result + assert "group2" in result + assert current_response["index"] == 2 # Verify both pages were fetched + + +@pytest.mark.asyncio +async def test_get_user_groups_empty_response(): + # Arrange + mock_response = { + "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", + "value": [], + } + + async def mock_get(*args, **kwargs): + mock = MagicMock() + mock.json.return_value = mock_response + return mock + + with patch( + "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" + ) as mock_client: + mock_client.return_value = MagicMock() + mock_client.return_value.get = mock_get + + # Act + result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( + access_token="mock_token" + ) + + # Assert + assert isinstance(result, list) + assert len(result) == 0 + + +@pytest.mark.asyncio +async def test_get_user_groups_error_handling(): + # Arrange + async def mock_get(*args, **kwargs): + raise Exception("API Error") + + with patch( + "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" + ) as mock_client: + mock_client.return_value = MagicMock() + mock_client.return_value.get = mock_get + + # Act + result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( + access_token="mock_token" + ) + + # Assert + assert isinstance(result, list) + assert len(result) == 0 + + +def test_get_group_ids_from_graph_api_response(): + # Arrange + mock_response = MicrosoftGraphAPIUserGroupResponse( + odata_context="https://graph.microsoft.com/v1.0/$metadata#directoryObjects", + odata_nextLink=None, + value=[ + MicrosoftGraphAPIUserGroupDirectoryObject( + odata_type="#microsoft.graph.group", + id="group1", + displayName="Group 1", + description=None, + deletedDateTime=None, + roleTemplateId=None, + ), + MicrosoftGraphAPIUserGroupDirectoryObject( + odata_type="#microsoft.graph.group", + id="group2", + displayName="Group 2", + description=None, + deletedDateTime=None, + roleTemplateId=None, + ), + MicrosoftGraphAPIUserGroupDirectoryObject( + odata_type="#microsoft.graph.group", + id=None, # Test handling of None id + displayName="Invalid Group", + description=None, + deletedDateTime=None, + roleTemplateId=None, + ), + ], + ) + + # Act + result = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(mock_response) + + # Assert + assert isinstance(result, list) + assert len(result) == 2 + assert "group1" in result + assert "group2" in result