Remove user_id from url (#10192)

* fix(user_dashboard.tsx): initial commit using user id from jwt instead of url

* fix(proxy_server.py): remove user id from url

fixes security issue around sharing url's

* fix(user_dashboard.tsx): handle user id being null
This commit is contained in:
Krish Dholakia 2025-04-21 16:22:57 -07:00 committed by GitHub
parent a34778dda6
commit 89131d8ed3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 41 additions and 29 deletions

View file

@ -553,7 +553,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
algorithm="HS256",
)
if user_id is not None and isinstance(user_id, str):
litellm_dashboard_ui += "?userID=" + user_id
litellm_dashboard_ui += "?login=success"
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
return redirect_response
@ -592,9 +592,9 @@ async def insert_sso_user(
if user_defined_values.get("max_budget") is None:
user_defined_values["max_budget"] = litellm.max_internal_user_budget
if user_defined_values.get("budget_duration") is None:
user_defined_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
user_defined_values[
"budget_duration"
] = litellm.internal_user_budget_duration
if user_defined_values["user_role"] is None:
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
@ -787,9 +787,9 @@ class SSOAuthenticationHandler:
if state:
redirect_params["state"] = state
elif "okta" in generic_authorization_endpoint:
redirect_params["state"] = (
uuid.uuid4().hex
) # set state param for okta - required
redirect_params[
"state"
] = uuid.uuid4().hex # set state param for okta - required
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
raise ValueError(
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
@ -1023,7 +1023,7 @@ class MicrosoftSSOHandler:
original_msft_result = (
await microsoft_sso.verify_and_process(
request=request,
convert_response=False,
convert_response=False, # type: ignore
)
or {}
)
@ -1034,9 +1034,9 @@ class MicrosoftSSOHandler:
# if user is trying to get the raw sso response for debugging, return the raw sso response
if return_raw_sso_response:
original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = (
user_team_ids
)
original_msft_result[
MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY
] = user_team_ids
return original_msft_result or {}
result = MicrosoftSSOHandler.openid_from_response(
@ -1086,12 +1086,13 @@ class MicrosoftSSOHandler:
service_principal_group_ids: Optional[List[str]] = []
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
if service_principal_id:
service_principal_group_ids, service_principal_teams = (
await MicrosoftSSOHandler.get_group_ids_from_service_principal(
service_principal_id=service_principal_id,
async_client=async_client,
access_token=access_token,
)
(
service_principal_group_ids,
service_principal_teams,
) = await MicrosoftSSOHandler.get_group_ids_from_service_principal(
service_principal_id=service_principal_id,
async_client=async_client,
access_token=access_token,
)
verbose_proxy_logger.debug(
f"Service principal group IDs: {service_principal_group_ids}"
@ -1103,9 +1104,9 @@ class MicrosoftSSOHandler:
# Fetch user membership from Microsoft Graph API
all_group_ids = []
next_link: Optional[str] = (
MicrosoftSSOHandler.graph_api_user_groups_endpoint
)
next_link: Optional[
str
] = MicrosoftSSOHandler.graph_api_user_groups_endpoint
auth_headers = {"Authorization": f"Bearer {access_token}"}
page_count = 0
@ -1304,7 +1305,7 @@ class GoogleSSOHandler:
return (
await google_sso.verify_and_process(
request=request,
convert_response=False,
convert_response=False, # type: ignore
)
or {}
)

View file

@ -6815,7 +6815,7 @@ async def login(request: Request): # noqa: PLR0915
master_key,
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id
litellm_dashboard_ui += "?login=success"
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
redirect_response.set_cookie(key="token", value=jwt_token)
return redirect_response
@ -6891,7 +6891,7 @@ async def login(request: Request): # noqa: PLR0915
master_key,
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id
litellm_dashboard_ui += "?login=success"
redirect_response = RedirectResponse(
url=litellm_dashboard_ui, status_code=303
)

View file

@ -104,7 +104,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli
# Assert the response
assert response.status_code == 303
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
assert response.headers["location"].startswith(f"/ui/?login=success")
# Verify that the user was added to the database
user = await prisma_client.db.litellm_usertable.find_first(
@ -177,7 +177,7 @@ async def test_auth_callback_new_user_with_sso_default(
# Assert the response
assert response.status_code == 303
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
assert response.headers["location"].startswith(f"/ui/?login=success")
# Verify that the user was added to the database
user = await prisma_client.db.litellm_usertable.find_first(

View file

@ -84,8 +84,7 @@ export default function Onboarding() {
formValues.password
).then((data) => {
let litellm_dashboard_ui = "/ui/";
const user_id = data.data?.user_id || data.user_id;
litellm_dashboard_ui += "?userID=" + user_id;
litellm_dashboard_ui += "?login=success";
// set cookie "token" to jwtToken
document.cookie = "token=" + jwtToken;

View file

@ -98,8 +98,8 @@ export default function CreateKeyPage() {
const searchParams = useSearchParams()!;
const [modelData, setModelData] = useState<any>({ data: [] });
const [token, setToken] = useState<string | null>(null);
const [userID, setUserID] = useState<string | null>(null);
const userID = searchParams.get("userID");
const invitation_id = searchParams.get("invitation_id");
// Get page from URL, default to 'api-keys' if not present
@ -177,6 +177,10 @@ export default function CreateKeyPage() {
if (decoded.auth_header_name) {
setGlobalLitellmHeaderName(decoded.auth_header_name);
}
if (decoded.user_id) {
setUserID(decoded.user_id);
}
}
}, [token]);

View file

@ -295,7 +295,8 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
)
}
if (userID == null || token == null) {
if (token == null) {
// user is not logged in as yet
console.log("All cookies before redirect:", document.cookie);
@ -314,6 +315,13 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
return null;
}
if (userID == null) {
return (
<h1>User ID is not set</h1>
);
}
if (userRole == null) {
setUserRole("App Owner");
}