diff --git a/.circleci/requirements.txt b/.circleci/requirements.txt index 578bfa5729..3f281a05cc 100644 --- a/.circleci/requirements.txt +++ b/.circleci/requirements.txt @@ -9,3 +9,4 @@ anthropic orjson==3.9.15 pydantic==2.7.1 google-cloud-aiplatform==1.43.0 +fastapi-sso==0.10.0 diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 670f250677..ae52950128 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -24,4 +24,4 @@ model_list: custom_tokenizer: identifier: deepseek-ai/DeepSeek-V3-Base revision: main - auth_token: os.environ/HUGGINGFACE_API_KEY + auth_token: os.environ/HUGGINGFACE_API_KEY \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 94e2b70d92..92522616fd 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -420,6 +420,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): "info_routes", ] team_id_jwt_field: Optional[str] = None + team_ids_jwt_field: Optional[str] = None + upsert_sso_user_to_team: bool = False team_allowed_routes: List[ Literal["openai_routes", "info_routes", "management_routes"] ] = ["openai_routes", "info_routes"] diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index bfcde8fb1e..3d2e952e1d 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -8,7 +8,7 @@ JWT token must have 'litellm_proxy_admin' in scope. import json import os -from typing import Optional, cast +from typing import List, Optional, cast from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -59,6 +59,11 @@ class JWTHandler: return True return False + def get_team_ids_from_jwt(self, token: dict) -> List[str]: + if self.litellm_jwtauth.team_ids_jwt_field is not None: + return token[self.litellm_jwtauth.team_ids_jwt_field] + return [] + def get_end_user_id( self, token: dict, default_value: Optional[str] ) -> Optional[str]: diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 668c2db770..cab4850f2a 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -457,7 +457,7 @@ async def update_team( if existing_team_row is None: raise HTTPException( - status_code=404, + status_code=400, detail={"error": f"Team not found, passed team_id={data.team_id}"}, ) diff --git a/litellm/proxy/management_endpoints/types.py b/litellm/proxy/management_endpoints/types.py new file mode 100644 index 0000000000..0e811669d1 --- /dev/null +++ b/litellm/proxy/management_endpoints/types.py @@ -0,0 +1,13 @@ +""" +Types for the management endpoints + +Might include fastapi/proxy requirements.txt related imports +""" + +from typing import List + +from fastapi_sso.sso.base import OpenID + + +class CustomOpenID(OpenID): + team_ids: List[str] diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 29c22faa33..9d8c4dcf8b 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -8,7 +8,7 @@ Has all /sso/* routes import asyncio import os import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse @@ -18,13 +18,17 @@ from litellm._logging import verbose_proxy_logger from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY from litellm.proxy._types import ( LitellmUserRoles, + Member, NewUserRequest, + NewUserResponse, ProxyErrorTypes, ProxyException, SSOUserDefinedValues, + TeamMemberAddRequest, UserAPIKeyAuth, ) from litellm.proxy.auth.auth_utils import _has_user_setup_sso +from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.admin_ui_utils import ( admin_ui_disabled, @@ -36,6 +40,8 @@ from litellm.proxy.management_endpoints.sso_helper_utils import ( check_is_admin_only_access, has_admin_ui_access, ) +from litellm.proxy.management_endpoints.team_endpoints import team_member_add +from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.secret_managers.main import str_to_bool if TYPE_CHECKING: @@ -221,6 +227,170 @@ async def google_login(request: Request): # noqa: PLR0915 return HTMLResponse(content=html_form, status_code=200) +def generic_response_convertor(response, jwt_handler: JWTHandler): + generic_user_id_attribute_name = os.getenv( + "GENERIC_USER_ID_ATTRIBUTE", "preferred_username" + ) + generic_user_display_name_attribute_name = os.getenv( + "GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub" + ) + generic_user_email_attribute_name = os.getenv( + "GENERIC_USER_EMAIL_ATTRIBUTE", "email" + ) + + generic_user_first_name_attribute_name = os.getenv( + "GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name" + ) + generic_user_last_name_attribute_name = os.getenv( + "GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name" + ) + + generic_provider_attribute_name = os.getenv( + "GENERIC_USER_PROVIDER_ATTRIBUTE", "provider" + ) + + verbose_proxy_logger.debug( + f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}" + ) + + return CustomOpenID( + id=response.get(generic_user_id_attribute_name), + display_name=response.get(generic_user_display_name_attribute_name), + email=response.get(generic_user_email_attribute_name), + first_name=response.get(generic_user_first_name_attribute_name), + last_name=response.get(generic_user_last_name_attribute_name), + provider=response.get(generic_provider_attribute_name), + team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)), + ) + + +async def get_generic_sso_response( + request: Request, + jwt_handler: JWTHandler, + generic_client_id: str, + redirect_url: str, +) -> Optional[OpenID]: + # make generic sso provider + from fastapi_sso.sso.base import DiscoveryDocument + from fastapi_sso.sso.generic import create_provider + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") + generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + generic_include_client_id = ( + os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true" + ) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type=ProxyErrorTypes.auth_error, + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + + def response_convertor(response, client): + return generic_response_convertor( + response=response, + jwt_handler=jwt_handler, + ) + + SSOProvider = create_provider( + name="oidc", + discovery_document=discovery, + response_convertor=response_convertor, + ) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + scope=generic_scope, + ) + verbose_proxy_logger.debug("calling generic_sso.verify_and_process") + result = await generic_sso.verify_and_process( + request, params={"include_client_id": generic_include_client_id} + ) + verbose_proxy_logger.debug("generic result: %s", result) + return result + + +async def create_team_member_add_task(team_id, user_info): + """Create a task for adding a member to a team.""" + try: + member = Member(user_id=user_info.user_id, role="user") + team_member_add_request = TeamMemberAddRequest( + member=member, + team_id=team_id, + ) + return await team_member_add( + data=team_member_add_request, + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), + http_request=Request(scope={"type": "http", "path": "/sso/callback"}), + ) + except Exception as e: + verbose_proxy_logger.debug( + f"[Non-Blocking] Error trying to add sso user to db: {e}" + ) + + +async def add_missing_team_member(user_info: NewUserResponse, sso_teams: List[str]): + """ + - Get missing teams (diff b/w user_info.team_ids and sso_teams) + - Add missing user to missing teams + """ + if user_info.teams is None: + return + missing_teams = set(sso_teams) - set(user_info.teams) + missing_teams_list = list(missing_teams) + tasks = [] + tasks = [ + create_team_member_add_task(team_id, user_info) + for team_id in missing_teams_list + ] + + try: + await asyncio.gather(*tasks) + except Exception as e: + verbose_proxy_logger.debug( + f"[Non-Blocking] Error trying to add sso user to db: {e}" + ) + + @router.get("/sso/callback", tags=["experimental"], include_in_schema=False) async def auth_callback(request: Request): # noqa: PLR0915 """Verify login""" @@ -229,6 +399,7 @@ async def auth_callback(request: Request): # noqa: PLR0915 ) from litellm.proxy.proxy_server import ( general_settings, + jwt_handler, master_key, premium_user, prisma_client, @@ -299,116 +470,12 @@ async def auth_callback(request: Request): # noqa: PLR0915 ) result = await microsoft_sso.verify_and_process(request) elif generic_client_id is not None: - # make generic sso provider - from fastapi_sso.sso.base import DiscoveryDocument, OpenID - from fastapi_sso.sso.generic import create_provider - - generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) - generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") - generic_authorization_endpoint = os.getenv( - "GENERIC_AUTHORIZATION_ENDPOINT", None + result = await get_generic_sso_response( + request=request, + jwt_handler=jwt_handler, + generic_client_id=generic_client_id, + redirect_url=redirect_url, ) - generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) - generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) - generic_include_client_id = ( - os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true" - ) - if generic_client_secret is None: - raise ProxyException( - message="GENERIC_CLIENT_SECRET not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_authorization_endpoint is None: - raise ProxyException( - message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_AUTHORIZATION_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_token_endpoint is None: - raise ProxyException( - message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_TOKEN_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_userinfo_endpoint is None: - raise ProxyException( - message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", - type=ProxyErrorTypes.auth_error, - param="GENERIC_USERINFO_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - verbose_proxy_logger.debug( - f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" - ) - verbose_proxy_logger.debug( - f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" - ) - - generic_user_id_attribute_name = os.getenv( - "GENERIC_USER_ID_ATTRIBUTE", "preferred_username" - ) - generic_user_display_name_attribute_name = os.getenv( - "GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub" - ) - generic_user_email_attribute_name = os.getenv( - "GENERIC_USER_EMAIL_ATTRIBUTE", "email" - ) - generic_user_role_attribute_name = os.getenv( - "GENERIC_USER_ROLE_ATTRIBUTE", "role" - ) - generic_user_first_name_attribute_name = os.getenv( - "GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name" - ) - generic_user_last_name_attribute_name = os.getenv( - "GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name" - ) - - generic_provider_attribute_name = os.getenv( - "GENERIC_USER_PROVIDER_ATTRIBUTE", "provider" - ) - - verbose_proxy_logger.debug( - f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}" - ) - - discovery = DiscoveryDocument( - authorization_endpoint=generic_authorization_endpoint, - token_endpoint=generic_token_endpoint, - userinfo_endpoint=generic_userinfo_endpoint, - ) - - def response_convertor(response, client): - return OpenID( - id=response.get(generic_user_id_attribute_name), - display_name=response.get(generic_user_display_name_attribute_name), - email=response.get(generic_user_email_attribute_name), - first_name=response.get(generic_user_first_name_attribute_name), - last_name=response.get(generic_user_last_name_attribute_name), - provider=response.get(generic_provider_attribute_name), - ) - - SSOProvider = create_provider( - name="oidc", - discovery_document=discovery, - response_convertor=response_convertor, - ) - generic_sso = SSOProvider( - client_id=generic_client_id, - client_secret=generic_client_secret, - redirect_uri=redirect_url, - allow_insecure_http=True, - scope=generic_scope, - ) - verbose_proxy_logger.debug("calling generic_sso.verify_and_process") - result = await generic_sso.verify_and_process( - request, params={"include_client_id": generic_include_client_id} - ) - verbose_proxy_logger.debug("generic result: %s", result) - # User is Authe'd in - generate key for the UI to access Proxy user_email: Optional[str] = getattr(result, "email", None) user_id: Optional[str] = getattr(result, "id", None) if result is not None else None @@ -428,6 +495,9 @@ async def auth_callback(request: Request): # noqa: PLR0915 # generic client id if generic_client_id is not None and result is not None: + generic_user_role_attribute_name = os.getenv( + "GENERIC_USER_ROLE_ATTRIBUTE", "role" + ) user_id = getattr(result, "id", None) user_email = getattr(result, "email", None) user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore @@ -508,12 +578,21 @@ async def auth_callback(request: Request): # noqa: PLR0915 ) else: # user not in DB, insert User into LiteLLM DB - user_role = await insert_sso_user( + user_info = await insert_sso_user( result_openid=result, user_defined_values=user_defined_values, ) - except Exception: - pass + + user_role = ( + user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + ) + sso_teams = getattr(result, "team_ids", []) + await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) + + except Exception as e: + verbose_proxy_logger.debug( + f"[Non-Blocking] Error trying to add sso user to db: {e}" + ) if user_defined_values is None: raise Exception( @@ -588,13 +667,16 @@ async def auth_callback(request: Request): # noqa: PLR0915 async def insert_sso_user( result_openid: Optional[OpenID], user_defined_values: Optional[SSOUserDefinedValues] = None, -) -> str: +) -> NewUserResponse: """ Helper function to create a New User in LiteLLM DB after a successful SSO login Args: result_openid (OpenID): User information in OpenID format if the login was successful. user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read + + Returns: + Tuple[str, str]: User ID and User Role """ verbose_proxy_logger.debug( f"Inserting SSO user into DB. User values: {user_defined_values}" @@ -629,9 +711,9 @@ async def insert_sso_user( if result_openid: new_user_request.metadata = {"auth_provider": result_openid.provider} - await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth()) + response = await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth()) - return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + return response @router.get( diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index bed171df70..dd018f674f 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1327,3 +1327,35 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): # Verify the result structure assert isinstance(result, UserInfoResponse) assert len(result.keys) == 2 + + +def test_custom_openid_response(): + from litellm.proxy.management_endpoints.ui_sso import generic_response_convertor + from litellm.proxy.management_endpoints.ui_sso import JWTHandler + from litellm.proxy._types import LiteLLM_JWTAuth + from litellm.caching import DualCache + + jwt_handler = JWTHandler() + jwt_handler.update_environment( + prisma_client={}, + user_api_key_cache=DualCache(), + litellm_jwtauth=LiteLLM_JWTAuth( + team_ids_jwt_field="department", + ), + ) + response = { + "sub": "3f196e06-7484-451e-be5a-ea6c6bb86c5b", + "email_verified": True, + "name": "Krish Dholakia", + "preferred_username": "krrishd", + "given_name": "Krish", + "department": ["/test-group"], + "family_name": "Dholakia", + "email": "krrishdholakia@gmail.com", + } + + resp = generic_response_convertor( + response=response, + jwt_handler=jwt_handler, + ) + assert resp.team_ids == ["/test-group"] diff --git a/ui/litellm-dashboard/src/components/leftnav.tsx b/ui/litellm-dashboard/src/components/leftnav.tsx index 42ee4b9683..27b5edfe1e 100644 --- a/ui/litellm-dashboard/src/components/leftnav.tsx +++ b/ui/litellm-dashboard/src/components/leftnav.tsx @@ -33,7 +33,7 @@ const menuItems: MenuItem[] = [ { key: "2", page: "models", label: "Models", roles: all_admin_roles }, { key: "4", page: "usage", label: "Usage"}, // all roles { key: "6", page: "teams", label: "Teams" }, - { key: "17", page: "organizations", label: "Organizations" }, + { key: "17", page: "organizations", label: "Organizations", roles: all_admin_roles }, { key: "5", page: "users", label: "Internal Users", roles: all_admin_roles }, { key: "8", page: "settings", label: "Logging & Alerts", roles: all_admin_roles }, { key: "9", page: "caching", label: "Caching", roles: all_admin_roles }, diff --git a/ui/litellm-dashboard/src/components/teams.tsx b/ui/litellm-dashboard/src/components/teams.tsx index 36a343637e..fc43ececc3 100644 --- a/ui/litellm-dashboard/src/components/teams.tsx +++ b/ui/litellm-dashboard/src/components/teams.tsx @@ -20,6 +20,7 @@ import { Tooltip } from "antd"; import { Select, SelectItem } from "@tremor/react"; + import { Table, TableBody, @@ -69,6 +70,7 @@ import { teamListCall } from "./networking"; + const Team: React.FC = ({ teams, searchParams, @@ -365,6 +367,7 @@ const Team: React.FC = ({ const handleCreate = async (formValues: Record) => { try { + console.log(`formValues: ${JSON.stringify(formValues)}`); if (accessToken != null) { const newTeamAlias = formValues?.team_alias; const existingTeamAliases = teams?.map((t) => t.team_alias) ?? []; @@ -746,6 +749,17 @@ const Team: React.FC = ({ Additional Settings + + { + e.target.value = e.target.value.trim(); + }} + /> +