This commit is contained in:
Ishaan Jaff 2025-04-21 18:39:44 +05:30 committed by GitHub
commit cbe0b58263
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 117 additions and 46 deletions

View file

@ -10,6 +10,7 @@ Has all /sso/* routes
import asyncio
import os
import time
import uuid
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
@ -57,6 +58,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
)
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
from litellm.proxy.management_endpoints.types import CustomOpenID
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
from litellm.proxy.utils import PrismaClient
from litellm.secret_managers.main import str_to_bool
from litellm.types.proxy.management_endpoints.ui_sso import *
@ -554,9 +556,10 @@ async def auth_callback(request: Request): # noqa: PLR0915
)
if user_id is not None and isinstance(user_id, str):
litellm_dashboard_ui += "?userID=" + user_id
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
return redirect_response
return UISessionHandler.generate_authenticated_redirect_response(
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
)
async def insert_sso_user(

View file

@ -0,0 +1,24 @@
import time
from fastapi.responses import RedirectResponse
class UISessionHandler:
@staticmethod
def generate_authenticated_redirect_response(
redirect_url: str, jwt_token: str
) -> RedirectResponse:
redirect_response = RedirectResponse(url=redirect_url, status_code=303)
redirect_response.set_cookie(
key=UISessionHandler._generate_token_name(),
value=jwt_token,
secure=True,
samesite="strict",
)
return redirect_response
@staticmethod
def _generate_token_name() -> str:
current_timestamp = int(time.time())
cookie_name = f"token_{current_timestamp}"
return cookie_name

View file

@ -6699,6 +6699,8 @@ async def login(request: Request): # noqa: PLR0915
import multipart
except ImportError:
subprocess.run(["pip", "install", "python-multipart"])
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
global master_key
if master_key is None:
raise ProxyException(
@ -6816,9 +6818,9 @@ async def login(request: Request): # noqa: PLR0915
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
redirect_response.set_cookie(key="token", value=jwt_token)
return redirect_response
return UISessionHandler.generate_authenticated_redirect_response(
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
)
elif _user_row is not None:
"""
When sharing invite links
@ -6892,11 +6894,9 @@ async def login(request: Request): # noqa: PLR0915
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id
redirect_response = RedirectResponse(
url=litellm_dashboard_ui, status_code=303
return UISessionHandler.generate_authenticated_redirect_response(
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
)
redirect_response.set_cookie(key="token", value=jwt_token)
return redirect_response
else:
raise ProxyException(
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",

View file

@ -20,12 +20,12 @@ import {
} from "@/components/networking";
import { jwtDecode } from "jwt-decode";
import { Form, Button as Button2, message } from "antd";
import { getCookie } from "@/utils/cookieUtils";
import { getAuthToken, setAuthToken } from "@/utils/cookieUtils";
export default function Onboarding() {
const [form] = Form.useForm();
const searchParams = useSearchParams()!;
const token = getCookie('token');
const token = getAuthToken();
const inviteID = searchParams.get("invitation_id");
const [accessToken, setAccessToken] = useState<string | null>(null);
const [defaultUserEmail, setDefaultUserEmail] = useState<string>("");
@ -88,7 +88,7 @@ export default function Onboarding() {
litellm_dashboard_ui += "?userID=" + user_id;
// set cookie "token" to jwtToken
document.cookie = "token=" + jwtToken;
setAuthToken(jwtToken);
console.log("redirecting to:", litellm_dashboard_ui);
window.location.href = litellm_dashboard_ui;

View file

@ -32,6 +32,7 @@ import GuardrailsPanel from "@/components/guardrails";
import TransformRequestPanel from "@/components/transform_request";
import { fetchUserModels } from "@/components/create_key_button";
import { fetchTeams } from "@/components/common_components/fetch_teams";
import { getAuthToken } from "@/utils/cookieUtils";
import MCPToolsViewer from "@/components/mcp_tools";
import TagManagement from "@/components/tag_management";
@ -122,7 +123,7 @@ export default function CreateKeyPage() {
const [accessToken, setAccessToken] = useState<string | null>(null);
useEffect(() => {
const token = getCookie("token");
const token = getAuthToken();
setToken(token);
}, []);

View file

@ -21,6 +21,7 @@ import { useSearchParams, useRouter } from "next/navigation";
import { Team } from "./key_team_helpers/key_list";
import { jwtDecode } from "jwt-decode";
import { Typography } from "antd";
import { getAuthToken } from "@/utils/cookieUtils";
import { clearTokenCookies } from "@/utils/cookieUtils";
const isLocal = process.env.NODE_ENV === "development";
if (isLocal != true) {
@ -45,14 +46,6 @@ export type UserInfo = {
spend: number;
}
function getCookie(name: string) {
console.log("COOKIES", document.cookie)
const cookieValue = document.cookie
.split('; ')
.find(row => row.startsWith(name + '='));
return cookieValue ? cookieValue.split('=')[1] : null;
}
interface UserDashboardProps {
userID: string | null;
userRole: string | null;
@ -94,7 +87,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
// Assuming useSearchParams() hook exists and works in your setup
const searchParams = useSearchParams()!;
const token = getCookie('token');
const token = getAuthToken();
const invitation_id = searchParams.get("invitation_id");

View file

@ -13,32 +13,82 @@ export function clearTokenCookies() {
const paths = ['/', '/ui'];
const sameSiteValues = ['Lax', 'Strict', 'None'];
paths.forEach(path => {
// Basic clearing
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
// With domain
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
// Try different SameSite values
sameSiteValues.forEach(sameSite => {
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
// Get all cookies
const allCookies = document.cookie.split("; ");
const tokenPattern = /^token_\d+$/;
// Find all token cookies
const tokenCookieNames = allCookies
.map(cookie => cookie.split("=")[0])
.filter(name => name === "token" || tokenPattern.test(name));
// Clear each token cookie with various combinations
tokenCookieNames.forEach(cookieName => {
paths.forEach(path => {
// Basic clearing
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
// With domain
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
// Try different SameSite values
sameSiteValues.forEach(sameSite => {
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
});
});
});
console.log("After clearing cookies:", document.cookie);
}
/**
* Gets a cookie value by name
* @param name The name of the cookie to retrieve
* @returns The cookie value or null if not found
*/
export function getCookie(name: string) {
const cookieValue = document.cookie
.split('; ')
.find(row => row.startsWith(name + '='));
return cookieValue ? cookieValue.split('=')[1] : null;
}
export function setAuthToken(token: string) {
// Generate a token name with current timestamp
const currentTimestamp = Math.floor(Date.now() / 1000);
const tokenName = `token_${currentTimestamp}`;
// Set the cookie with the timestamp-based name
document.cookie = `${tokenName}=${token}; path=/; domain=${window.location.hostname};`;
}
export function getAuthToken() {
// Check if we're in a browser environment
if (typeof window === 'undefined' || typeof document === 'undefined') {
return null;
}
const tokenPattern = /^token_(\d+)$/;
const allCookies = document.cookie.split("; ");
const tokenCookies = allCookies
.map(cookie => {
const parts = cookie.split("=");
const name = parts[0];
// Explicitly skip cookies named just "token"
if (name === "token") {
return null;
}
// Only match cookies with the token_{timestamp} format
const match = name.match(tokenPattern);
if (match) {
return {
name,
timestamp: parseInt(match[1], 10),
value: parts.slice(1).join("=")
};
}
return null;
})
.filter((cookie): cookie is { name: string; timestamp: number; value: string } => cookie !== null);
if (tokenCookies.length > 0) {
// Sort by timestamp (newest first)
tokenCookies.sort((a, b) => b.timestamp - a.timestamp);
return tokenCookies[0].value;
}
return null;
}