mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Remove user_id from url (#10192)
* fix(user_dashboard.tsx): initial commit using user id from jwt instead of url * fix(proxy_server.py): remove user id from url fixes security issue around sharing url's * fix(user_dashboard.tsx): handle user id being null
This commit is contained in:
parent
ca3649e6fb
commit
58cb6be9e7
6 changed files with 41 additions and 29 deletions
|
@ -553,7 +553,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
if user_id is not None and isinstance(user_id, str):
|
if user_id is not None and isinstance(user_id, str):
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?login=success"
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
||||||
return redirect_response
|
return redirect_response
|
||||||
|
@ -592,9 +592,9 @@ async def insert_sso_user(
|
||||||
if user_defined_values.get("max_budget") is None:
|
if user_defined_values.get("max_budget") is None:
|
||||||
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||||||
if user_defined_values.get("budget_duration") is None:
|
if user_defined_values.get("budget_duration") is None:
|
||||||
user_defined_values["budget_duration"] = (
|
user_defined_values[
|
||||||
litellm.internal_user_budget_duration
|
"budget_duration"
|
||||||
)
|
] = litellm.internal_user_budget_duration
|
||||||
|
|
||||||
if user_defined_values["user_role"] is None:
|
if user_defined_values["user_role"] is None:
|
||||||
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
|
@ -787,9 +787,9 @@ class SSOAuthenticationHandler:
|
||||||
if state:
|
if state:
|
||||||
redirect_params["state"] = state
|
redirect_params["state"] = state
|
||||||
elif "okta" in generic_authorization_endpoint:
|
elif "okta" in generic_authorization_endpoint:
|
||||||
redirect_params["state"] = (
|
redirect_params[
|
||||||
uuid.uuid4().hex
|
"state"
|
||||||
) # set state param for okta - required
|
] = uuid.uuid4().hex # set state param for okta - required
|
||||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
|
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
|
||||||
|
@ -1023,7 +1023,7 @@ class MicrosoftSSOHandler:
|
||||||
original_msft_result = (
|
original_msft_result = (
|
||||||
await microsoft_sso.verify_and_process(
|
await microsoft_sso.verify_and_process(
|
||||||
request=request,
|
request=request,
|
||||||
convert_response=False,
|
convert_response=False, # type: ignore
|
||||||
)
|
)
|
||||||
or {}
|
or {}
|
||||||
)
|
)
|
||||||
|
@ -1034,9 +1034,9 @@ class MicrosoftSSOHandler:
|
||||||
|
|
||||||
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||||||
if return_raw_sso_response:
|
if return_raw_sso_response:
|
||||||
original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = (
|
original_msft_result[
|
||||||
user_team_ids
|
MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY
|
||||||
)
|
] = user_team_ids
|
||||||
return original_msft_result or {}
|
return original_msft_result or {}
|
||||||
|
|
||||||
result = MicrosoftSSOHandler.openid_from_response(
|
result = MicrosoftSSOHandler.openid_from_response(
|
||||||
|
@ -1086,12 +1086,13 @@ class MicrosoftSSOHandler:
|
||||||
service_principal_group_ids: Optional[List[str]] = []
|
service_principal_group_ids: Optional[List[str]] = []
|
||||||
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
|
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
|
||||||
if service_principal_id:
|
if service_principal_id:
|
||||||
service_principal_group_ids, service_principal_teams = (
|
(
|
||||||
await MicrosoftSSOHandler.get_group_ids_from_service_principal(
|
service_principal_group_ids,
|
||||||
service_principal_id=service_principal_id,
|
service_principal_teams,
|
||||||
async_client=async_client,
|
) = await MicrosoftSSOHandler.get_group_ids_from_service_principal(
|
||||||
access_token=access_token,
|
service_principal_id=service_principal_id,
|
||||||
)
|
async_client=async_client,
|
||||||
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Service principal group IDs: {service_principal_group_ids}"
|
f"Service principal group IDs: {service_principal_group_ids}"
|
||||||
|
@ -1103,9 +1104,9 @@ class MicrosoftSSOHandler:
|
||||||
|
|
||||||
# Fetch user membership from Microsoft Graph API
|
# Fetch user membership from Microsoft Graph API
|
||||||
all_group_ids = []
|
all_group_ids = []
|
||||||
next_link: Optional[str] = (
|
next_link: Optional[
|
||||||
MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
str
|
||||||
)
|
] = MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||||||
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
page_count = 0
|
page_count = 0
|
||||||
|
|
||||||
|
@ -1304,7 +1305,7 @@ class GoogleSSOHandler:
|
||||||
return (
|
return (
|
||||||
await google_sso.verify_and_process(
|
await google_sso.verify_and_process(
|
||||||
request=request,
|
request=request,
|
||||||
convert_response=False,
|
convert_response=False, # type: ignore
|
||||||
)
|
)
|
||||||
or {}
|
or {}
|
||||||
)
|
)
|
||||||
|
|
|
@ -6815,7 +6815,7 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
master_key,
|
master_key,
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?login=success"
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
redirect_response.set_cookie(key="token", value=jwt_token)
|
||||||
return redirect_response
|
return redirect_response
|
||||||
|
@ -6891,7 +6891,7 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
master_key,
|
master_key,
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?login=success"
|
||||||
redirect_response = RedirectResponse(
|
redirect_response = RedirectResponse(
|
||||||
url=litellm_dashboard_ui, status_code=303
|
url=litellm_dashboard_ui, status_code=303
|
||||||
)
|
)
|
||||||
|
|
|
@ -104,7 +104,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli
|
||||||
|
|
||||||
# Assert the response
|
# Assert the response
|
||||||
assert response.status_code == 303
|
assert response.status_code == 303
|
||||||
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
|
assert response.headers["location"].startswith(f"/ui/?login=success")
|
||||||
|
|
||||||
# Verify that the user was added to the database
|
# Verify that the user was added to the database
|
||||||
user = await prisma_client.db.litellm_usertable.find_first(
|
user = await prisma_client.db.litellm_usertable.find_first(
|
||||||
|
@ -177,7 +177,7 @@ async def test_auth_callback_new_user_with_sso_default(
|
||||||
|
|
||||||
# Assert the response
|
# Assert the response
|
||||||
assert response.status_code == 303
|
assert response.status_code == 303
|
||||||
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
|
assert response.headers["location"].startswith(f"/ui/?login=success")
|
||||||
|
|
||||||
# Verify that the user was added to the database
|
# Verify that the user was added to the database
|
||||||
user = await prisma_client.db.litellm_usertable.find_first(
|
user = await prisma_client.db.litellm_usertable.find_first(
|
||||||
|
|
|
@ -84,8 +84,7 @@ export default function Onboarding() {
|
||||||
formValues.password
|
formValues.password
|
||||||
).then((data) => {
|
).then((data) => {
|
||||||
let litellm_dashboard_ui = "/ui/";
|
let litellm_dashboard_ui = "/ui/";
|
||||||
const user_id = data.data?.user_id || data.user_id;
|
litellm_dashboard_ui += "?login=success";
|
||||||
litellm_dashboard_ui += "?userID=" + user_id;
|
|
||||||
|
|
||||||
// set cookie "token" to jwtToken
|
// set cookie "token" to jwtToken
|
||||||
document.cookie = "token=" + jwtToken;
|
document.cookie = "token=" + jwtToken;
|
||||||
|
|
|
@ -98,8 +98,8 @@ export default function CreateKeyPage() {
|
||||||
const searchParams = useSearchParams()!;
|
const searchParams = useSearchParams()!;
|
||||||
const [modelData, setModelData] = useState<any>({ data: [] });
|
const [modelData, setModelData] = useState<any>({ data: [] });
|
||||||
const [token, setToken] = useState<string | null>(null);
|
const [token, setToken] = useState<string | null>(null);
|
||||||
|
const [userID, setUserID] = useState<string | null>(null);
|
||||||
|
|
||||||
const userID = searchParams.get("userID");
|
|
||||||
const invitation_id = searchParams.get("invitation_id");
|
const invitation_id = searchParams.get("invitation_id");
|
||||||
|
|
||||||
// Get page from URL, default to 'api-keys' if not present
|
// Get page from URL, default to 'api-keys' if not present
|
||||||
|
@ -177,6 +177,10 @@ export default function CreateKeyPage() {
|
||||||
if (decoded.auth_header_name) {
|
if (decoded.auth_header_name) {
|
||||||
setGlobalLitellmHeaderName(decoded.auth_header_name);
|
setGlobalLitellmHeaderName(decoded.auth_header_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (decoded.user_id) {
|
||||||
|
setUserID(decoded.user_id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, [token]);
|
}, [token]);
|
||||||
|
|
||||||
|
|
|
@ -295,7 +295,8 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (userID == null || token == null) {
|
|
||||||
|
if (token == null) {
|
||||||
// user is not logged in as yet
|
// user is not logged in as yet
|
||||||
console.log("All cookies before redirect:", document.cookie);
|
console.log("All cookies before redirect:", document.cookie);
|
||||||
|
|
||||||
|
@ -314,6 +315,13 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (userID == null) {
|
||||||
|
return (
|
||||||
|
<h1>User ID is not set</h1>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (userRole == null) {
|
if (userRole == null) {
|
||||||
setUserRole("App Owner");
|
setUserRole("App Owner");
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue