diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 6e50bdd5a3..23054fc45b 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -553,7 +553,7 @@ async def auth_callback(request: Request): # noqa: PLR0915 algorithm="HS256", ) if user_id is not None and isinstance(user_id, str): - litellm_dashboard_ui += "?userID=" + user_id + litellm_dashboard_ui += "?login=success" redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) redirect_response.set_cookie(key="token", value=jwt_token, secure=True) return redirect_response @@ -592,9 +592,9 @@ async def insert_sso_user( if user_defined_values.get("max_budget") is None: user_defined_values["max_budget"] = litellm.max_internal_user_budget if user_defined_values.get("budget_duration") is None: - user_defined_values["budget_duration"] = ( - litellm.internal_user_budget_duration - ) + user_defined_values[ + "budget_duration" + ] = litellm.internal_user_budget_duration if user_defined_values["user_role"] is None: user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY @@ -787,9 +787,9 @@ class SSOAuthenticationHandler: if state: redirect_params["state"] = state elif "okta" in generic_authorization_endpoint: - redirect_params["state"] = ( - uuid.uuid4().hex - ) # set state param for okta - required + redirect_params[ + "state" + ] = uuid.uuid4().hex # set state param for okta - required return await generic_sso.get_login_redirect(**redirect_params) # type: ignore raise ValueError( "Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso" @@ -1023,7 +1023,7 @@ class MicrosoftSSOHandler: original_msft_result = ( await microsoft_sso.verify_and_process( request=request, - convert_response=False, + convert_response=False, # type: ignore ) or {} ) @@ -1034,9 +1034,9 @@ class MicrosoftSSOHandler: # if user is trying to get the raw sso response for debugging, return the raw sso response if return_raw_sso_response: - original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = ( - user_team_ids - ) + original_msft_result[ + MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY + ] = user_team_ids return original_msft_result or {} result = MicrosoftSSOHandler.openid_from_response( @@ -1086,12 +1086,13 @@ class MicrosoftSSOHandler: service_principal_group_ids: Optional[List[str]] = [] service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = [] if service_principal_id: - service_principal_group_ids, service_principal_teams = ( - await MicrosoftSSOHandler.get_group_ids_from_service_principal( - service_principal_id=service_principal_id, - async_client=async_client, - access_token=access_token, - ) + ( + service_principal_group_ids, + service_principal_teams, + ) = await MicrosoftSSOHandler.get_group_ids_from_service_principal( + service_principal_id=service_principal_id, + async_client=async_client, + access_token=access_token, ) verbose_proxy_logger.debug( f"Service principal group IDs: {service_principal_group_ids}" @@ -1103,9 +1104,9 @@ class MicrosoftSSOHandler: # Fetch user membership from Microsoft Graph API all_group_ids = [] - next_link: Optional[str] = ( - MicrosoftSSOHandler.graph_api_user_groups_endpoint - ) + next_link: Optional[ + str + ] = MicrosoftSSOHandler.graph_api_user_groups_endpoint auth_headers = {"Authorization": f"Bearer {access_token}"} page_count = 0 @@ -1304,7 +1305,7 @@ class GoogleSSOHandler: return ( await google_sso.verify_and_process( request=request, - convert_response=False, + convert_response=False, # type: ignore ) or {} ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 50662e69d5..4ca6b35db4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -6815,7 +6815,7 @@ async def login(request: Request): # noqa: PLR0915 master_key, algorithm="HS256", ) - litellm_dashboard_ui += "?userID=" + user_id + litellm_dashboard_ui += "?login=success" redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) redirect_response.set_cookie(key="token", value=jwt_token) return redirect_response @@ -6891,7 +6891,7 @@ async def login(request: Request): # noqa: PLR0915 master_key, algorithm="HS256", ) - litellm_dashboard_ui += "?userID=" + user_id + litellm_dashboard_ui += "?login=success" redirect_response = RedirectResponse( url=litellm_dashboard_ui, status_code=303 ) diff --git a/tests/proxy_admin_ui_tests/test_sso_sign_in.py b/tests/proxy_admin_ui_tests/test_sso_sign_in.py index 3d5dd9ffcc..5de198b04b 100644 --- a/tests/proxy_admin_ui_tests/test_sso_sign_in.py +++ b/tests/proxy_admin_ui_tests/test_sso_sign_in.py @@ -104,7 +104,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli # Assert the response assert response.status_code == 303 - assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}") + assert response.headers["location"].startswith(f"/ui/?login=success") # Verify that the user was added to the database user = await prisma_client.db.litellm_usertable.find_first( @@ -177,7 +177,7 @@ async def test_auth_callback_new_user_with_sso_default( # Assert the response assert response.status_code == 303 - assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}") + assert response.headers["location"].startswith(f"/ui/?login=success") # Verify that the user was added to the database user = await prisma_client.db.litellm_usertable.find_first( diff --git a/ui/litellm-dashboard/src/app/onboarding/page.tsx b/ui/litellm-dashboard/src/app/onboarding/page.tsx index d65d2eb510..01b8f33084 100644 --- a/ui/litellm-dashboard/src/app/onboarding/page.tsx +++ b/ui/litellm-dashboard/src/app/onboarding/page.tsx @@ -84,8 +84,7 @@ export default function Onboarding() { formValues.password ).then((data) => { let litellm_dashboard_ui = "/ui/"; - const user_id = data.data?.user_id || data.user_id; - litellm_dashboard_ui += "?userID=" + user_id; + litellm_dashboard_ui += "?login=success"; // set cookie "token" to jwtToken document.cookie = "token=" + jwtToken; diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index ae256bf0ac..10df55cda5 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -98,8 +98,8 @@ export default function CreateKeyPage() { const searchParams = useSearchParams()!; const [modelData, setModelData] = useState({ data: [] }); const [token, setToken] = useState(null); + const [userID, setUserID] = useState(null); - const userID = searchParams.get("userID"); const invitation_id = searchParams.get("invitation_id"); // Get page from URL, default to 'api-keys' if not present @@ -177,6 +177,10 @@ export default function CreateKeyPage() { if (decoded.auth_header_name) { setGlobalLitellmHeaderName(decoded.auth_header_name); } + + if (decoded.user_id) { + setUserID(decoded.user_id); + } } }, [token]); diff --git a/ui/litellm-dashboard/src/components/user_dashboard.tsx b/ui/litellm-dashboard/src/components/user_dashboard.tsx index ab25a96e31..c0f4c96d86 100644 --- a/ui/litellm-dashboard/src/components/user_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/user_dashboard.tsx @@ -295,7 +295,8 @@ const UserDashboard: React.FC = ({ ) } - if (userID == null || token == null) { + + if (token == null) { // user is not logged in as yet console.log("All cookies before redirect:", document.cookie); @@ -314,6 +315,13 @@ const UserDashboard: React.FC = ({ return null; } + if (userID == null) { + return ( +

User ID is not set

+ ); + } + + if (userRole == null) { setUserRole("App Owner"); }