mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
[Bug Fix MSFT SSO] Use correct field for user email when using MSFT SSO (#9886)
* fix openid_from_response * test_microsoft_sso_handler_openid_from_response_user_principal_name * test upsert_sso_user
This commit is contained in:
parent
94a553dbb2
commit
72a12e91c4
2 changed files with 83 additions and 23 deletions
|
@ -468,9 +468,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
result=result,
|
result=result,
|
||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
user_email=user_email,
|
user_email=user_email,
|
||||||
user_id_models=user_id_models,
|
|
||||||
max_internal_user_budget=max_internal_user_budget,
|
|
||||||
internal_user_budget_duration=internal_user_budget_duration,
|
|
||||||
user_defined_values=user_defined_values,
|
user_defined_values=user_defined_values,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
)
|
)
|
||||||
|
@ -831,37 +828,20 @@ class SSOAuthenticationHandler:
|
||||||
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||||||
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
|
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
|
||||||
user_email: Optional[str],
|
user_email: Optional[str],
|
||||||
user_id_models: List[str],
|
|
||||||
max_internal_user_budget: Optional[float],
|
|
||||||
internal_user_budget_duration: Optional[str],
|
|
||||||
user_defined_values: Optional[SSOUserDefinedValues],
|
user_defined_values: Optional[SSOUserDefinedValues],
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Connects the SSO Users to the User Table in LiteLLM DB
|
Connects the SSO Users to the User Table in LiteLLM DB
|
||||||
|
|
||||||
- If user on LiteLLM DB, update the user_id with the SSO user_id
|
- If user on LiteLLM DB, update the user_email with the SSO user_email
|
||||||
- If user not on LiteLLM DB, insert the user into LiteLLM DB
|
- If user not on LiteLLM DB, insert the user into LiteLLM DB
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if user_info is not None:
|
if user_info is not None:
|
||||||
user_id = user_info.user_id
|
user_id = user_info.user_id
|
||||||
user_defined_values = SSOUserDefinedValues(
|
|
||||||
models=getattr(user_info, "models", user_id_models),
|
|
||||||
user_id=user_info.user_id or "",
|
|
||||||
user_email=getattr(user_info, "user_email", user_email),
|
|
||||||
user_role=getattr(user_info, "user_role", None),
|
|
||||||
max_budget=getattr(
|
|
||||||
user_info, "max_budget", max_internal_user_budget
|
|
||||||
),
|
|
||||||
budget_duration=getattr(
|
|
||||||
user_info, "budget_duration", internal_user_budget_duration
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# update id
|
|
||||||
await prisma_client.db.litellm_usertable.update_many(
|
await prisma_client.db.litellm_usertable.update_many(
|
||||||
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
|
where={"user_id": user_id}, data={"user_email": user_email}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
|
@ -1045,7 +1025,7 @@ class MicrosoftSSOHandler:
|
||||||
response = response or {}
|
response = response or {}
|
||||||
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
|
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
|
||||||
openid_response = CustomOpenID(
|
openid_response = CustomOpenID(
|
||||||
email=response.get("mail"),
|
email=response.get("userPrincipalName") or response.get("mail"),
|
||||||
display_name=response.get("displayName"),
|
display_name=response.get("displayName"),
|
||||||
provider="microsoft",
|
provider="microsoft",
|
||||||
id=response.get("id"),
|
id=response.get("id"),
|
||||||
|
|
|
@ -21,6 +21,7 @@ from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||||
from litellm.proxy.management_endpoints.ui_sso import (
|
from litellm.proxy.management_endpoints.ui_sso import (
|
||||||
GoogleSSOHandler,
|
GoogleSSOHandler,
|
||||||
MicrosoftSSOHandler,
|
MicrosoftSSOHandler,
|
||||||
|
SSOAuthenticationHandler,
|
||||||
)
|
)
|
||||||
from litellm.types.proxy.management_endpoints.ui_sso import (
|
from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||||
MicrosoftGraphAPIUserGroupDirectoryObject,
|
MicrosoftGraphAPIUserGroupDirectoryObject,
|
||||||
|
@ -29,6 +30,37 @@ from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_microsoft_sso_handler_openid_from_response_user_principal_name():
|
||||||
|
# Arrange
|
||||||
|
# Create a mock response similar to what Microsoft SSO would return
|
||||||
|
mock_response = {
|
||||||
|
"userPrincipalName": "test@example.com",
|
||||||
|
"displayName": "Test User",
|
||||||
|
"id": "user123",
|
||||||
|
"givenName": "Test",
|
||||||
|
"surname": "User",
|
||||||
|
"some_other_field": "value",
|
||||||
|
}
|
||||||
|
expected_team_ids = ["team1", "team2"]
|
||||||
|
# Act
|
||||||
|
# Call the method being tested
|
||||||
|
result = MicrosoftSSOHandler.openid_from_response(
|
||||||
|
response=mock_response, team_ids=expected_team_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
|
||||||
|
# Check that the result is a CustomOpenID object with the expected values
|
||||||
|
assert isinstance(result, CustomOpenID)
|
||||||
|
assert result.email == "test@example.com"
|
||||||
|
assert result.display_name == "Test User"
|
||||||
|
assert result.provider == "microsoft"
|
||||||
|
assert result.id == "user123"
|
||||||
|
assert result.first_name == "Test"
|
||||||
|
assert result.last_name == "User"
|
||||||
|
assert result.team_ids == expected_team_ids
|
||||||
|
|
||||||
|
|
||||||
def test_microsoft_sso_handler_openid_from_response():
|
def test_microsoft_sso_handler_openid_from_response():
|
||||||
# Arrange
|
# Arrange
|
||||||
# Create a mock response similar to what Microsoft SSO would return
|
# Create a mock response similar to what Microsoft SSO would return
|
||||||
|
@ -386,6 +418,54 @@ def test_get_group_ids_from_graph_api_response():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_sso_user_existing_user():
|
||||||
|
"""
|
||||||
|
If a user_id is already in the LiteLLM DB and the user signed in with SSO. Ensure that the user_id is updated with the SSO user_email
|
||||||
|
|
||||||
|
SSO Test
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_prisma = MagicMock()
|
||||||
|
mock_prisma.db = MagicMock()
|
||||||
|
mock_prisma.db.litellm_usertable = MagicMock()
|
||||||
|
mock_prisma.db.litellm_usertable.update_many = AsyncMock()
|
||||||
|
|
||||||
|
# Create a mock existing user
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.user_id = "existing_user_123"
|
||||||
|
mock_user.user_email = "old_email@example.com"
|
||||||
|
|
||||||
|
# Create mock SSO response
|
||||||
|
mock_sso_response = CustomOpenID(
|
||||||
|
email="new_email@example.com",
|
||||||
|
display_name="Test User",
|
||||||
|
provider="microsoft",
|
||||||
|
id="existing_user_123",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
team_ids=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock user defined values
|
||||||
|
mock_user_defined_values = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await SSOAuthenticationHandler.upsert_sso_user(
|
||||||
|
result=mock_sso_response,
|
||||||
|
user_info=mock_user,
|
||||||
|
user_email="new_email@example.com",
|
||||||
|
user_defined_values=mock_user_defined_values,
|
||||||
|
prisma_client=mock_prisma,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_prisma.db.litellm_usertable.update_many.assert_called_once_with(
|
||||||
|
where={"user_id": "existing_user_123"},
|
||||||
|
data={"user_email": "new_email@example.com"},
|
||||||
|
)
|
||||||
|
assert result == mock_user
|
||||||
|
|
||||||
|
|
||||||
async def test_default_team_params():
|
async def test_default_team_params():
|
||||||
"""
|
"""
|
||||||
When litellm.default_team_params is set, it should be used to create a new team
|
When litellm.default_team_params is set, it should be used to create a new team
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue