mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(proxy_server.py): remove user id from url
fixes security issue around sharing url's
This commit is contained in:
parent
9ff4fa56a7
commit
707dc1d56a
4 changed files with 27 additions and 27 deletions
|
@ -553,7 +553,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
if user_id is not None and isinstance(user_id, str):
|
if user_id is not None and isinstance(user_id, str):
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?login=success"
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
||||||
return redirect_response
|
return redirect_response
|
||||||
|
@ -592,9 +592,9 @@ async def insert_sso_user(
|
||||||
if user_defined_values.get("max_budget") is None:
|
if user_defined_values.get("max_budget") is None:
|
||||||
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||||||
if user_defined_values.get("budget_duration") is None:
|
if user_defined_values.get("budget_duration") is None:
|
||||||
user_defined_values["budget_duration"] = (
|
user_defined_values[
|
||||||
litellm.internal_user_budget_duration
|
"budget_duration"
|
||||||
)
|
] = litellm.internal_user_budget_duration
|
||||||
|
|
||||||
if user_defined_values["user_role"] is None:
|
if user_defined_values["user_role"] is None:
|
||||||
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
|
@ -787,9 +787,9 @@ class SSOAuthenticationHandler:
|
||||||
if state:
|
if state:
|
||||||
redirect_params["state"] = state
|
redirect_params["state"] = state
|
||||||
elif "okta" in generic_authorization_endpoint:
|
elif "okta" in generic_authorization_endpoint:
|
||||||
redirect_params["state"] = (
|
redirect_params[
|
||||||
uuid.uuid4().hex
|
"state"
|
||||||
) # set state param for okta - required
|
] = uuid.uuid4().hex # set state param for okta - required
|
||||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
|
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
|
||||||
|
@ -1023,7 +1023,7 @@ class MicrosoftSSOHandler:
|
||||||
original_msft_result = (
|
original_msft_result = (
|
||||||
await microsoft_sso.verify_and_process(
|
await microsoft_sso.verify_and_process(
|
||||||
request=request,
|
request=request,
|
||||||
convert_response=False,
|
convert_response=False, # type: ignore
|
||||||
)
|
)
|
||||||
or {}
|
or {}
|
||||||
)
|
)
|
||||||
|
@ -1034,9 +1034,9 @@ class MicrosoftSSOHandler:
|
||||||
|
|
||||||
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||||||
if return_raw_sso_response:
|
if return_raw_sso_response:
|
||||||
original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = (
|
original_msft_result[
|
||||||
user_team_ids
|
MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY
|
||||||
)
|
] = user_team_ids
|
||||||
return original_msft_result or {}
|
return original_msft_result or {}
|
||||||
|
|
||||||
result = MicrosoftSSOHandler.openid_from_response(
|
result = MicrosoftSSOHandler.openid_from_response(
|
||||||
|
@ -1086,12 +1086,13 @@ class MicrosoftSSOHandler:
|
||||||
service_principal_group_ids: Optional[List[str]] = []
|
service_principal_group_ids: Optional[List[str]] = []
|
||||||
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
|
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
|
||||||
if service_principal_id:
|
if service_principal_id:
|
||||||
service_principal_group_ids, service_principal_teams = (
|
(
|
||||||
await MicrosoftSSOHandler.get_group_ids_from_service_principal(
|
service_principal_group_ids,
|
||||||
service_principal_id=service_principal_id,
|
service_principal_teams,
|
||||||
async_client=async_client,
|
) = await MicrosoftSSOHandler.get_group_ids_from_service_principal(
|
||||||
access_token=access_token,
|
service_principal_id=service_principal_id,
|
||||||
)
|
async_client=async_client,
|
||||||
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Service principal group IDs: {service_principal_group_ids}"
|
f"Service principal group IDs: {service_principal_group_ids}"
|
||||||
|
@ -1103,9 +1104,9 @@ class MicrosoftSSOHandler:
|
||||||
|
|
||||||
# Fetch user membership from Microsoft Graph API
|
# Fetch user membership from Microsoft Graph API
|
||||||
all_group_ids = []
|
all_group_ids = []
|
||||||
next_link: Optional[str] = (
|
next_link: Optional[
|
||||||
MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
str
|
||||||
)
|
] = MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||||||
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
page_count = 0
|
page_count = 0
|
||||||
|
|
||||||
|
@ -1304,7 +1305,7 @@ class GoogleSSOHandler:
|
||||||
return (
|
return (
|
||||||
await google_sso.verify_and_process(
|
await google_sso.verify_and_process(
|
||||||
request=request,
|
request=request,
|
||||||
convert_response=False,
|
convert_response=False, # type: ignore
|
||||||
)
|
)
|
||||||
or {}
|
or {}
|
||||||
)
|
)
|
||||||
|
|
|
@ -6815,7 +6815,7 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
master_key,
|
master_key,
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?login=success"
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
redirect_response.set_cookie(key="token", value=jwt_token)
|
||||||
return redirect_response
|
return redirect_response
|
||||||
|
@ -6891,7 +6891,7 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
master_key,
|
master_key,
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?login=success"
|
||||||
redirect_response = RedirectResponse(
|
redirect_response = RedirectResponse(
|
||||||
url=litellm_dashboard_ui, status_code=303
|
url=litellm_dashboard_ui, status_code=303
|
||||||
)
|
)
|
||||||
|
|
|
@ -104,7 +104,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli
|
||||||
|
|
||||||
# Assert the response
|
# Assert the response
|
||||||
assert response.status_code == 303
|
assert response.status_code == 303
|
||||||
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
|
assert response.headers["location"].startswith(f"/ui/?login=success")
|
||||||
|
|
||||||
# Verify that the user was added to the database
|
# Verify that the user was added to the database
|
||||||
user = await prisma_client.db.litellm_usertable.find_first(
|
user = await prisma_client.db.litellm_usertable.find_first(
|
||||||
|
@ -177,7 +177,7 @@ async def test_auth_callback_new_user_with_sso_default(
|
||||||
|
|
||||||
# Assert the response
|
# Assert the response
|
||||||
assert response.status_code == 303
|
assert response.status_code == 303
|
||||||
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
|
assert response.headers["location"].startswith(f"/ui/?login=success")
|
||||||
|
|
||||||
# Verify that the user was added to the database
|
# Verify that the user was added to the database
|
||||||
user = await prisma_client.db.litellm_usertable.find_first(
|
user = await prisma_client.db.litellm_usertable.find_first(
|
||||||
|
|
|
@ -84,8 +84,7 @@ export default function Onboarding() {
|
||||||
formValues.password
|
formValues.password
|
||||||
).then((data) => {
|
).then((data) => {
|
||||||
let litellm_dashboard_ui = "/ui/";
|
let litellm_dashboard_ui = "/ui/";
|
||||||
const user_id = data.data?.user_id || data.user_id;
|
litellm_dashboard_ui += "?login=success";
|
||||||
litellm_dashboard_ui += "?userID=" + user_id;
|
|
||||||
|
|
||||||
// set cookie "token" to jwtToken
|
// set cookie "token" to jwtToken
|
||||||
document.cookie = "token=" + jwtToken;
|
document.cookie = "token=" + jwtToken;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue