mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
[Bug Fix MSFT SSO] Use correct field for user email when using MSFT SSO (#9886)
* fix openid_from_response * test_microsoft_sso_handler_openid_from_response_user_principal_name * test upsert_sso_user
This commit is contained in:
parent
94a553dbb2
commit
72a12e91c4
2 changed files with 83 additions and 23 deletions
|
@ -468,9 +468,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
result=result,
|
||||
user_info=user_info,
|
||||
user_email=user_email,
|
||||
user_id_models=user_id_models,
|
||||
max_internal_user_budget=max_internal_user_budget,
|
||||
internal_user_budget_duration=internal_user_budget_duration,
|
||||
user_defined_values=user_defined_values,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
@ -831,37 +828,20 @@ class SSOAuthenticationHandler:
|
|||
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||||
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
|
||||
user_email: Optional[str],
|
||||
user_id_models: List[str],
|
||||
max_internal_user_budget: Optional[float],
|
||||
internal_user_budget_duration: Optional[str],
|
||||
user_defined_values: Optional[SSOUserDefinedValues],
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
"""
|
||||
Connects the SSO Users to the User Table in LiteLLM DB
|
||||
|
||||
- If user on LiteLLM DB, update the user_id with the SSO user_id
|
||||
- If user on LiteLLM DB, update the user_email with the SSO user_email
|
||||
- If user not on LiteLLM DB, insert the user into LiteLLM DB
|
||||
"""
|
||||
try:
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
user_defined_values = SSOUserDefinedValues(
|
||||
models=getattr(user_info, "models", user_id_models),
|
||||
user_id=user_info.user_id or "",
|
||||
user_email=getattr(user_info, "user_email", user_email),
|
||||
user_role=getattr(user_info, "user_role", None),
|
||||
max_budget=getattr(
|
||||
user_info, "max_budget", max_internal_user_budget
|
||||
),
|
||||
budget_duration=getattr(
|
||||
user_info, "budget_duration", internal_user_budget_duration
|
||||
),
|
||||
)
|
||||
|
||||
# update id
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
|
||||
where={"user_id": user_id}, data={"user_email": user_email}
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.info(
|
||||
|
@ -1045,7 +1025,7 @@ class MicrosoftSSOHandler:
|
|||
response = response or {}
|
||||
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
|
||||
openid_response = CustomOpenID(
|
||||
email=response.get("mail"),
|
||||
email=response.get("userPrincipalName") or response.get("mail"),
|
||||
display_name=response.get("displayName"),
|
||||
provider="microsoft",
|
||||
id=response.get("id"),
|
||||
|
|
|
@ -21,6 +21,7 @@ from litellm.proxy.management_endpoints.types import CustomOpenID
|
|||
from litellm.proxy.management_endpoints.ui_sso import (
|
||||
GoogleSSOHandler,
|
||||
MicrosoftSSOHandler,
|
||||
SSOAuthenticationHandler,
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||
MicrosoftGraphAPIUserGroupDirectoryObject,
|
||||
|
@ -29,6 +30,37 @@ from litellm.types.proxy.management_endpoints.ui_sso import (
|
|||
)
|
||||
|
||||
|
||||
def test_microsoft_sso_handler_openid_from_response_user_principal_name():
|
||||
# Arrange
|
||||
# Create a mock response similar to what Microsoft SSO would return
|
||||
mock_response = {
|
||||
"userPrincipalName": "test@example.com",
|
||||
"displayName": "Test User",
|
||||
"id": "user123",
|
||||
"givenName": "Test",
|
||||
"surname": "User",
|
||||
"some_other_field": "value",
|
||||
}
|
||||
expected_team_ids = ["team1", "team2"]
|
||||
# Act
|
||||
# Call the method being tested
|
||||
result = MicrosoftSSOHandler.openid_from_response(
|
||||
response=mock_response, team_ids=expected_team_ids
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
||||
# Check that the result is a CustomOpenID object with the expected values
|
||||
assert isinstance(result, CustomOpenID)
|
||||
assert result.email == "test@example.com"
|
||||
assert result.display_name == "Test User"
|
||||
assert result.provider == "microsoft"
|
||||
assert result.id == "user123"
|
||||
assert result.first_name == "Test"
|
||||
assert result.last_name == "User"
|
||||
assert result.team_ids == expected_team_ids
|
||||
|
||||
|
||||
def test_microsoft_sso_handler_openid_from_response():
|
||||
# Arrange
|
||||
# Create a mock response similar to what Microsoft SSO would return
|
||||
|
@ -386,6 +418,54 @@ def test_get_group_ids_from_graph_api_response():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_sso_user_existing_user():
|
||||
"""
|
||||
If a user_id is already in the LiteLLM DB and the user signed in with SSO. Ensure that the user_id is updated with the SSO user_email
|
||||
|
||||
SSO Test
|
||||
"""
|
||||
# Arrange
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.db = MagicMock()
|
||||
mock_prisma.db.litellm_usertable = MagicMock()
|
||||
mock_prisma.db.litellm_usertable.update_many = AsyncMock()
|
||||
|
||||
# Create a mock existing user
|
||||
mock_user = MagicMock()
|
||||
mock_user.user_id = "existing_user_123"
|
||||
mock_user.user_email = "old_email@example.com"
|
||||
|
||||
# Create mock SSO response
|
||||
mock_sso_response = CustomOpenID(
|
||||
email="new_email@example.com",
|
||||
display_name="Test User",
|
||||
provider="microsoft",
|
||||
id="existing_user_123",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
team_ids=[],
|
||||
)
|
||||
|
||||
# Create mock user defined values
|
||||
mock_user_defined_values = MagicMock()
|
||||
|
||||
# Act
|
||||
result = await SSOAuthenticationHandler.upsert_sso_user(
|
||||
result=mock_sso_response,
|
||||
user_info=mock_user,
|
||||
user_email="new_email@example.com",
|
||||
user_defined_values=mock_user_defined_values,
|
||||
prisma_client=mock_prisma,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_prisma.db.litellm_usertable.update_many.assert_called_once_with(
|
||||
where={"user_id": "existing_user_123"},
|
||||
data={"user_email": "new_email@example.com"},
|
||||
)
|
||||
assert result == mock_user
|
||||
|
||||
|
||||
async def test_default_team_params():
|
||||
"""
|
||||
When litellm.default_team_params is set, it should be used to create a new team
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue