mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
working graph api assignment
This commit is contained in:
parent
588c567d92
commit
0601ae55c4
3 changed files with 129 additions and 3 deletions
|
@ -11,7 +11,7 @@ Has all /sso/* routes
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
@ -19,6 +19,11 @@ from fastapi.responses import RedirectResponse
|
|||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
|
@ -51,6 +56,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
|||
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.proxy.management_endpoints.ui_sso import *
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi_sso.sso.base import OpenID
|
||||
|
@ -841,6 +847,10 @@ class MicrosoftSSOHandler:
|
|||
Handles Microsoft SSO callback response and returns a CustomOpenID object
|
||||
"""
|
||||
|
||||
graph_api_base_url = "https://graph.microsoft.com/v1.0"
|
||||
graph_api_user_groups_endpoint = f"{graph_api_base_url}/me/memberOf"
|
||||
MAX_GRAPH_API_PAGES = 200
|
||||
|
||||
@staticmethod
|
||||
async def get_microsoft_callback_response(
|
||||
request: Request,
|
||||
|
@ -889,15 +899,20 @@ class MicrosoftSSOHandler:
|
|||
if return_raw_sso_response:
|
||||
return original_msft_result or {}
|
||||
|
||||
user_team_ids = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
||||
access_token=microsoft_sso.access_token
|
||||
)
|
||||
|
||||
result = MicrosoftSSOHandler.openid_from_response(
|
||||
response=original_msft_result,
|
||||
jwt_handler=jwt_handler,
|
||||
team_ids=user_team_ids,
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def openid_from_response(
|
||||
response: Optional[dict], jwt_handler: JWTHandler
|
||||
response: Optional[dict], jwt_handler: JWTHandler, team_ids: List[str]
|
||||
) -> CustomOpenID:
|
||||
response = response or {}
|
||||
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
|
||||
|
@ -908,11 +923,101 @@ class MicrosoftSSOHandler:
|
|||
id=response.get("id"),
|
||||
first_name=response.get("givenName"),
|
||||
last_name=response.get("surname"),
|
||||
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
|
||||
team_ids=team_ids,
|
||||
)
|
||||
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
|
||||
return openid_response
|
||||
|
||||
@staticmethod
|
||||
async def get_user_groups_from_graph_api(
|
||||
access_token: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of `team_ids` the user belongs to from the Microsoft Graph API
|
||||
|
||||
Args:
|
||||
access_token (Optional[str]): Microsoft Graph API access token
|
||||
|
||||
Returns:
|
||||
List[str]: List of group IDs the user belongs to
|
||||
"""
|
||||
try:
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SSO_HANDLER
|
||||
)
|
||||
|
||||
all_group_ids = []
|
||||
next_link = MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||||
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
||||
page_count = 0
|
||||
|
||||
while next_link and page_count < MicrosoftSSOHandler.MAX_GRAPH_API_PAGES:
|
||||
group_ids, next_link = await MicrosoftSSOHandler.fetch_and_parse_groups(
|
||||
url=next_link, headers=auth_headers, async_client=async_client
|
||||
)
|
||||
all_group_ids.extend(group_ids)
|
||||
page_count += 1
|
||||
|
||||
if next_link and page_count >= MicrosoftSSOHandler.MAX_GRAPH_API_PAGES:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included."
|
||||
)
|
||||
|
||||
return all_group_ids
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error getting user groups from Microsoft Graph API: {e}"
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def fetch_and_parse_groups(
|
||||
url: str, headers: dict, async_client: AsyncHTTPHandler
|
||||
) -> Tuple[List[str], Optional[str]]:
|
||||
"""Helper function to fetch and parse group data from a URL"""
|
||||
response = await async_client.get(url, headers=headers)
|
||||
response_json = response.json()
|
||||
response_typed = await MicrosoftSSOHandler._cast_graph_api_response_dict(
|
||||
response=response_json
|
||||
)
|
||||
group_ids = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(
|
||||
response=response_typed
|
||||
)
|
||||
return group_ids, response_typed.get("odata_nextLink")
|
||||
|
||||
@staticmethod
|
||||
def _get_group_ids_from_graph_api_response(
|
||||
response: MicrosoftGraphAPIUserGroupResponse,
|
||||
) -> List[str]:
|
||||
group_ids = []
|
||||
for _object in response.get("value", []) or []:
|
||||
if _object.get("id") is not None:
|
||||
group_ids.append(_object.get("id"))
|
||||
return group_ids
|
||||
|
||||
@staticmethod
|
||||
async def _cast_graph_api_response_dict(
|
||||
response: dict,
|
||||
) -> MicrosoftGraphAPIUserGroupResponse:
|
||||
directory_objects: List[MicrosoftGraphAPIUserGroupDirectoryObject] = []
|
||||
for _object in response.get("value", []):
|
||||
directory_objects.append(
|
||||
MicrosoftGraphAPIUserGroupDirectoryObject(
|
||||
odata_type=_object.get("@odata.type"),
|
||||
id=_object.get("id"),
|
||||
deletedDateTime=_object.get("deletedDateTime"),
|
||||
description=_object.get("description"),
|
||||
displayName=_object.get("displayName"),
|
||||
roleTemplateId=_object.get("roleTemplateId"),
|
||||
)
|
||||
)
|
||||
return MicrosoftGraphAPIUserGroupResponse(
|
||||
odata_context=response.get("@odata.context"),
|
||||
odata_nextLink=response.get("@odata.nextLink"),
|
||||
value=directory_objects,
|
||||
)
|
||||
|
||||
|
||||
class GoogleSSOHandler:
|
||||
"""
|
||||
|
|
|
@ -19,6 +19,7 @@ class httpxSpecialProvider(str, Enum):
|
|||
SecretManager = "secret_manager"
|
||||
PassThroughEndpoint = "pass_through_endpoint"
|
||||
PromptFactory = "prompt_factory"
|
||||
SSO_HANDLER = "sso_handler"
|
||||
|
||||
|
||||
VerifyTypes = Union[str, bool, ssl.SSLContext]
|
||||
|
|
20
litellm/types/proxy/management_endpoints/ui_sso.py
Normal file
20
litellm/types/proxy/management_endpoints/ui_sso.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from typing import List, Optional, TypedDict
|
||||
|
||||
|
||||
class MicrosoftGraphAPIUserGroupDirectoryObject(TypedDict, total=False):
|
||||
"""Model for Microsoft Graph API directory object"""
|
||||
|
||||
odata_type: Optional[str]
|
||||
id: Optional[str]
|
||||
deletedDateTime: Optional[str]
|
||||
description: Optional[str]
|
||||
displayName: Optional[str]
|
||||
roleTemplateId: Optional[str]
|
||||
|
||||
|
||||
class MicrosoftGraphAPIUserGroupResponse(TypedDict, total=False):
|
||||
"""Model for Microsoft Graph API user groups response"""
|
||||
|
||||
odata_context: Optional[str]
|
||||
odata_nextLink: Optional[str]
|
||||
value: Optional[List[MicrosoftGraphAPIUserGroupDirectoryObject]]
|
Loading…
Add table
Add a link
Reference in a new issue