diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 96c6f91a5b..d356835678 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -363,7 +363,6 @@ async def auth_callback(request: Request): # noqa: PLR0915 request=request, microsoft_client_id=microsoft_client_id, redirect_url=redirect_url, - jwt_handler=jwt_handler, ) elif generic_client_id is not None: result = await get_generic_sso_response( @@ -856,7 +855,6 @@ class MicrosoftSSOHandler: request: Request, microsoft_client_id: str, redirect_url: str, - jwt_handler: JWTHandler, return_raw_sso_response: bool = False, ) -> Union[CustomOpenID, OpenID, dict]: """ @@ -905,14 +903,13 @@ class MicrosoftSSOHandler: result = MicrosoftSSOHandler.openid_from_response( response=original_msft_result, - jwt_handler=jwt_handler, team_ids=user_team_ids, ) return result @staticmethod def openid_from_response( - response: Optional[dict], jwt_handler: JWTHandler, team_ids: List[str] + response: Optional[dict], team_ids: List[str] ) -> CustomOpenID: response = response or {} verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") @@ -1151,7 +1148,6 @@ async def debug_sso_callback(request: Request): request=request, microsoft_client_id=microsoft_client_id, redirect_url=redirect_url, - jwt_handler=jwt_handler, return_raw_sso_response=True, ) elif generic_client_id is not None: diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py index b785b01f8c..14b6688361 100644 --- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py @@ -32,23 +32,14 @@ def test_microsoft_sso_handler_openid_from_response(): "surname": "User", "some_other_field": "value", } - - # Create a mock JWTHandler that returns predetermined team IDs - mock_jwt_handler = MagicMock(spec=JWTHandler) expected_team_ids = ["team1", "team2"] - mock_jwt_handler.get_team_ids_from_jwt.return_value = expected_team_ids - # Act # Call the method being tested result = MicrosoftSSOHandler.openid_from_response( - response=mock_response, jwt_handler=mock_jwt_handler + response=mock_response, team_ids=expected_team_ids ) # Assert - # Verify the JWT handler was called with the correct parameters - mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with( - cast(dict, mock_response) - ) # Check that the result is a CustomOpenID object with the expected values assert isinstance(result, CustomOpenID) @@ -64,13 +55,9 @@ def test_microsoft_sso_handler_openid_from_response(): def test_microsoft_sso_handler_with_empty_response(): # Arrange # Test with None response - mock_jwt_handler = MagicMock(spec=JWTHandler) - mock_jwt_handler.get_team_ids_from_jwt.return_value = [] # Act - result = MicrosoftSSOHandler.openid_from_response( - response=None, jwt_handler=mock_jwt_handler - ) + result = MicrosoftSSOHandler.openid_from_response(response=None, team_ids=[]) # Assert assert isinstance(result, CustomOpenID) @@ -82,14 +69,10 @@ def test_microsoft_sso_handler_with_empty_response(): assert result.last_name is None assert result.team_ids == [] - # Make sure the JWT handler was called with an empty dict - mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with({}) - def test_get_microsoft_callback_response(): # Arrange mock_request = MagicMock(spec=Request) - mock_jwt_handler = MagicMock(spec=JWTHandler) mock_response = { "mail": "microsoft_user@example.com", "displayName": "Microsoft User", @@ -115,7 +98,6 @@ def test_get_microsoft_callback_response(): request=mock_request, microsoft_client_id="mock_client_id", redirect_url="http://mock_redirect_url", - jwt_handler=mock_jwt_handler, ) ) @@ -132,7 +114,6 @@ def test_get_microsoft_callback_response(): def test_get_microsoft_callback_response_raw_sso_response(): # Arrange mock_request = MagicMock(spec=Request) - mock_jwt_handler = MagicMock(spec=JWTHandler) mock_response = { "mail": "microsoft_user@example.com", "displayName": "Microsoft User", @@ -157,7 +138,6 @@ def test_get_microsoft_callback_response_raw_sso_response(): request=mock_request, microsoft_client_id="mock_client_id", redirect_url="http://mock_redirect_url", - jwt_handler=mock_jwt_handler, return_raw_sso_response=True, ) )