[Bug Fix MSFT SSO] Use correct field for user email when using MSFT SSO (#9886)

* fix openid_from_response

* test_microsoft_sso_handler_openid_from_response_user_principal_name

* test upsert_sso_user
This commit is contained in:
Ishaan Jaff 2025-04-10 17:40:58 -07:00 committed by GitHub
parent 94a553dbb2
commit 72a12e91c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 83 additions and 23 deletions

View file

@ -468,9 +468,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
result=result,
user_info=user_info,
user_email=user_email,
user_id_models=user_id_models,
max_internal_user_budget=max_internal_user_budget,
internal_user_budget_duration=internal_user_budget_duration,
user_defined_values=user_defined_values,
prisma_client=prisma_client,
)
@ -831,37 +828,20 @@ class SSOAuthenticationHandler:
result: Optional[Union[CustomOpenID, OpenID, dict]],
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
user_email: Optional[str],
user_id_models: List[str],
max_internal_user_budget: Optional[float],
internal_user_budget_duration: Optional[str],
user_defined_values: Optional[SSOUserDefinedValues],
prisma_client: PrismaClient,
):
"""
Connects the SSO Users to the User Table in LiteLLM DB
- If user on LiteLLM DB, update the user_id with the SSO user_id
- If user on LiteLLM DB, update the user_email with the SSO user_email
- If user not on LiteLLM DB, insert the user into LiteLLM DB
"""
try:
if user_info is not None:
user_id = user_info.user_id
user_defined_values = SSOUserDefinedValues(
models=getattr(user_info, "models", user_id_models),
user_id=user_info.user_id or "",
user_email=getattr(user_info, "user_email", user_email),
user_role=getattr(user_info, "user_role", None),
max_budget=getattr(
user_info, "max_budget", max_internal_user_budget
),
budget_duration=getattr(
user_info, "budget_duration", internal_user_budget_duration
),
)
# update id
await prisma_client.db.litellm_usertable.update_many(
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
where={"user_id": user_id}, data={"user_email": user_email}
)
else:
verbose_proxy_logger.info(
@ -1045,7 +1025,7 @@ class MicrosoftSSOHandler:
response = response or {}
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
openid_response = CustomOpenID(
email=response.get("mail"),
email=response.get("userPrincipalName") or response.get("mail"),
display_name=response.get("displayName"),
provider="microsoft",
id=response.get("id"),

View file

@ -21,6 +21,7 @@ from litellm.proxy.management_endpoints.types import CustomOpenID
from litellm.proxy.management_endpoints.ui_sso import (
GoogleSSOHandler,
MicrosoftSSOHandler,
SSOAuthenticationHandler,
)
from litellm.types.proxy.management_endpoints.ui_sso import (
MicrosoftGraphAPIUserGroupDirectoryObject,
@ -29,6 +30,37 @@ from litellm.types.proxy.management_endpoints.ui_sso import (
)
def test_microsoft_sso_handler_openid_from_response_user_principal_name():
# Arrange
# Create a mock response similar to what Microsoft SSO would return
mock_response = {
"userPrincipalName": "test@example.com",
"displayName": "Test User",
"id": "user123",
"givenName": "Test",
"surname": "User",
"some_other_field": "value",
}
expected_team_ids = ["team1", "team2"]
# Act
# Call the method being tested
result = MicrosoftSSOHandler.openid_from_response(
response=mock_response, team_ids=expected_team_ids
)
# Assert
# Check that the result is a CustomOpenID object with the expected values
assert isinstance(result, CustomOpenID)
assert result.email == "test@example.com"
assert result.display_name == "Test User"
assert result.provider == "microsoft"
assert result.id == "user123"
assert result.first_name == "Test"
assert result.last_name == "User"
assert result.team_ids == expected_team_ids
def test_microsoft_sso_handler_openid_from_response():
# Arrange
# Create a mock response similar to what Microsoft SSO would return
@ -386,6 +418,54 @@ def test_get_group_ids_from_graph_api_response():
@pytest.mark.asyncio
async def test_upsert_sso_user_existing_user():
"""
If a user_id is already in the LiteLLM DB and the user signed in with SSO. Ensure that the user_id is updated with the SSO user_email
SSO Test
"""
# Arrange
mock_prisma = MagicMock()
mock_prisma.db = MagicMock()
mock_prisma.db.litellm_usertable = MagicMock()
mock_prisma.db.litellm_usertable.update_many = AsyncMock()
# Create a mock existing user
mock_user = MagicMock()
mock_user.user_id = "existing_user_123"
mock_user.user_email = "old_email@example.com"
# Create mock SSO response
mock_sso_response = CustomOpenID(
email="new_email@example.com",
display_name="Test User",
provider="microsoft",
id="existing_user_123",
first_name="Test",
last_name="User",
team_ids=[],
)
# Create mock user defined values
mock_user_defined_values = MagicMock()
# Act
result = await SSOAuthenticationHandler.upsert_sso_user(
result=mock_sso_response,
user_info=mock_user,
user_email="new_email@example.com",
user_defined_values=mock_user_defined_values,
prisma_client=mock_prisma,
)
# Assert
mock_prisma.db.litellm_usertable.update_many.assert_called_once_with(
where={"user_id": "existing_user_123"},
data={"user_email": "new_email@example.com"},
)
assert result == mock_user
async def test_default_team_params():
"""
When litellm.default_team_params is set, it should be used to create a new team