mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
[SSO] Connect LiteLLM to Azure Entra ID Enterprise Application (#9872)
* refactor SSO handler * render sso JWT on ui * docs debug sso * fix sso login flow use await * fix ui sso debug JWT * test ui sso * remove redis vl * fix redisvl==0.5.1 * fix ml dtypes * fix redisvl * fix redis vl * fix debug_sso_callback * fix linting error * fix redis semantic caching dep * working graph api assignment * test msft sso handler openid * testing for msft group assignment * fix debug graph api sso flow * fix linting errors * add_user_to_teams_from_sso_response * ui sso fix team assignments * linting fix _get_group_ids_from_graph_api_response * add MicrosoftServicePrincipalTeam * create_litellm_teams_from_service_principal_team_ids * create_litellm_teams_from_service_principal_team_ids * docs MICROSOFT_SERVICE_PRINCIPAL_ID * fix linting errors
This commit is contained in:
parent
50fc49e4bb
commit
1d9ec118dd
3 changed files with 219 additions and 34 deletions
|
@ -449,6 +449,7 @@ router_settings:
|
||||||
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
||||||
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
||||||
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
|
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
|
||||||
|
| MICROSOFT_SERVICE_PRINCIPAL_ID | Service Principal ID for Microsoft Enterprise Application. (This is an advanced feature if you want litellm to auto-assign members to Litellm Teams based on their Microsoft Entra ID Groups)
|
||||||
| NO_DOCS | Flag to disable documentation generation
|
| NO_DOCS | Flag to disable documentation generation
|
||||||
| NO_PROXY | List of addresses to bypass proxy
|
| NO_PROXY | List of addresses to bypass proxy
|
||||||
| OAUTH_TOKEN_INFO_ENDPOINT | Endpoint for OAuth token info retrieval
|
| OAUTH_TOKEN_INFO_ENDPOINT | Endpoint for OAuth token info retrieval
|
||||||
|
|
|
@ -28,6 +28,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
Member,
|
Member,
|
||||||
|
NewTeamRequest,
|
||||||
NewUserRequest,
|
NewUserRequest,
|
||||||
NewUserResponse,
|
NewUserResponse,
|
||||||
ProxyErrorTypes,
|
ProxyErrorTypes,
|
||||||
|
@ -53,8 +54,9 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||||
check_is_admin_only_access,
|
check_is_admin_only_access,
|
||||||
has_admin_ui_access,
|
has_admin_ui_access,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
|
||||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
from litellm.types.proxy.management_endpoints.ui_sso import *
|
from litellm.types.proxy.management_endpoints.ui_sso import *
|
||||||
|
|
||||||
|
@ -461,40 +463,22 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}"
|
f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_info is not None:
|
# Upsert SSO User to LiteLLM DB
|
||||||
user_id = user_info.user_id
|
user_info = await SSOAuthenticationHandler.upsert_sso_user(
|
||||||
user_defined_values = SSOUserDefinedValues(
|
result=result,
|
||||||
models=getattr(user_info, "models", user_id_models),
|
user_info=user_info,
|
||||||
user_id=user_info.user_id,
|
user_email=user_email,
|
||||||
user_email=getattr(user_info, "user_email", user_email),
|
user_id_models=user_id_models,
|
||||||
user_role=getattr(user_info, "user_role", None),
|
max_internal_user_budget=max_internal_user_budget,
|
||||||
max_budget=getattr(
|
internal_user_budget_duration=internal_user_budget_duration,
|
||||||
user_info, "max_budget", max_internal_user_budget
|
user_defined_values=user_defined_values,
|
||||||
),
|
prisma_client=prisma_client,
|
||||||
budget_duration=getattr(
|
)
|
||||||
user_info, "budget_duration", internal_user_budget_duration
|
if user_info and user_info.user_role is not None:
|
||||||
),
|
user_role = user_info.user_role
|
||||||
)
|
|
||||||
|
|
||||||
user_role = getattr(user_info, "user_role", None)
|
|
||||||
|
|
||||||
# update id
|
|
||||||
await prisma_client.db.litellm_usertable.update_many(
|
|
||||||
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
verbose_proxy_logger.info(
|
user_role = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
"user not in DB, inserting user into LiteLLM DB"
|
|
||||||
)
|
|
||||||
# user not in DB, insert User into LiteLLM DB
|
|
||||||
user_info = await insert_sso_user(
|
|
||||||
result_openid=result,
|
|
||||||
user_defined_values=user_defined_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
user_role = (
|
|
||||||
user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
|
||||||
)
|
|
||||||
await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
|
await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
|
||||||
result=result,
|
result=result,
|
||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
|
@ -842,10 +826,61 @@ class SSOAuthenticationHandler:
|
||||||
redirect_url += "/" + sso_callback_route
|
redirect_url += "/" + sso_callback_route
|
||||||
return redirect_url
|
return redirect_url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def upsert_sso_user(
|
||||||
|
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||||||
|
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
|
||||||
|
user_email: Optional[str],
|
||||||
|
user_id_models: List[str],
|
||||||
|
max_internal_user_budget: Optional[float],
|
||||||
|
internal_user_budget_duration: Optional[str],
|
||||||
|
user_defined_values: Optional[SSOUserDefinedValues],
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Connects the SSO Users to the User Table in LiteLLM DB
|
||||||
|
|
||||||
|
- If user on LiteLLM DB, update the user_id with the SSO user_id
|
||||||
|
- If user not on LiteLLM DB, insert the user into LiteLLM DB
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if user_info is not None:
|
||||||
|
user_id = user_info.user_id
|
||||||
|
user_defined_values = SSOUserDefinedValues(
|
||||||
|
models=getattr(user_info, "models", user_id_models),
|
||||||
|
user_id=user_info.user_id or "",
|
||||||
|
user_email=getattr(user_info, "user_email", user_email),
|
||||||
|
user_role=getattr(user_info, "user_role", None),
|
||||||
|
max_budget=getattr(
|
||||||
|
user_info, "max_budget", max_internal_user_budget
|
||||||
|
),
|
||||||
|
budget_duration=getattr(
|
||||||
|
user_info, "budget_duration", internal_user_budget_duration
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# update id
|
||||||
|
await prisma_client.db.litellm_usertable.update_many(
|
||||||
|
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
"user not in DB, inserting user into LiteLLM DB"
|
||||||
|
)
|
||||||
|
# user not in DB, insert User into LiteLLM DB
|
||||||
|
user_info = await insert_sso_user(
|
||||||
|
result_openid=result,
|
||||||
|
user_defined_values=user_defined_values,
|
||||||
|
)
|
||||||
|
return user_info
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(f"Error upserting SSO user into LiteLLM DB: {e}")
|
||||||
|
return user_info
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def add_user_to_teams_from_sso_response(
|
async def add_user_to_teams_from_sso_response(
|
||||||
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||||||
user_info: Union[NewUserResponse, LiteLLM_UserTable],
|
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Adds the user as a team member to the teams specified in the SSO responses `team_ids` field
|
Adds the user as a team member to the teams specified in the SSO responses `team_ids` field
|
||||||
|
@ -853,6 +888,11 @@ class SSOAuthenticationHandler:
|
||||||
|
|
||||||
The `team_ids` field is populated by litellm after processing the SSO response
|
The `team_ids` field is populated by litellm after processing the SSO response
|
||||||
"""
|
"""
|
||||||
|
if user_info is None:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"User not found in LiteLLM DB, skipping team member addition"
|
||||||
|
)
|
||||||
|
return
|
||||||
sso_teams = getattr(result, "team_ids", [])
|
sso_teams = getattr(result, "team_ids", [])
|
||||||
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
||||||
|
|
||||||
|
@ -972,6 +1012,27 @@ class MicrosoftSSOHandler:
|
||||||
llm_provider=httpxSpecialProvider.SSO_HANDLER
|
llm_provider=httpxSpecialProvider.SSO_HANDLER
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle MSFT Enterprise Application Groups
|
||||||
|
service_principal_id = os.getenv("MICROSOFT_SERVICE_PRINCIPAL_ID", None)
|
||||||
|
service_principal_group_ids: Optional[List[str]] = []
|
||||||
|
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
|
||||||
|
if service_principal_id:
|
||||||
|
service_principal_group_ids, service_principal_teams = (
|
||||||
|
await MicrosoftSSOHandler.get_group_ids_from_service_principal(
|
||||||
|
service_principal_id=service_principal_id,
|
||||||
|
async_client=async_client,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Service principal group IDs: {service_principal_group_ids}"
|
||||||
|
)
|
||||||
|
if len(service_principal_group_ids) > 0:
|
||||||
|
await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
|
||||||
|
service_principal_teams=service_principal_teams,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch user membership from Microsoft Graph API
|
||||||
all_group_ids = []
|
all_group_ids = []
|
||||||
next_link: Optional[str] = (
|
next_link: Optional[str] = (
|
||||||
MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||||||
|
@ -997,6 +1058,14 @@ class MicrosoftSSOHandler:
|
||||||
f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included."
|
f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If service_principal_group_ids is not empty, only return group_ids that are in both all_group_ids and service_principal_group_ids
|
||||||
|
if service_principal_group_ids and len(service_principal_group_ids) > 0:
|
||||||
|
all_group_ids = [
|
||||||
|
group_id
|
||||||
|
for group_id in all_group_ids
|
||||||
|
if group_id in service_principal_group_ids
|
||||||
|
]
|
||||||
|
|
||||||
return all_group_ids
|
return all_group_ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1053,6 +1122,114 @@ class MicrosoftSSOHandler:
|
||||||
value=directory_objects,
|
value=directory_objects,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_group_ids_from_service_principal(
|
||||||
|
service_principal_id: str,
|
||||||
|
async_client: AsyncHTTPHandler,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
) -> Tuple[List[str], List[MicrosoftServicePrincipalTeam]]:
|
||||||
|
"""
|
||||||
|
Gets the groups belonging to the Service Principal Application
|
||||||
|
|
||||||
|
Service Principal Id is an `Enterprise Application` in Azure AD
|
||||||
|
|
||||||
|
Users use Enterprise Applications to manage Groups and Users on Microsoft Entra ID
|
||||||
|
"""
|
||||||
|
base_url = "https://graph.microsoft.com/v1.0"
|
||||||
|
# Endpoint to get app role assignments for the given service principal
|
||||||
|
endpoint = f"/servicePrincipals/{service_principal_id}/appRoleAssignedTo"
|
||||||
|
url = base_url + endpoint
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.get(url, headers=headers)
|
||||||
|
response_json = response.json()
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Response from service principal app role assigned to: {response_json}"
|
||||||
|
)
|
||||||
|
group_ids: List[str] = []
|
||||||
|
service_principal_teams: List[MicrosoftServicePrincipalTeam] = []
|
||||||
|
|
||||||
|
for _object in response_json.get("value", []):
|
||||||
|
if _object.get("principalType") == "Group":
|
||||||
|
# Append the group ID to the list
|
||||||
|
group_ids.append(_object.get("principalId"))
|
||||||
|
# Append the service principal team to the list
|
||||||
|
service_principal_teams.append(
|
||||||
|
MicrosoftServicePrincipalTeam(
|
||||||
|
principalDisplayName=_object.get("principalDisplayName"),
|
||||||
|
principalId=_object.get("principalId"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return group_ids, service_principal_teams
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_litellm_teams_from_service_principal_team_ids(
|
||||||
|
service_principal_teams: List[MicrosoftServicePrincipalTeam],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates Litellm Teams from the Service Principal Group IDs
|
||||||
|
|
||||||
|
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Prisma client not found. Set it in the proxy_server.py file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="prisma_client",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
|
||||||
|
)
|
||||||
|
for service_principal_team in service_principal_teams:
|
||||||
|
litellm_team_id: Optional[str] = service_principal_team.get("principalId")
|
||||||
|
litellm_team_name: Optional[str] = service_principal_team.get(
|
||||||
|
"principalDisplayName"
|
||||||
|
)
|
||||||
|
if not litellm_team_id:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Skipping team creation for {litellm_team_name} because it has no principalId"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
team_obj = await prisma_client.db.litellm_teamtable.find_first(
|
||||||
|
where={"team_id": litellm_team_id}
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(f"Team object: {team_obj}")
|
||||||
|
|
||||||
|
# only create a new team if it doesn't exist
|
||||||
|
if team_obj:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Team already exists: {litellm_team_id} - {litellm_team_name}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
await new_team(
|
||||||
|
data=NewTeamRequest(
|
||||||
|
team_id=litellm_team_id,
|
||||||
|
team_alias=litellm_team_name,
|
||||||
|
),
|
||||||
|
# params used for Audit Logging
|
||||||
|
http_request=Request(scope={"type": "http", "method": "POST"}),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
token="",
|
||||||
|
key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
|
||||||
|
|
||||||
|
|
||||||
class GoogleSSOHandler:
|
class GoogleSSOHandler:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -18,3 +18,10 @@ class MicrosoftGraphAPIUserGroupResponse(TypedDict, total=False):
|
||||||
odata_context: Optional[str]
|
odata_context: Optional[str]
|
||||||
odata_nextLink: Optional[str]
|
odata_nextLink: Optional[str]
|
||||||
value: Optional[List[MicrosoftGraphAPIUserGroupDirectoryObject]]
|
value: Optional[List[MicrosoftGraphAPIUserGroupDirectoryObject]]
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftServicePrincipalTeam(TypedDict, total=False):
|
||||||
|
"""Model for Microsoft Service Principal Team"""
|
||||||
|
|
||||||
|
principalDisplayName: Optional[str]
|
||||||
|
principalId: Optional[str]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue