Litellm dev 01 08 2025 p1 (#7640)

* feat(ui_sso.py): support reading team ids from sso token

* feat(ui_sso.py): working upsert sso user teams membership in litellm - if team exists

Adds user to relevant teams, if user is part of teams and team exists on litellm

* fix(ui_sso.py): safely handle add team member task

* build(ui/): support setting team id when creating team on UI

* build(ui/): teams.tsx

allow setting team id on ui

* build(circle_ci/requirements.txt): add fastapi-sso to ci/cd testing

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2025-01-08 22:08:20 -08:00 committed by GitHub
parent 6d8cfeaf14
commit b77832a793
10 changed files with 269 additions and 120 deletions

View file

@ -8,7 +8,7 @@ Has all /sso/* routes
import asyncio
import os
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
@ -18,13 +18,17 @@ from litellm._logging import verbose_proxy_logger
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
from litellm.proxy._types import (
LitellmUserRoles,
Member,
NewUserRequest,
NewUserResponse,
ProxyErrorTypes,
ProxyException,
SSOUserDefinedValues,
TeamMemberAddRequest,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_utils import _has_user_setup_sso
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.admin_ui_utils import (
admin_ui_disabled,
@ -36,6 +40,8 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
check_is_admin_only_access,
has_admin_ui_access,
)
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
from litellm.proxy.management_endpoints.types import CustomOpenID
from litellm.secret_managers.main import str_to_bool
if TYPE_CHECKING:
@ -221,6 +227,170 @@ async def google_login(request: Request): # noqa: PLR0915
return HTMLResponse(content=html_form, status_code=200)
def generic_response_convertor(response, jwt_handler: JWTHandler):
generic_user_id_attribute_name = os.getenv(
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
)
generic_user_display_name_attribute_name = os.getenv(
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
)
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_first_name_attribute_name = os.getenv(
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
)
generic_user_last_name_attribute_name = os.getenv(
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
)
generic_provider_attribute_name = os.getenv(
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}"
)
return CustomOpenID(
id=response.get(generic_user_id_attribute_name),
display_name=response.get(generic_user_display_name_attribute_name),
email=response.get(generic_user_email_attribute_name),
first_name=response.get(generic_user_first_name_attribute_name),
last_name=response.get(generic_user_last_name_attribute_name),
provider=response.get(generic_provider_attribute_name),
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
)
async def get_generic_sso_response(
request: Request,
jwt_handler: JWTHandler,
generic_client_id: str,
redirect_url: str,
) -> Optional[OpenID]:
# make generic sso provider
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
generic_include_client_id = (
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
def response_convertor(response, client):
return generic_response_convertor(
response=response,
jwt_handler=jwt_handler,
)
SSOProvider = create_provider(
name="oidc",
discovery_document=discovery,
response_convertor=response_convertor,
)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
scope=generic_scope,
)
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
result = await generic_sso.verify_and_process(
request, params={"include_client_id": generic_include_client_id}
)
verbose_proxy_logger.debug("generic result: %s", result)
return result
async def create_team_member_add_task(team_id, user_info):
"""Create a task for adding a member to a team."""
try:
member = Member(user_id=user_info.user_id, role="user")
team_member_add_request = TeamMemberAddRequest(
member=member,
team_id=team_id,
)
return await team_member_add(
data=team_member_add_request,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
http_request=Request(scope={"type": "http", "path": "/sso/callback"}),
)
except Exception as e:
verbose_proxy_logger.debug(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
async def add_missing_team_member(user_info: NewUserResponse, sso_teams: List[str]):
"""
- Get missing teams (diff b/w user_info.team_ids and sso_teams)
- Add missing user to missing teams
"""
if user_info.teams is None:
return
missing_teams = set(sso_teams) - set(user_info.teams)
missing_teams_list = list(missing_teams)
tasks = []
tasks = [
create_team_member_add_task(team_id, user_info)
for team_id in missing_teams_list
]
try:
await asyncio.gather(*tasks)
except Exception as e:
verbose_proxy_logger.debug(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
async def auth_callback(request: Request): # noqa: PLR0915
"""Verify login"""
@ -229,6 +399,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
)
from litellm.proxy.proxy_server import (
general_settings,
jwt_handler,
master_key,
premium_user,
prisma_client,
@ -299,116 +470,12 @@ async def auth_callback(request: Request): # noqa: PLR0915
)
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
result = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
generic_include_client_id = (
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
generic_user_id_attribute_name = os.getenv(
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
)
generic_user_display_name_attribute_name = os.getenv(
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
)
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
generic_user_first_name_attribute_name = os.getenv(
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
)
generic_user_last_name_attribute_name = os.getenv(
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
)
generic_provider_attribute_name = os.getenv(
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
def response_convertor(response, client):
return OpenID(
id=response.get(generic_user_id_attribute_name),
display_name=response.get(generic_user_display_name_attribute_name),
email=response.get(generic_user_email_attribute_name),
first_name=response.get(generic_user_first_name_attribute_name),
last_name=response.get(generic_user_last_name_attribute_name),
provider=response.get(generic_provider_attribute_name),
)
SSOProvider = create_provider(
name="oidc",
discovery_document=discovery,
response_convertor=response_convertor,
)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
scope=generic_scope,
)
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
result = await generic_sso.verify_and_process(
request, params={"include_client_id": generic_include_client_id}
)
verbose_proxy_logger.debug("generic result: %s", result)
# User is Authe'd in - generate key for the UI to access Proxy
user_email: Optional[str] = getattr(result, "email", None)
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
@ -428,6 +495,9 @@ async def auth_callback(request: Request): # noqa: PLR0915
# generic client id
if generic_client_id is not None and result is not None:
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
user_id = getattr(result, "id", None)
user_email = getattr(result, "email", None)
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
@ -508,12 +578,21 @@ async def auth_callback(request: Request): # noqa: PLR0915
)
else:
# user not in DB, insert User into LiteLLM DB
user_role = await insert_sso_user(
user_info = await insert_sso_user(
result_openid=result,
user_defined_values=user_defined_values,
)
except Exception:
pass
user_role = (
user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
)
sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
except Exception as e:
verbose_proxy_logger.debug(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
if user_defined_values is None:
raise Exception(
@ -588,13 +667,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
async def insert_sso_user(
result_openid: Optional[OpenID],
user_defined_values: Optional[SSOUserDefinedValues] = None,
) -> str:
) -> NewUserResponse:
"""
Helper function to create a New User in LiteLLM DB after a successful SSO login
Args:
result_openid (OpenID): User information in OpenID format if the login was successful.
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
Returns:
Tuple[str, str]: User ID and User Role
"""
verbose_proxy_logger.debug(
f"Inserting SSO user into DB. User values: {user_defined_values}"
@ -629,9 +711,9 @@ async def insert_sso_user(
if result_openid:
new_user_request.metadata = {"auth_provider": result_openid.provider}
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
response = await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
return response
@router.get(