mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
[Feat] - SSO - Use MSFT Graph API to assign users to teams (#9865)
* refactor SSO handler * render sso JWT on ui * docs debug sso * fix sso login flow use await * fix ui sso debug JWT * test ui sso * remove redis vl * fix redisvl==0.5.1 * fix ml dtypes * fix redisvl * fix redis vl * fix debug_sso_callback * fix linting error * fix redis semantic caching dep * working graph api assignment * test msft sso handler openid * testing for msft group assignment * fix debug graph api sso flow * fix linting errors * add_user_to_teams_from_sso_response * fix linting error
This commit is contained in:
parent
a1433da4a7
commit
4c1bb74c3d
4 changed files with 368 additions and 34 deletions
|
@ -11,7 +11,7 @@ Has all /sso/* routes
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
@ -19,6 +19,11 @@ from fastapi.responses import RedirectResponse
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
httpxSpecialProvider,
|
||||||
|
)
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
|
@ -51,6 +56,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
||||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
from litellm.types.proxy.management_endpoints.ui_sso import *
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi_sso.sso.base import OpenID
|
from fastapi_sso.sso.base import OpenID
|
||||||
|
@ -357,7 +363,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
request=request,
|
request=request,
|
||||||
microsoft_client_id=microsoft_client_id,
|
microsoft_client_id=microsoft_client_id,
|
||||||
redirect_url=redirect_url,
|
redirect_url=redirect_url,
|
||||||
jwt_handler=jwt_handler,
|
|
||||||
)
|
)
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
result = await get_generic_sso_response(
|
result = await get_generic_sso_response(
|
||||||
|
@ -490,8 +495,10 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
user_role = (
|
user_role = (
|
||||||
user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
)
|
)
|
||||||
sso_teams = getattr(result, "team_ids", [])
|
await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
|
||||||
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
result=result,
|
||||||
|
user_info=user_info,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -835,18 +842,42 @@ class SSOAuthenticationHandler:
|
||||||
redirect_url += "/" + sso_callback_route
|
redirect_url += "/" + sso_callback_route
|
||||||
return redirect_url
|
return redirect_url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def add_user_to_teams_from_sso_response(
|
||||||
|
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||||||
|
user_info: Union[NewUserResponse, LiteLLM_UserTable],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds the user as a team member to the teams specified in the SSO responses `team_ids` field
|
||||||
|
|
||||||
|
|
||||||
|
The `team_ids` field is populated by litellm after processing the SSO response
|
||||||
|
"""
|
||||||
|
sso_teams = getattr(result, "team_ids", [])
|
||||||
|
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftSSOHandler:
|
class MicrosoftSSOHandler:
|
||||||
"""
|
"""
|
||||||
Handles Microsoft SSO callback response and returns a CustomOpenID object
|
Handles Microsoft SSO callback response and returns a CustomOpenID object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
graph_api_base_url = "https://graph.microsoft.com/v1.0"
|
||||||
|
graph_api_user_groups_endpoint = f"{graph_api_base_url}/me/memberOf"
|
||||||
|
|
||||||
|
"""
|
||||||
|
Constants
|
||||||
|
"""
|
||||||
|
MAX_GRAPH_API_PAGES = 200
|
||||||
|
|
||||||
|
# used for debugging to show the user groups litellm found from Graph API
|
||||||
|
GRAPH_API_RESPONSE_KEY = "graph_api_user_groups"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_microsoft_callback_response(
|
async def get_microsoft_callback_response(
|
||||||
request: Request,
|
request: Request,
|
||||||
microsoft_client_id: str,
|
microsoft_client_id: str,
|
||||||
redirect_url: str,
|
redirect_url: str,
|
||||||
jwt_handler: JWTHandler,
|
|
||||||
return_raw_sso_response: bool = False,
|
return_raw_sso_response: bool = False,
|
||||||
) -> Union[CustomOpenID, OpenID, dict]:
|
) -> Union[CustomOpenID, OpenID, dict]:
|
||||||
"""
|
"""
|
||||||
|
@ -880,24 +911,34 @@ class MicrosoftSSOHandler:
|
||||||
redirect_uri=redirect_url,
|
redirect_uri=redirect_url,
|
||||||
allow_insecure_http=True,
|
allow_insecure_http=True,
|
||||||
)
|
)
|
||||||
original_msft_result = await microsoft_sso.verify_and_process(
|
original_msft_result = (
|
||||||
request=request,
|
await microsoft_sso.verify_and_process(
|
||||||
convert_response=False,
|
request=request,
|
||||||
|
convert_response=False,
|
||||||
|
)
|
||||||
|
or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
user_team_ids = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
||||||
|
access_token=microsoft_sso.access_token
|
||||||
)
|
)
|
||||||
|
|
||||||
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||||||
if return_raw_sso_response:
|
if return_raw_sso_response:
|
||||||
|
original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = (
|
||||||
|
user_team_ids
|
||||||
|
)
|
||||||
return original_msft_result or {}
|
return original_msft_result or {}
|
||||||
|
|
||||||
result = MicrosoftSSOHandler.openid_from_response(
|
result = MicrosoftSSOHandler.openid_from_response(
|
||||||
response=original_msft_result,
|
response=original_msft_result,
|
||||||
jwt_handler=jwt_handler,
|
team_ids=user_team_ids,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def openid_from_response(
|
def openid_from_response(
|
||||||
response: Optional[dict], jwt_handler: JWTHandler
|
response: Optional[dict], team_ids: List[str]
|
||||||
) -> CustomOpenID:
|
) -> CustomOpenID:
|
||||||
response = response or {}
|
response = response or {}
|
||||||
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
|
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
|
||||||
|
@ -908,11 +949,110 @@ class MicrosoftSSOHandler:
|
||||||
id=response.get("id"),
|
id=response.get("id"),
|
||||||
first_name=response.get("givenName"),
|
first_name=response.get("givenName"),
|
||||||
last_name=response.get("surname"),
|
last_name=response.get("surname"),
|
||||||
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
|
team_ids=team_ids,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
|
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
|
||||||
return openid_response
|
return openid_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_groups_from_graph_api(
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of `team_ids` the user belongs to from the Microsoft Graph API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token (Optional[str]): Microsoft Graph API access token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of group IDs the user belongs to
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.SSO_HANDLER
|
||||||
|
)
|
||||||
|
|
||||||
|
all_group_ids = []
|
||||||
|
next_link: Optional[str] = (
|
||||||
|
MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||||||
|
)
|
||||||
|
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
|
page_count = 0
|
||||||
|
|
||||||
|
while (
|
||||||
|
next_link is not None
|
||||||
|
and page_count < MicrosoftSSOHandler.MAX_GRAPH_API_PAGES
|
||||||
|
):
|
||||||
|
group_ids, next_link = await MicrosoftSSOHandler.fetch_and_parse_groups(
|
||||||
|
url=next_link, headers=auth_headers, async_client=async_client
|
||||||
|
)
|
||||||
|
all_group_ids.extend(group_ids)
|
||||||
|
page_count += 1
|
||||||
|
|
||||||
|
if (
|
||||||
|
next_link is not None
|
||||||
|
and page_count >= MicrosoftSSOHandler.MAX_GRAPH_API_PAGES
|
||||||
|
):
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included."
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_group_ids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
f"Error getting user groups from Microsoft Graph API: {e}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def fetch_and_parse_groups(
|
||||||
|
url: str, headers: dict, async_client: AsyncHTTPHandler
|
||||||
|
) -> Tuple[List[str], Optional[str]]:
|
||||||
|
"""Helper function to fetch and parse group data from a URL"""
|
||||||
|
response = await async_client.get(url, headers=headers)
|
||||||
|
response_json = response.json()
|
||||||
|
response_typed = await MicrosoftSSOHandler._cast_graph_api_response_dict(
|
||||||
|
response=response_json
|
||||||
|
)
|
||||||
|
group_ids = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(
|
||||||
|
response=response_typed
|
||||||
|
)
|
||||||
|
return group_ids, response_typed.get("odata_nextLink")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_group_ids_from_graph_api_response(
|
||||||
|
response: MicrosoftGraphAPIUserGroupResponse,
|
||||||
|
) -> List[str]:
|
||||||
|
group_ids = []
|
||||||
|
for _object in response.get("value", []) or []:
|
||||||
|
_group_id = _object.get("id")
|
||||||
|
if _group_id is not None:
|
||||||
|
group_ids.append(_group_id)
|
||||||
|
return group_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _cast_graph_api_response_dict(
|
||||||
|
response: dict,
|
||||||
|
) -> MicrosoftGraphAPIUserGroupResponse:
|
||||||
|
directory_objects: List[MicrosoftGraphAPIUserGroupDirectoryObject] = []
|
||||||
|
for _object in response.get("value", []):
|
||||||
|
directory_objects.append(
|
||||||
|
MicrosoftGraphAPIUserGroupDirectoryObject(
|
||||||
|
odata_type=_object.get("@odata.type"),
|
||||||
|
id=_object.get("id"),
|
||||||
|
deletedDateTime=_object.get("deletedDateTime"),
|
||||||
|
description=_object.get("description"),
|
||||||
|
displayName=_object.get("displayName"),
|
||||||
|
roleTemplateId=_object.get("roleTemplateId"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return MicrosoftGraphAPIUserGroupResponse(
|
||||||
|
odata_context=response.get("@odata.context"),
|
||||||
|
odata_nextLink=response.get("@odata.nextLink"),
|
||||||
|
value=directory_objects,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GoogleSSOHandler:
|
class GoogleSSOHandler:
|
||||||
"""
|
"""
|
||||||
|
@ -1046,9 +1186,9 @@ async def debug_sso_callback(request: Request):
|
||||||
request=request,
|
request=request,
|
||||||
microsoft_client_id=microsoft_client_id,
|
microsoft_client_id=microsoft_client_id,
|
||||||
redirect_url=redirect_url,
|
redirect_url=redirect_url,
|
||||||
jwt_handler=jwt_handler,
|
|
||||||
return_raw_sso_response=True,
|
return_raw_sso_response=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
result = await get_generic_sso_response(
|
result = await get_generic_sso_response(
|
||||||
request=request,
|
request=request,
|
||||||
|
|
|
@ -19,6 +19,7 @@ class httpxSpecialProvider(str, Enum):
|
||||||
SecretManager = "secret_manager"
|
SecretManager = "secret_manager"
|
||||||
PassThroughEndpoint = "pass_through_endpoint"
|
PassThroughEndpoint = "pass_through_endpoint"
|
||||||
PromptFactory = "prompt_factory"
|
PromptFactory = "prompt_factory"
|
||||||
|
SSO_HANDLER = "sso_handler"
|
||||||
|
|
||||||
|
|
||||||
VerifyTypes = Union[str, bool, ssl.SSLContext]
|
VerifyTypes = Union[str, bool, ssl.SSLContext]
|
||||||
|
|
20
litellm/types/proxy/management_endpoints/ui_sso.py
Normal file
20
litellm/types/proxy/management_endpoints/ui_sso.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
from typing import List, Optional, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphAPIUserGroupDirectoryObject(TypedDict, total=False):
|
||||||
|
"""Model for Microsoft Graph API directory object"""
|
||||||
|
|
||||||
|
odata_type: Optional[str]
|
||||||
|
id: Optional[str]
|
||||||
|
deletedDateTime: Optional[str]
|
||||||
|
description: Optional[str]
|
||||||
|
displayName: Optional[str]
|
||||||
|
roleTemplateId: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphAPIUserGroupResponse(TypedDict, total=False):
|
||||||
|
"""Model for Microsoft Graph API user groups response"""
|
||||||
|
|
||||||
|
odata_context: Optional[str]
|
||||||
|
odata_nextLink: Optional[str]
|
||||||
|
value: Optional[List[MicrosoftGraphAPIUserGroupDirectoryObject]]
|
|
@ -19,6 +19,10 @@ from litellm.proxy.management_endpoints.ui_sso import (
|
||||||
GoogleSSOHandler,
|
GoogleSSOHandler,
|
||||||
MicrosoftSSOHandler,
|
MicrosoftSSOHandler,
|
||||||
)
|
)
|
||||||
|
from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||||
|
MicrosoftGraphAPIUserGroupDirectoryObject,
|
||||||
|
MicrosoftGraphAPIUserGroupResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_microsoft_sso_handler_openid_from_response():
|
def test_microsoft_sso_handler_openid_from_response():
|
||||||
|
@ -32,23 +36,14 @@ def test_microsoft_sso_handler_openid_from_response():
|
||||||
"surname": "User",
|
"surname": "User",
|
||||||
"some_other_field": "value",
|
"some_other_field": "value",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create a mock JWTHandler that returns predetermined team IDs
|
|
||||||
mock_jwt_handler = MagicMock(spec=JWTHandler)
|
|
||||||
expected_team_ids = ["team1", "team2"]
|
expected_team_ids = ["team1", "team2"]
|
||||||
mock_jwt_handler.get_team_ids_from_jwt.return_value = expected_team_ids
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Call the method being tested
|
# Call the method being tested
|
||||||
result = MicrosoftSSOHandler.openid_from_response(
|
result = MicrosoftSSOHandler.openid_from_response(
|
||||||
response=mock_response, jwt_handler=mock_jwt_handler
|
response=mock_response, team_ids=expected_team_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify the JWT handler was called with the correct parameters
|
|
||||||
mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with(
|
|
||||||
cast(dict, mock_response)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that the result is a CustomOpenID object with the expected values
|
# Check that the result is a CustomOpenID object with the expected values
|
||||||
assert isinstance(result, CustomOpenID)
|
assert isinstance(result, CustomOpenID)
|
||||||
|
@ -64,13 +59,9 @@ def test_microsoft_sso_handler_openid_from_response():
|
||||||
def test_microsoft_sso_handler_with_empty_response():
|
def test_microsoft_sso_handler_with_empty_response():
|
||||||
# Arrange
|
# Arrange
|
||||||
# Test with None response
|
# Test with None response
|
||||||
mock_jwt_handler = MagicMock(spec=JWTHandler)
|
|
||||||
mock_jwt_handler.get_team_ids_from_jwt.return_value = []
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = MicrosoftSSOHandler.openid_from_response(
|
result = MicrosoftSSOHandler.openid_from_response(response=None, team_ids=[])
|
||||||
response=None, jwt_handler=mock_jwt_handler
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert isinstance(result, CustomOpenID)
|
assert isinstance(result, CustomOpenID)
|
||||||
|
@ -82,14 +73,10 @@ def test_microsoft_sso_handler_with_empty_response():
|
||||||
assert result.last_name is None
|
assert result.last_name is None
|
||||||
assert result.team_ids == []
|
assert result.team_ids == []
|
||||||
|
|
||||||
# Make sure the JWT handler was called with an empty dict
|
|
||||||
mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with({})
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_microsoft_callback_response():
|
def test_get_microsoft_callback_response():
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_request = MagicMock(spec=Request)
|
mock_request = MagicMock(spec=Request)
|
||||||
mock_jwt_handler = MagicMock(spec=JWTHandler)
|
|
||||||
mock_response = {
|
mock_response = {
|
||||||
"mail": "microsoft_user@example.com",
|
"mail": "microsoft_user@example.com",
|
||||||
"displayName": "Microsoft User",
|
"displayName": "Microsoft User",
|
||||||
|
@ -115,7 +102,6 @@ def test_get_microsoft_callback_response():
|
||||||
request=mock_request,
|
request=mock_request,
|
||||||
microsoft_client_id="mock_client_id",
|
microsoft_client_id="mock_client_id",
|
||||||
redirect_url="http://mock_redirect_url",
|
redirect_url="http://mock_redirect_url",
|
||||||
jwt_handler=mock_jwt_handler,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -132,7 +118,6 @@ def test_get_microsoft_callback_response():
|
||||||
def test_get_microsoft_callback_response_raw_sso_response():
|
def test_get_microsoft_callback_response_raw_sso_response():
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_request = MagicMock(spec=Request)
|
mock_request = MagicMock(spec=Request)
|
||||||
mock_jwt_handler = MagicMock(spec=JWTHandler)
|
|
||||||
mock_response = {
|
mock_response = {
|
||||||
"mail": "microsoft_user@example.com",
|
"mail": "microsoft_user@example.com",
|
||||||
"displayName": "Microsoft User",
|
"displayName": "Microsoft User",
|
||||||
|
@ -157,7 +142,6 @@ def test_get_microsoft_callback_response_raw_sso_response():
|
||||||
request=mock_request,
|
request=mock_request,
|
||||||
microsoft_client_id="mock_client_id",
|
microsoft_client_id="mock_client_id",
|
||||||
redirect_url="http://mock_redirect_url",
|
redirect_url="http://mock_redirect_url",
|
||||||
jwt_handler=mock_jwt_handler,
|
|
||||||
return_raw_sso_response=True,
|
return_raw_sso_response=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -206,3 +190,192 @@ def test_get_google_callback_response():
|
||||||
assert result.get("sub") == "google123"
|
assert result.get("sub") == "google123"
|
||||||
assert result.get("given_name") == "Google"
|
assert result.get("given_name") == "Google"
|
||||||
assert result.get("family_name") == "User"
|
assert result.get("family_name") == "User"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_groups_from_graph_api():
|
||||||
|
# Arrange
|
||||||
|
mock_response = {
|
||||||
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"@odata.type": "#microsoft.graph.group",
|
||||||
|
"id": "group1",
|
||||||
|
"displayName": "Group 1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"@odata.type": "#microsoft.graph.group",
|
||||||
|
"id": "group2",
|
||||||
|
"displayName": "Group 2",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def mock_get(*args, **kwargs):
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.json.return_value = mock_response
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client"
|
||||||
|
) as mock_client:
|
||||||
|
mock_client.return_value = MagicMock()
|
||||||
|
mock_client.return_value.get = mock_get
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
||||||
|
access_token="mock_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "group1" in result
|
||||||
|
assert "group2" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_groups_pagination():
|
||||||
|
# Arrange
|
||||||
|
first_response = {
|
||||||
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
||||||
|
"@odata.nextLink": "https://graph.microsoft.com/v1.0/me/memberOf?$skiptoken=page2",
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"@odata.type": "#microsoft.graph.group",
|
||||||
|
"id": "group1",
|
||||||
|
"displayName": "Group 1",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
second_response = {
|
||||||
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"@odata.type": "#microsoft.graph.group",
|
||||||
|
"id": "group2",
|
||||||
|
"displayName": "Group 2",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
responses = [first_response, second_response]
|
||||||
|
current_response = {"index": 0}
|
||||||
|
|
||||||
|
async def mock_get(*args, **kwargs):
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.json.return_value = responses[current_response["index"]]
|
||||||
|
current_response["index"] += 1
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client"
|
||||||
|
) as mock_client:
|
||||||
|
mock_client.return_value = MagicMock()
|
||||||
|
mock_client.return_value.get = mock_get
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
||||||
|
access_token="mock_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "group1" in result
|
||||||
|
assert "group2" in result
|
||||||
|
assert current_response["index"] == 2 # Verify both pages were fetched
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_groups_empty_response():
|
||||||
|
# Arrange
|
||||||
|
mock_response = {
|
||||||
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
||||||
|
"value": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def mock_get(*args, **kwargs):
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.json.return_value = mock_response
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client"
|
||||||
|
) as mock_client:
|
||||||
|
mock_client.return_value = MagicMock()
|
||||||
|
mock_client.return_value.get = mock_get
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
||||||
|
access_token="mock_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_groups_error_handling():
|
||||||
|
# Arrange
|
||||||
|
async def mock_get(*args, **kwargs):
|
||||||
|
raise Exception("API Error")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client"
|
||||||
|
) as mock_client:
|
||||||
|
mock_client.return_value = MagicMock()
|
||||||
|
mock_client.return_value.get = mock_get
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
||||||
|
access_token="mock_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_group_ids_from_graph_api_response():
|
||||||
|
# Arrange
|
||||||
|
mock_response = MicrosoftGraphAPIUserGroupResponse(
|
||||||
|
odata_context="https://graph.microsoft.com/v1.0/$metadata#directoryObjects",
|
||||||
|
odata_nextLink=None,
|
||||||
|
value=[
|
||||||
|
MicrosoftGraphAPIUserGroupDirectoryObject(
|
||||||
|
odata_type="#microsoft.graph.group",
|
||||||
|
id="group1",
|
||||||
|
displayName="Group 1",
|
||||||
|
description=None,
|
||||||
|
deletedDateTime=None,
|
||||||
|
roleTemplateId=None,
|
||||||
|
),
|
||||||
|
MicrosoftGraphAPIUserGroupDirectoryObject(
|
||||||
|
odata_type="#microsoft.graph.group",
|
||||||
|
id="group2",
|
||||||
|
displayName="Group 2",
|
||||||
|
description=None,
|
||||||
|
deletedDateTime=None,
|
||||||
|
roleTemplateId=None,
|
||||||
|
),
|
||||||
|
MicrosoftGraphAPIUserGroupDirectoryObject(
|
||||||
|
odata_type="#microsoft.graph.group",
|
||||||
|
id=None, # Test handling of None id
|
||||||
|
displayName="Invalid Group",
|
||||||
|
description=None,
|
||||||
|
deletedDateTime=None,
|
||||||
|
roleTemplateId=None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(mock_response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "group1" in result
|
||||||
|
assert "group2" in result
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue