mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Litellm dev 01 08 2025 p1 (#7640)
* feat(ui_sso.py): support reading team ids from sso token * feat(ui_sso.py): working upsert sso user teams membership in litellm - if team exists Adds user to relevant teams, if user is part of teams and team exists on litellm * fix(ui_sso.py): safely handle add team member task * build(ui/): support setting team id when creating team on UI * build(ui/): teams.tsx allow setting team id on ui * build(circle_ci/requirements.txt): add fastapi-sso to ci/cd testing * fix: fix linting errors
This commit is contained in:
parent
6d8cfeaf14
commit
b77832a793
10 changed files with 269 additions and 120 deletions
|
@ -8,7 +8,7 @@ Has all /sso/* routes
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
@ -18,13 +18,17 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
||||
from litellm.proxy._types import (
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
NewUserRequest,
|
||||
NewUserResponse,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
SSOUserDefinedValues,
|
||||
TeamMemberAddRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import _has_user_setup_sso
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.admin_ui_utils import (
|
||||
admin_ui_disabled,
|
||||
|
@ -36,6 +40,8 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
|||
check_is_admin_only_access,
|
||||
has_admin_ui_access,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -221,6 +227,170 @@ async def google_login(request: Request): # noqa: PLR0915
|
|||
return HTMLResponse(content=html_form, status_code=200)
|
||||
|
||||
|
||||
def generic_response_convertor(response, jwt_handler: JWTHandler):
|
||||
generic_user_id_attribute_name = os.getenv(
|
||||
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
||||
)
|
||||
generic_user_display_name_attribute_name = os.getenv(
|
||||
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
||||
)
|
||||
generic_user_email_attribute_name = os.getenv(
|
||||
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||||
)
|
||||
|
||||
generic_user_first_name_attribute_name = os.getenv(
|
||||
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
||||
)
|
||||
generic_user_last_name_attribute_name = os.getenv(
|
||||
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
||||
)
|
||||
|
||||
generic_provider_attribute_name = os.getenv(
|
||||
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}"
|
||||
)
|
||||
|
||||
return CustomOpenID(
|
||||
id=response.get(generic_user_id_attribute_name),
|
||||
display_name=response.get(generic_user_display_name_attribute_name),
|
||||
email=response.get(generic_user_email_attribute_name),
|
||||
first_name=response.get(generic_user_first_name_attribute_name),
|
||||
last_name=response.get(generic_user_last_name_attribute_name),
|
||||
provider=response.get(generic_provider_attribute_name),
|
||||
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
|
||||
)
|
||||
|
||||
|
||||
async def get_generic_sso_response(
|
||||
request: Request,
|
||||
jwt_handler: JWTHandler,
|
||||
generic_client_id: str,
|
||||
redirect_url: str,
|
||||
) -> Optional[OpenID]:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
generic_include_client_id = (
|
||||
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
||||
)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_authorization_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_token_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_TOKEN_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_userinfo_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_USERINFO_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
|
||||
def response_convertor(response, client):
|
||||
return generic_response_convertor(
|
||||
response=response,
|
||||
jwt_handler=jwt_handler,
|
||||
)
|
||||
|
||||
SSOProvider = create_provider(
|
||||
name="oidc",
|
||||
discovery_document=discovery,
|
||||
response_convertor=response_convertor,
|
||||
)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
scope=generic_scope,
|
||||
)
|
||||
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
||||
result = await generic_sso.verify_and_process(
|
||||
request, params={"include_client_id": generic_include_client_id}
|
||||
)
|
||||
verbose_proxy_logger.debug("generic result: %s", result)
|
||||
return result
|
||||
|
||||
|
||||
async def create_team_member_add_task(team_id, user_info):
|
||||
"""Create a task for adding a member to a team."""
|
||||
try:
|
||||
member = Member(user_id=user_info.user_id, role="user")
|
||||
team_member_add_request = TeamMemberAddRequest(
|
||||
member=member,
|
||||
team_id=team_id,
|
||||
)
|
||||
return await team_member_add(
|
||||
data=team_member_add_request,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
http_request=Request(scope={"type": "http", "path": "/sso/callback"}),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||||
)
|
||||
|
||||
|
||||
async def add_missing_team_member(user_info: NewUserResponse, sso_teams: List[str]):
|
||||
"""
|
||||
- Get missing teams (diff b/w user_info.team_ids and sso_teams)
|
||||
- Add missing user to missing teams
|
||||
"""
|
||||
if user_info.teams is None:
|
||||
return
|
||||
missing_teams = set(sso_teams) - set(user_info.teams)
|
||||
missing_teams_list = list(missing_teams)
|
||||
tasks = []
|
||||
tasks = [
|
||||
create_team_member_add_task(team_id, user_info)
|
||||
for team_id in missing_teams_list
|
||||
]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
||||
async def auth_callback(request: Request): # noqa: PLR0915
|
||||
"""Verify login"""
|
||||
|
@ -229,6 +399,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
jwt_handler,
|
||||
master_key,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
|
@ -299,116 +470,12 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
)
|
||||
result = await microsoft_sso.verify_and_process(request)
|
||||
elif generic_client_id is not None:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
generic_authorization_endpoint = os.getenv(
|
||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||
result = await get_generic_sso_response(
|
||||
request=request,
|
||||
jwt_handler=jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
generic_include_client_id = (
|
||||
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
||||
)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_authorization_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_token_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_TOKEN_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_userinfo_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_USERINFO_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
|
||||
generic_user_id_attribute_name = os.getenv(
|
||||
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
||||
)
|
||||
generic_user_display_name_attribute_name = os.getenv(
|
||||
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
||||
)
|
||||
generic_user_email_attribute_name = os.getenv(
|
||||
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||||
)
|
||||
generic_user_role_attribute_name = os.getenv(
|
||||
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||||
)
|
||||
generic_user_first_name_attribute_name = os.getenv(
|
||||
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
||||
)
|
||||
generic_user_last_name_attribute_name = os.getenv(
|
||||
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
||||
)
|
||||
|
||||
generic_provider_attribute_name = os.getenv(
|
||||
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
||||
)
|
||||
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
|
||||
def response_convertor(response, client):
|
||||
return OpenID(
|
||||
id=response.get(generic_user_id_attribute_name),
|
||||
display_name=response.get(generic_user_display_name_attribute_name),
|
||||
email=response.get(generic_user_email_attribute_name),
|
||||
first_name=response.get(generic_user_first_name_attribute_name),
|
||||
last_name=response.get(generic_user_last_name_attribute_name),
|
||||
provider=response.get(generic_provider_attribute_name),
|
||||
)
|
||||
|
||||
SSOProvider = create_provider(
|
||||
name="oidc",
|
||||
discovery_document=discovery,
|
||||
response_convertor=response_convertor,
|
||||
)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
scope=generic_scope,
|
||||
)
|
||||
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
||||
result = await generic_sso.verify_and_process(
|
||||
request, params={"include_client_id": generic_include_client_id}
|
||||
)
|
||||
verbose_proxy_logger.debug("generic result: %s", result)
|
||||
|
||||
# User is Authe'd in - generate key for the UI to access Proxy
|
||||
user_email: Optional[str] = getattr(result, "email", None)
|
||||
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
||||
|
@ -428,6 +495,9 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
|
||||
# generic client id
|
||||
if generic_client_id is not None and result is not None:
|
||||
generic_user_role_attribute_name = os.getenv(
|
||||
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||||
)
|
||||
user_id = getattr(result, "id", None)
|
||||
user_email = getattr(result, "email", None)
|
||||
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
||||
|
@ -508,12 +578,21 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
)
|
||||
else:
|
||||
# user not in DB, insert User into LiteLLM DB
|
||||
user_role = await insert_sso_user(
|
||||
user_info = await insert_sso_user(
|
||||
result_openid=result,
|
||||
user_defined_values=user_defined_values,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
user_role = (
|
||||
user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
)
|
||||
sso_teams = getattr(result, "team_ids", [])
|
||||
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||||
)
|
||||
|
||||
if user_defined_values is None:
|
||||
raise Exception(
|
||||
|
@ -588,13 +667,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
async def insert_sso_user(
|
||||
result_openid: Optional[OpenID],
|
||||
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
||||
) -> str:
|
||||
) -> NewUserResponse:
|
||||
"""
|
||||
Helper function to create a New User in LiteLLM DB after a successful SSO login
|
||||
|
||||
Args:
|
||||
result_openid (OpenID): User information in OpenID format if the login was successful.
|
||||
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: User ID and User Role
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
||||
|
@ -629,9 +711,9 @@ async def insert_sso_user(
|
|||
if result_openid:
|
||||
new_user_request.metadata = {"auth_provider": result_openid.provider}
|
||||
|
||||
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
|
||||
response = await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
|
||||
|
||||
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue