import asyncio import json import os import sys from typing import Optional, cast from unittest.mock import MagicMock, patch import pytest from fastapi import Request from fastapi.testclient import TestClient sys.path.insert( 0, os.path.abspath("../../../") ) # Adds the parent directory to the system path from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.proxy.management_endpoints.ui_sso import ( GoogleSSOHandler, MicrosoftSSOHandler, ) from litellm.types.proxy.management_endpoints.ui_sso import ( MicrosoftGraphAPIUserGroupDirectoryObject, MicrosoftGraphAPIUserGroupResponse, ) def test_microsoft_sso_handler_openid_from_response(): # Arrange # Create a mock response similar to what Microsoft SSO would return mock_response = { "mail": "test@example.com", "displayName": "Test User", "id": "user123", "givenName": "Test", "surname": "User", "some_other_field": "value", } expected_team_ids = ["team1", "team2"] # Act # Call the method being tested result = MicrosoftSSOHandler.openid_from_response( response=mock_response, team_ids=expected_team_ids ) # Assert # Check that the result is a CustomOpenID object with the expected values assert isinstance(result, CustomOpenID) assert result.email == "test@example.com" assert result.display_name == "Test User" assert result.provider == "microsoft" assert result.id == "user123" assert result.first_name == "Test" assert result.last_name == "User" assert result.team_ids == expected_team_ids def test_microsoft_sso_handler_with_empty_response(): # Arrange # Test with None response # Act result = MicrosoftSSOHandler.openid_from_response(response=None, team_ids=[]) # Assert assert isinstance(result, CustomOpenID) assert result.email is None assert result.display_name is None assert result.provider == "microsoft" assert result.id is None assert result.first_name is None assert result.last_name is None assert result.team_ids == [] def test_get_microsoft_callback_response(): # Arrange mock_request = MagicMock(spec=Request) mock_response = { "mail": "microsoft_user@example.com", "displayName": "Microsoft User", "id": "msft123", "givenName": "Microsoft", "surname": "User", } future = asyncio.Future() future.set_result(mock_response) with patch.dict( os.environ, {"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"}, ): with patch( "fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process", return_value=future, ): # Act result = asyncio.run( MicrosoftSSOHandler.get_microsoft_callback_response( request=mock_request, microsoft_client_id="mock_client_id", redirect_url="http://mock_redirect_url", ) ) # Assert assert isinstance(result, CustomOpenID) assert result.email == "microsoft_user@example.com" assert result.display_name == "Microsoft User" assert result.provider == "microsoft" assert result.id == "msft123" assert result.first_name == "Microsoft" assert result.last_name == "User" def test_get_microsoft_callback_response_raw_sso_response(): # Arrange mock_request = MagicMock(spec=Request) mock_response = { "mail": "microsoft_user@example.com", "displayName": "Microsoft User", "id": "msft123", "givenName": "Microsoft", "surname": "User", } future = asyncio.Future() future.set_result(mock_response) with patch.dict( os.environ, {"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"}, ): with patch( "fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process", return_value=future, ): # Act result = asyncio.run( MicrosoftSSOHandler.get_microsoft_callback_response( request=mock_request, microsoft_client_id="mock_client_id", redirect_url="http://mock_redirect_url", return_raw_sso_response=True, ) ) # Assert print("result from verify_and_process", result) assert isinstance(result, dict) assert result["mail"] == "microsoft_user@example.com" assert result["displayName"] == "Microsoft User" assert result["id"] == "msft123" assert result["givenName"] == "Microsoft" assert result["surname"] == "User" def test_get_google_callback_response(): # Arrange mock_request = MagicMock(spec=Request) mock_response = { "email": "google_user@example.com", "name": "Google User", "sub": "google123", "given_name": "Google", "family_name": "User", } future = asyncio.Future() future.set_result(mock_response) with patch.dict(os.environ, {"GOOGLE_CLIENT_SECRET": "mock_secret"}): with patch( "fastapi_sso.sso.google.GoogleSSO.verify_and_process", return_value=future ): # Act result = asyncio.run( GoogleSSOHandler.get_google_callback_response( request=mock_request, google_client_id="mock_client_id", redirect_url="http://mock_redirect_url", ) ) # Assert assert isinstance(result, dict) assert result.get("email") == "google_user@example.com" assert result.get("name") == "Google User" assert result.get("sub") == "google123" assert result.get("given_name") == "Google" assert result.get("family_name") == "User" @pytest.mark.asyncio async def test_get_user_groups_from_graph_api(): # Arrange mock_response = { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", "value": [ { "@odata.type": "#microsoft.graph.group", "id": "group1", "displayName": "Group 1", }, { "@odata.type": "#microsoft.graph.group", "id": "group2", "displayName": "Group 2", }, ], } async def mock_get(*args, **kwargs): mock = MagicMock() mock.json.return_value = mock_response return mock with patch( "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" ) as mock_client: mock_client.return_value = MagicMock() mock_client.return_value.get = mock_get # Act result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( access_token="mock_token" ) # Assert assert isinstance(result, list) assert len(result) == 2 assert "group1" in result assert "group2" in result @pytest.mark.asyncio async def test_get_user_groups_pagination(): # Arrange first_response = { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", "@odata.nextLink": "https://graph.microsoft.com/v1.0/me/memberOf?$skiptoken=page2", "value": [ { "@odata.type": "#microsoft.graph.group", "id": "group1", "displayName": "Group 1", }, ], } second_response = { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", "value": [ { "@odata.type": "#microsoft.graph.group", "id": "group2", "displayName": "Group 2", }, ], } responses = [first_response, second_response] current_response = {"index": 0} async def mock_get(*args, **kwargs): mock = MagicMock() mock.json.return_value = responses[current_response["index"]] current_response["index"] += 1 return mock with patch( "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" ) as mock_client: mock_client.return_value = MagicMock() mock_client.return_value.get = mock_get # Act result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( access_token="mock_token" ) # Assert assert isinstance(result, list) assert len(result) == 2 assert "group1" in result assert "group2" in result assert current_response["index"] == 2 # Verify both pages were fetched @pytest.mark.asyncio async def test_get_user_groups_empty_response(): # Arrange mock_response = { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", "value": [], } async def mock_get(*args, **kwargs): mock = MagicMock() mock.json.return_value = mock_response return mock with patch( "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" ) as mock_client: mock_client.return_value = MagicMock() mock_client.return_value.get = mock_get # Act result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( access_token="mock_token" ) # Assert assert isinstance(result, list) assert len(result) == 0 @pytest.mark.asyncio async def test_get_user_groups_error_handling(): # Arrange async def mock_get(*args, **kwargs): raise Exception("API Error") with patch( "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" ) as mock_client: mock_client.return_value = MagicMock() mock_client.return_value.get = mock_get # Act result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( access_token="mock_token" ) # Assert assert isinstance(result, list) assert len(result) == 0 def test_get_group_ids_from_graph_api_response(): # Arrange mock_response = MicrosoftGraphAPIUserGroupResponse( odata_context="https://graph.microsoft.com/v1.0/$metadata#directoryObjects", odata_nextLink=None, value=[ MicrosoftGraphAPIUserGroupDirectoryObject( odata_type="#microsoft.graph.group", id="group1", displayName="Group 1", description=None, deletedDateTime=None, roleTemplateId=None, ), MicrosoftGraphAPIUserGroupDirectoryObject( odata_type="#microsoft.graph.group", id="group2", displayName="Group 2", description=None, deletedDateTime=None, roleTemplateId=None, ), MicrosoftGraphAPIUserGroupDirectoryObject( odata_type="#microsoft.graph.group", id=None, # Test handling of None id displayName="Invalid Group", description=None, deletedDateTime=None, roleTemplateId=None, ), ], ) # Act result = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(mock_response) # Assert assert isinstance(result, list) assert len(result) == 2 assert "group1" in result assert "group2" in result