diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 455bdda938..863349a8fe 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -449,6 +449,7 @@ router_settings: | MICROSOFT_CLIENT_ID | Client ID for Microsoft services | MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services | MICROSOFT_TENANT | Tenant ID for Microsoft Azure +| MICROSOFT_SERVICE_PRINCIPAL_ID | Service Principal ID for Microsoft Enterprise Application. (This is an advanced feature if you want litellm to auto-assign members to Litellm Teams based on their Microsoft Entra ID Groups) | NO_DOCS | Flag to disable documentation generation | NO_PROXY | List of addresses to bypass proxy | OAUTH_TOKEN_INFO_ENDPOINT | Endpoint for OAuth token info retrieval diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 2fe86d4e6c..0cd3600220 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -28,6 +28,7 @@ from litellm.proxy._types import ( LiteLLM_UserTable, LitellmUserRoles, Member, + NewTeamRequest, NewUserRequest, NewUserResponse, ProxyErrorTypes, @@ -53,8 +54,9 @@ from litellm.proxy.management_endpoints.sso_helper_utils import ( check_is_admin_only_access, has_admin_ui_access, ) -from litellm.proxy.management_endpoints.team_endpoints import team_member_add +from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add from litellm.proxy.management_endpoints.types import CustomOpenID +from litellm.proxy.utils import PrismaClient from litellm.secret_managers.main import str_to_bool from litellm.types.proxy.management_endpoints.ui_sso import * @@ -461,40 +463,22 @@ async def auth_callback(request: Request): # noqa: PLR0915 f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}" ) - if user_info is not None: - user_id = user_info.user_id - user_defined_values = SSOUserDefinedValues( - models=getattr(user_info, "models", user_id_models), - user_id=user_info.user_id, - user_email=getattr(user_info, "user_email", user_email), - user_role=getattr(user_info, "user_role", None), - max_budget=getattr( - user_info, "max_budget", max_internal_user_budget - ), - budget_duration=getattr( - user_info, "budget_duration", internal_user_budget_duration - ), - ) - - user_role = getattr(user_info, "user_role", None) - - # update id - await prisma_client.db.litellm_usertable.update_many( - where={"user_email": user_email}, data={"user_id": user_id} # type: ignore - ) + # Upsert SSO User to LiteLLM DB + user_info = await SSOAuthenticationHandler.upsert_sso_user( + result=result, + user_info=user_info, + user_email=user_email, + user_id_models=user_id_models, + max_internal_user_budget=max_internal_user_budget, + internal_user_budget_duration=internal_user_budget_duration, + user_defined_values=user_defined_values, + prisma_client=prisma_client, + ) + if user_info and user_info.user_role is not None: + user_role = user_info.user_role else: - verbose_proxy_logger.info( - "user not in DB, inserting user into LiteLLM DB" - ) - # user not in DB, insert User into LiteLLM DB - user_info = await insert_sso_user( - result_openid=result, - user_defined_values=user_defined_values, - ) + user_role = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY - user_role = ( - user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY - ) await SSOAuthenticationHandler.add_user_to_teams_from_sso_response( result=result, user_info=user_info, @@ -842,10 +826,61 @@ class SSOAuthenticationHandler: redirect_url += "/" + sso_callback_route return redirect_url + @staticmethod + async def upsert_sso_user( + result: Optional[Union[CustomOpenID, OpenID, dict]], + user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]], + user_email: Optional[str], + user_id_models: List[str], + max_internal_user_budget: Optional[float], + internal_user_budget_duration: Optional[str], + user_defined_values: Optional[SSOUserDefinedValues], + prisma_client: PrismaClient, + ): + """ + Connects the SSO Users to the User Table in LiteLLM DB + + - If user on LiteLLM DB, update the user_id with the SSO user_id + - If user not on LiteLLM DB, insert the user into LiteLLM DB + """ + try: + if user_info is not None: + user_id = user_info.user_id + user_defined_values = SSOUserDefinedValues( + models=getattr(user_info, "models", user_id_models), + user_id=user_info.user_id or "", + user_email=getattr(user_info, "user_email", user_email), + user_role=getattr(user_info, "user_role", None), + max_budget=getattr( + user_info, "max_budget", max_internal_user_budget + ), + budget_duration=getattr( + user_info, "budget_duration", internal_user_budget_duration + ), + ) + + # update id + await prisma_client.db.litellm_usertable.update_many( + where={"user_email": user_email}, data={"user_id": user_id} # type: ignore + ) + else: + verbose_proxy_logger.info( + "user not in DB, inserting user into LiteLLM DB" + ) + # user not in DB, insert User into LiteLLM DB + user_info = await insert_sso_user( + result_openid=result, + user_defined_values=user_defined_values, + ) + return user_info + except Exception as e: + verbose_proxy_logger.error(f"Error upserting SSO user into LiteLLM DB: {e}") + return user_info + @staticmethod async def add_user_to_teams_from_sso_response( result: Optional[Union[CustomOpenID, OpenID, dict]], - user_info: Union[NewUserResponse, LiteLLM_UserTable], + user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]], ): """ Adds the user as a team member to the teams specified in the SSO responses `team_ids` field @@ -853,6 +888,11 @@ class SSOAuthenticationHandler: The `team_ids` field is populated by litellm after processing the SSO response """ + if user_info is None: + verbose_proxy_logger.debug( + "User not found in LiteLLM DB, skipping team member addition" + ) + return sso_teams = getattr(result, "team_ids", []) await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) @@ -972,6 +1012,27 @@ class MicrosoftSSOHandler: llm_provider=httpxSpecialProvider.SSO_HANDLER ) + # Handle MSFT Enterprise Application Groups + service_principal_id = os.getenv("MICROSOFT_SERVICE_PRINCIPAL_ID", None) + service_principal_group_ids: Optional[List[str]] = [] + service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = [] + if service_principal_id: + service_principal_group_ids, service_principal_teams = ( + await MicrosoftSSOHandler.get_group_ids_from_service_principal( + service_principal_id=service_principal_id, + async_client=async_client, + access_token=access_token, + ) + ) + verbose_proxy_logger.debug( + f"Service principal group IDs: {service_principal_group_ids}" + ) + if len(service_principal_group_ids) > 0: + await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids( + service_principal_teams=service_principal_teams, + ) + + # Fetch user membership from Microsoft Graph API all_group_ids = [] next_link: Optional[str] = ( MicrosoftSSOHandler.graph_api_user_groups_endpoint @@ -997,6 +1058,14 @@ class MicrosoftSSOHandler: f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included." ) + # If service_principal_group_ids is not empty, only return group_ids that are in both all_group_ids and service_principal_group_ids + if service_principal_group_ids and len(service_principal_group_ids) > 0: + all_group_ids = [ + group_id + for group_id in all_group_ids + if group_id in service_principal_group_ids + ] + return all_group_ids except Exception as e: @@ -1053,6 +1122,114 @@ class MicrosoftSSOHandler: value=directory_objects, ) + @staticmethod + async def get_group_ids_from_service_principal( + service_principal_id: str, + async_client: AsyncHTTPHandler, + access_token: Optional[str] = None, + ) -> Tuple[List[str], List[MicrosoftServicePrincipalTeam]]: + """ + Gets the groups belonging to the Service Principal Application + + Service Principal Id is an `Enterprise Application` in Azure AD + + Users use Enterprise Applications to manage Groups and Users on Microsoft Entra ID + """ + base_url = "https://graph.microsoft.com/v1.0" + # Endpoint to get app role assignments for the given service principal + endpoint = f"/servicePrincipals/{service_principal_id}/appRoleAssignedTo" + url = base_url + endpoint + + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + response = await async_client.get(url, headers=headers) + response_json = response.json() + verbose_proxy_logger.debug( + f"Response from service principal app role assigned to: {response_json}" + ) + group_ids: List[str] = [] + service_principal_teams: List[MicrosoftServicePrincipalTeam] = [] + + for _object in response_json.get("value", []): + if _object.get("principalType") == "Group": + # Append the group ID to the list + group_ids.append(_object.get("principalId")) + # Append the service principal team to the list + service_principal_teams.append( + MicrosoftServicePrincipalTeam( + principalDisplayName=_object.get("principalDisplayName"), + principalId=_object.get("principalId"), + ) + ) + + return group_ids, service_principal_teams + + @staticmethod + async def create_litellm_teams_from_service_principal_team_ids( + service_principal_teams: List[MicrosoftServicePrincipalTeam], + ): + """ + Creates Litellm Teams from the Service Principal Group IDs + + When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise ProxyException( + message="Prisma client not found. Set it in the proxy_server.py file", + type=ProxyErrorTypes.auth_error, + param="prisma_client", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug( + f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}" + ) + for service_principal_team in service_principal_teams: + litellm_team_id: Optional[str] = service_principal_team.get("principalId") + litellm_team_name: Optional[str] = service_principal_team.get( + "principalDisplayName" + ) + if not litellm_team_id: + verbose_proxy_logger.debug( + f"Skipping team creation for {litellm_team_name} because it has no principalId" + ) + continue + + try: + verbose_proxy_logger.debug( + f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}" + ) + + team_obj = await prisma_client.db.litellm_teamtable.find_first( + where={"team_id": litellm_team_id} + ) + verbose_proxy_logger.debug(f"Team object: {team_obj}") + + # only create a new team if it doesn't exist + if team_obj: + verbose_proxy_logger.debug( + f"Team already exists: {litellm_team_id} - {litellm_team_name}" + ) + continue + await new_team( + data=NewTeamRequest( + team_id=litellm_team_id, + team_alias=litellm_team_name, + ), + # params used for Audit Logging + http_request=Request(scope={"type": "http", "method": "POST"}), + user_api_key_dict=UserAPIKeyAuth( + token="", + key_alias=f"litellm.{MicrosoftSSOHandler.__name__}", + ), + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}") + class GoogleSSOHandler: """ diff --git a/litellm/types/proxy/management_endpoints/ui_sso.py b/litellm/types/proxy/management_endpoints/ui_sso.py index a706577f3d..ca17c47006 100644 --- a/litellm/types/proxy/management_endpoints/ui_sso.py +++ b/litellm/types/proxy/management_endpoints/ui_sso.py @@ -18,3 +18,10 @@ class MicrosoftGraphAPIUserGroupResponse(TypedDict, total=False): odata_context: Optional[str] odata_nextLink: Optional[str] value: Optional[List[MicrosoftGraphAPIUserGroupDirectoryObject]] + + +class MicrosoftServicePrincipalTeam(TypedDict, total=False): + """Model for Microsoft Service Principal Team""" + + principalDisplayName: Optional[str] + principalId: Optional[str]