mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* refactor SSO handler * render sso JWT on ui * docs debug sso * fix sso login flow use await * fix ui sso debug JWT * test ui sso * remove redis vl * fix redisvl==0.5.1 * fix ml dtypes * fix redisvl * fix redis vl * fix debug_sso_callback * fix linting error * fix redis semantic caching dep * working graph api assignment * test msft sso handler openid * testing for msft group assignment * fix debug graph api sso flow * fix linting errors * add_user_to_teams_from_sso_response * fix linting error
381 lines
11 KiB
Python
381 lines
11 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import Optional, cast
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import Request
|
|
from fastapi.testclient import TestClient
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../../")
|
|
) # Adds the parent directory to the system path
|
|
|
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
|
from litellm.proxy.management_endpoints.types import CustomOpenID
|
|
from litellm.proxy.management_endpoints.ui_sso import (
|
|
GoogleSSOHandler,
|
|
MicrosoftSSOHandler,
|
|
)
|
|
from litellm.types.proxy.management_endpoints.ui_sso import (
|
|
MicrosoftGraphAPIUserGroupDirectoryObject,
|
|
MicrosoftGraphAPIUserGroupResponse,
|
|
)
|
|
|
|
|
|
def test_microsoft_sso_handler_openid_from_response():
|
|
# Arrange
|
|
# Create a mock response similar to what Microsoft SSO would return
|
|
mock_response = {
|
|
"mail": "test@example.com",
|
|
"displayName": "Test User",
|
|
"id": "user123",
|
|
"givenName": "Test",
|
|
"surname": "User",
|
|
"some_other_field": "value",
|
|
}
|
|
expected_team_ids = ["team1", "team2"]
|
|
# Act
|
|
# Call the method being tested
|
|
result = MicrosoftSSOHandler.openid_from_response(
|
|
response=mock_response, team_ids=expected_team_ids
|
|
)
|
|
|
|
# Assert
|
|
|
|
# Check that the result is a CustomOpenID object with the expected values
|
|
assert isinstance(result, CustomOpenID)
|
|
assert result.email == "test@example.com"
|
|
assert result.display_name == "Test User"
|
|
assert result.provider == "microsoft"
|
|
assert result.id == "user123"
|
|
assert result.first_name == "Test"
|
|
assert result.last_name == "User"
|
|
assert result.team_ids == expected_team_ids
|
|
|
|
|
|
def test_microsoft_sso_handler_with_empty_response():
|
|
# Arrange
|
|
# Test with None response
|
|
|
|
# Act
|
|
result = MicrosoftSSOHandler.openid_from_response(response=None, team_ids=[])
|
|
|
|
# Assert
|
|
assert isinstance(result, CustomOpenID)
|
|
assert result.email is None
|
|
assert result.display_name is None
|
|
assert result.provider == "microsoft"
|
|
assert result.id is None
|
|
assert result.first_name is None
|
|
assert result.last_name is None
|
|
assert result.team_ids == []
|
|
|
|
|
|
def test_get_microsoft_callback_response():
|
|
# Arrange
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_response = {
|
|
"mail": "microsoft_user@example.com",
|
|
"displayName": "Microsoft User",
|
|
"id": "msft123",
|
|
"givenName": "Microsoft",
|
|
"surname": "User",
|
|
}
|
|
|
|
future = asyncio.Future()
|
|
future.set_result(mock_response)
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"},
|
|
):
|
|
with patch(
|
|
"fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process",
|
|
return_value=future,
|
|
):
|
|
# Act
|
|
result = asyncio.run(
|
|
MicrosoftSSOHandler.get_microsoft_callback_response(
|
|
request=mock_request,
|
|
microsoft_client_id="mock_client_id",
|
|
redirect_url="http://mock_redirect_url",
|
|
)
|
|
)
|
|
|
|
# Assert
|
|
assert isinstance(result, CustomOpenID)
|
|
assert result.email == "microsoft_user@example.com"
|
|
assert result.display_name == "Microsoft User"
|
|
assert result.provider == "microsoft"
|
|
assert result.id == "msft123"
|
|
assert result.first_name == "Microsoft"
|
|
assert result.last_name == "User"
|
|
|
|
|
|
def test_get_microsoft_callback_response_raw_sso_response():
|
|
# Arrange
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_response = {
|
|
"mail": "microsoft_user@example.com",
|
|
"displayName": "Microsoft User",
|
|
"id": "msft123",
|
|
"givenName": "Microsoft",
|
|
"surname": "User",
|
|
}
|
|
|
|
future = asyncio.Future()
|
|
future.set_result(mock_response)
|
|
with patch.dict(
|
|
os.environ,
|
|
{"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"},
|
|
):
|
|
with patch(
|
|
"fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process",
|
|
return_value=future,
|
|
):
|
|
# Act
|
|
result = asyncio.run(
|
|
MicrosoftSSOHandler.get_microsoft_callback_response(
|
|
request=mock_request,
|
|
microsoft_client_id="mock_client_id",
|
|
redirect_url="http://mock_redirect_url",
|
|
return_raw_sso_response=True,
|
|
)
|
|
)
|
|
|
|
# Assert
|
|
print("result from verify_and_process", result)
|
|
assert isinstance(result, dict)
|
|
assert result["mail"] == "microsoft_user@example.com"
|
|
assert result["displayName"] == "Microsoft User"
|
|
assert result["id"] == "msft123"
|
|
assert result["givenName"] == "Microsoft"
|
|
assert result["surname"] == "User"
|
|
|
|
|
|
def test_get_google_callback_response():
|
|
# Arrange
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_response = {
|
|
"email": "google_user@example.com",
|
|
"name": "Google User",
|
|
"sub": "google123",
|
|
"given_name": "Google",
|
|
"family_name": "User",
|
|
}
|
|
|
|
future = asyncio.Future()
|
|
future.set_result(mock_response)
|
|
|
|
with patch.dict(os.environ, {"GOOGLE_CLIENT_SECRET": "mock_secret"}):
|
|
with patch(
|
|
"fastapi_sso.sso.google.GoogleSSO.verify_and_process", return_value=future
|
|
):
|
|
# Act
|
|
result = asyncio.run(
|
|
GoogleSSOHandler.get_google_callback_response(
|
|
request=mock_request,
|
|
google_client_id="mock_client_id",
|
|
redirect_url="http://mock_redirect_url",
|
|
)
|
|
)
|
|
|
|
# Assert
|
|
assert isinstance(result, dict)
|
|
assert result.get("email") == "google_user@example.com"
|
|
assert result.get("name") == "Google User"
|
|
assert result.get("sub") == "google123"
|
|
assert result.get("given_name") == "Google"
|
|
assert result.get("family_name") == "User"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_groups_from_graph_api():
|
|
# Arrange
|
|
mock_response = {
|
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
|
"value": [
|
|
{
|
|
"@odata.type": "#microsoft.graph.group",
|
|
"id": "group1",
|
|
"displayName": "Group 1",
|
|
},
|
|
{
|
|
"@odata.type": "#microsoft.graph.group",
|
|
"id": "group2",
|
|
"displayName": "Group 2",
|
|
},
|
|
],
|
|
}
|
|
|
|
async def mock_get(*args, **kwargs):
|
|
mock = MagicMock()
|
|
mock.json.return_value = mock_response
|
|
return mock
|
|
|
|
with patch(
|
|
"litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client"
|
|
) as mock_client:
|
|
mock_client.return_value = MagicMock()
|
|
mock_client.return_value.get = mock_get
|
|
|
|
# Act
|
|
result = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
|
access_token="mock_token"
|
|
)
|
|
|
|
# Assert
|
|
assert isinstance(result, list)
|
|
assert len(result) == 2
|
|
assert "group1" in result
|
|
assert "group2" in result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_groups_pagination():
|
|
# Arrange
|
|
first_response = {
|
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
|
"@odata.nextLink": "https://graph.microsoft.com/v1.0/me/memberOf?$skiptoken=page2",
|
|
"value": [
|
|
{
|
|
"@odata.type": "#microsoft.graph.group",
|
|
"id": "group1",
|
|
"displayName": "Group 1",
|
|
},
|
|
],
|
|
}
|
|
second_response = {
|
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
|
"value": [
|
|
{
|
|
"@odata.type": "#microsoft.graph.group",
|
|
"id": "group2",
|
|
"displayName": "Group 2",
|
|
},
|
|
],
|
|
}
|
|
|
|
responses = [first_response, second_response]
|
|
current_response = {"index": 0}
|
|
|
|
async def mock_get(*args, **kwargs):
|
|
mock = MagicMock()
|
|
mock.json.return_value = responses[current_response["index"]]
|
|
current_response["index"] += 1
|
|
return mock
|
|
|
|
with patch(
|
|
"litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client"
|
|
) as mock_client:
|
|
mock_client.return_value = MagicMock()
|
|
mock_client.return_value.get = mock_get
|
|
|
|
# Act
|
|
result = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
|
access_token="mock_token"
|
|
)
|
|
|
|
# Assert
|
|
assert isinstance(result, list)
|
|
assert len(result) == 2
|
|
assert "group1" in result
|
|
assert "group2" in result
|
|
assert current_response["index"] == 2 # Verify both pages were fetched
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_groups_empty_response():
|
|
# Arrange
|
|
mock_response = {
|
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
|
"value": [],
|
|
}
|
|
|
|
async def mock_get(*args, **kwargs):
|
|
mock = MagicMock()
|
|
mock.json.return_value = mock_response
|
|
return mock
|
|
|
|
with patch(
|
|
"litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client"
|
|
) as mock_client:
|
|
mock_client.return_value = MagicMock()
|
|
mock_client.return_value.get = mock_get
|
|
|
|
# Act
|
|
result = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
|
access_token="mock_token"
|
|
)
|
|
|
|
# Assert
|
|
assert isinstance(result, list)
|
|
assert len(result) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_groups_error_handling():
|
|
# Arrange
|
|
async def mock_get(*args, **kwargs):
|
|
raise Exception("API Error")
|
|
|
|
with patch(
|
|
"litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client"
|
|
) as mock_client:
|
|
mock_client.return_value = MagicMock()
|
|
mock_client.return_value.get = mock_get
|
|
|
|
# Act
|
|
result = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
|
access_token="mock_token"
|
|
)
|
|
|
|
# Assert
|
|
assert isinstance(result, list)
|
|
assert len(result) == 0
|
|
|
|
|
|
def test_get_group_ids_from_graph_api_response():
|
|
# Arrange
|
|
mock_response = MicrosoftGraphAPIUserGroupResponse(
|
|
odata_context="https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
|
odata_nextLink=None,
|
|
value=[
|
|
MicrosoftGraphAPIUserGroupDirectoryObject(
|
|
odata_type="#microsoft.graph.group",
|
|
id="group1",
|
|
displayName="Group 1",
|
|
description=None,
|
|
deletedDateTime=None,
|
|
roleTemplateId=None,
|
|
),
|
|
MicrosoftGraphAPIUserGroupDirectoryObject(
|
|
odata_type="#microsoft.graph.group",
|
|
id="group2",
|
|
displayName="Group 2",
|
|
description=None,
|
|
deletedDateTime=None,
|
|
roleTemplateId=None,
|
|
),
|
|
MicrosoftGraphAPIUserGroupDirectoryObject(
|
|
odata_type="#microsoft.graph.group",
|
|
id=None, # Test handling of None id
|
|
displayName="Invalid Group",
|
|
description=None,
|
|
deletedDateTime=None,
|
|
roleTemplateId=None,
|
|
),
|
|
],
|
|
)
|
|
|
|
# Act
|
|
result = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(mock_response)
|
|
|
|
# Assert
|
|
assert isinstance(result, list)
|
|
assert len(result) == 2
|
|
assert "group1" in result
|
|
assert "group2" in result
|