mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge 2bc560d3f9
into 0b63c7a2eb
This commit is contained in:
commit
cbe0b58263
7 changed files with 117 additions and 46 deletions
|
@ -10,6 +10,7 @@ Has all /sso/* routes
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
@ -57,6 +58,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
|
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
|
||||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
from litellm.types.proxy.management_endpoints.ui_sso import *
|
from litellm.types.proxy.management_endpoints.ui_sso import *
|
||||||
|
@ -554,9 +556,10 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
)
|
)
|
||||||
if user_id is not None and isinstance(user_id, str):
|
if user_id is not None and isinstance(user_id, str):
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?userID=" + user_id
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
return UISessionHandler.generate_authenticated_redirect_response(
|
||||||
return redirect_response
|
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def insert_sso_user(
|
async def insert_sso_user(
|
||||||
|
|
24
litellm/proxy/management_helpers/ui_session_handler.py
Normal file
24
litellm/proxy/management_helpers/ui_session_handler.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
|
||||||
|
class UISessionHandler:
|
||||||
|
@staticmethod
|
||||||
|
def generate_authenticated_redirect_response(
|
||||||
|
redirect_url: str, jwt_token: str
|
||||||
|
) -> RedirectResponse:
|
||||||
|
redirect_response = RedirectResponse(url=redirect_url, status_code=303)
|
||||||
|
redirect_response.set_cookie(
|
||||||
|
key=UISessionHandler._generate_token_name(),
|
||||||
|
value=jwt_token,
|
||||||
|
secure=True,
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
return redirect_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_token_name() -> str:
|
||||||
|
current_timestamp = int(time.time())
|
||||||
|
cookie_name = f"token_{current_timestamp}"
|
||||||
|
return cookie_name
|
|
@ -6699,6 +6699,8 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
import multipart
|
import multipart
|
||||||
except ImportError:
|
except ImportError:
|
||||||
subprocess.run(["pip", "install", "python-multipart"])
|
subprocess.run(["pip", "install", "python-multipart"])
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
|
|
||||||
global master_key
|
global master_key
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
@ -6816,9 +6818,9 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?userID=" + user_id
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
return UISessionHandler.generate_authenticated_redirect_response(
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||||
return redirect_response
|
)
|
||||||
elif _user_row is not None:
|
elif _user_row is not None:
|
||||||
"""
|
"""
|
||||||
When sharing invite links
|
When sharing invite links
|
||||||
|
@ -6892,11 +6894,9 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?userID=" + user_id
|
||||||
redirect_response = RedirectResponse(
|
return UISessionHandler.generate_authenticated_redirect_response(
|
||||||
url=litellm_dashboard_ui, status_code=303
|
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||||
)
|
)
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
|
||||||
return redirect_response
|
|
||||||
else:
|
else:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
|
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
|
||||||
|
|
|
@ -20,12 +20,12 @@ import {
|
||||||
} from "@/components/networking";
|
} from "@/components/networking";
|
||||||
import { jwtDecode } from "jwt-decode";
|
import { jwtDecode } from "jwt-decode";
|
||||||
import { Form, Button as Button2, message } from "antd";
|
import { Form, Button as Button2, message } from "antd";
|
||||||
import { getCookie } from "@/utils/cookieUtils";
|
import { getAuthToken, setAuthToken } from "@/utils/cookieUtils";
|
||||||
|
|
||||||
export default function Onboarding() {
|
export default function Onboarding() {
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
const searchParams = useSearchParams()!;
|
const searchParams = useSearchParams()!;
|
||||||
const token = getCookie('token');
|
const token = getAuthToken();
|
||||||
const inviteID = searchParams.get("invitation_id");
|
const inviteID = searchParams.get("invitation_id");
|
||||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||||
const [defaultUserEmail, setDefaultUserEmail] = useState<string>("");
|
const [defaultUserEmail, setDefaultUserEmail] = useState<string>("");
|
||||||
|
@ -88,7 +88,7 @@ export default function Onboarding() {
|
||||||
litellm_dashboard_ui += "?userID=" + user_id;
|
litellm_dashboard_ui += "?userID=" + user_id;
|
||||||
|
|
||||||
// set cookie "token" to jwtToken
|
// set cookie "token" to jwtToken
|
||||||
document.cookie = "token=" + jwtToken;
|
setAuthToken(jwtToken);
|
||||||
console.log("redirecting to:", litellm_dashboard_ui);
|
console.log("redirecting to:", litellm_dashboard_ui);
|
||||||
|
|
||||||
window.location.href = litellm_dashboard_ui;
|
window.location.href = litellm_dashboard_ui;
|
||||||
|
|
|
@ -32,6 +32,7 @@ import GuardrailsPanel from "@/components/guardrails";
|
||||||
import TransformRequestPanel from "@/components/transform_request";
|
import TransformRequestPanel from "@/components/transform_request";
|
||||||
import { fetchUserModels } from "@/components/create_key_button";
|
import { fetchUserModels } from "@/components/create_key_button";
|
||||||
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
||||||
|
import { getAuthToken } from "@/utils/cookieUtils";
|
||||||
import MCPToolsViewer from "@/components/mcp_tools";
|
import MCPToolsViewer from "@/components/mcp_tools";
|
||||||
import TagManagement from "@/components/tag_management";
|
import TagManagement from "@/components/tag_management";
|
||||||
|
|
||||||
|
@ -122,7 +123,7 @@ export default function CreateKeyPage() {
|
||||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const token = getCookie("token");
|
const token = getAuthToken();
|
||||||
setToken(token);
|
setToken(token);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import { useSearchParams, useRouter } from "next/navigation";
|
||||||
import { Team } from "./key_team_helpers/key_list";
|
import { Team } from "./key_team_helpers/key_list";
|
||||||
import { jwtDecode } from "jwt-decode";
|
import { jwtDecode } from "jwt-decode";
|
||||||
import { Typography } from "antd";
|
import { Typography } from "antd";
|
||||||
|
import { getAuthToken } from "@/utils/cookieUtils";
|
||||||
import { clearTokenCookies } from "@/utils/cookieUtils";
|
import { clearTokenCookies } from "@/utils/cookieUtils";
|
||||||
const isLocal = process.env.NODE_ENV === "development";
|
const isLocal = process.env.NODE_ENV === "development";
|
||||||
if (isLocal != true) {
|
if (isLocal != true) {
|
||||||
|
@ -45,14 +46,6 @@ export type UserInfo = {
|
||||||
spend: number;
|
spend: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
function getCookie(name: string) {
|
|
||||||
console.log("COOKIES", document.cookie)
|
|
||||||
const cookieValue = document.cookie
|
|
||||||
.split('; ')
|
|
||||||
.find(row => row.startsWith(name + '='));
|
|
||||||
return cookieValue ? cookieValue.split('=')[1] : null;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface UserDashboardProps {
|
interface UserDashboardProps {
|
||||||
userID: string | null;
|
userID: string | null;
|
||||||
userRole: string | null;
|
userRole: string | null;
|
||||||
|
@ -94,7 +87,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
// Assuming useSearchParams() hook exists and works in your setup
|
// Assuming useSearchParams() hook exists and works in your setup
|
||||||
const searchParams = useSearchParams()!;
|
const searchParams = useSearchParams()!;
|
||||||
|
|
||||||
const token = getCookie('token');
|
const token = getAuthToken();
|
||||||
|
|
||||||
const invitation_id = searchParams.get("invitation_id");
|
const invitation_id = searchParams.get("invitation_id");
|
||||||
|
|
||||||
|
|
|
@ -13,32 +13,82 @@ export function clearTokenCookies() {
|
||||||
const paths = ['/', '/ui'];
|
const paths = ['/', '/ui'];
|
||||||
const sameSiteValues = ['Lax', 'Strict', 'None'];
|
const sameSiteValues = ['Lax', 'Strict', 'None'];
|
||||||
|
|
||||||
|
// Get all cookies
|
||||||
|
const allCookies = document.cookie.split("; ");
|
||||||
|
const tokenPattern = /^token_\d+$/;
|
||||||
|
|
||||||
|
// Find all token cookies
|
||||||
|
const tokenCookieNames = allCookies
|
||||||
|
.map(cookie => cookie.split("=")[0])
|
||||||
|
.filter(name => name === "token" || tokenPattern.test(name));
|
||||||
|
|
||||||
|
// Clear each token cookie with various combinations
|
||||||
|
tokenCookieNames.forEach(cookieName => {
|
||||||
paths.forEach(path => {
|
paths.forEach(path => {
|
||||||
// Basic clearing
|
// Basic clearing
|
||||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
|
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
|
||||||
|
|
||||||
// With domain
|
// With domain
|
||||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
|
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
|
||||||
|
|
||||||
// Try different SameSite values
|
// Try different SameSite values
|
||||||
sameSiteValues.forEach(sameSite => {
|
sameSiteValues.forEach(sameSite => {
|
||||||
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
|
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
|
||||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
|
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
|
||||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
|
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log("After clearing cookies:", document.cookie);
|
console.log("After clearing cookies:", document.cookie);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export function setAuthToken(token: string) {
|
||||||
* Gets a cookie value by name
|
// Generate a token name with current timestamp
|
||||||
* @param name The name of the cookie to retrieve
|
const currentTimestamp = Math.floor(Date.now() / 1000);
|
||||||
* @returns The cookie value or null if not found
|
const tokenName = `token_${currentTimestamp}`;
|
||||||
*/
|
|
||||||
export function getCookie(name: string) {
|
// Set the cookie with the timestamp-based name
|
||||||
const cookieValue = document.cookie
|
document.cookie = `${tokenName}=${token}; path=/; domain=${window.location.hostname};`;
|
||||||
.split('; ')
|
}
|
||||||
.find(row => row.startsWith(name + '='));
|
|
||||||
return cookieValue ? cookieValue.split('=')[1] : null;
|
export function getAuthToken() {
|
||||||
|
// Check if we're in a browser environment
|
||||||
|
if (typeof window === 'undefined' || typeof document === 'undefined') {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tokenPattern = /^token_(\d+)$/;
|
||||||
|
const allCookies = document.cookie.split("; ");
|
||||||
|
|
||||||
|
const tokenCookies = allCookies
|
||||||
|
.map(cookie => {
|
||||||
|
const parts = cookie.split("=");
|
||||||
|
const name = parts[0];
|
||||||
|
|
||||||
|
// Explicitly skip cookies named just "token"
|
||||||
|
if (name === "token") {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only match cookies with the token_{timestamp} format
|
||||||
|
const match = name.match(tokenPattern);
|
||||||
|
if (match) {
|
||||||
|
return {
|
||||||
|
name,
|
||||||
|
timestamp: parseInt(match[1], 10),
|
||||||
|
value: parts.slice(1).join("=")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
})
|
||||||
|
.filter((cookie): cookie is { name: string; timestamp: number; value: string } => cookie !== null);
|
||||||
|
|
||||||
|
if (tokenCookies.length > 0) {
|
||||||
|
// Sort by timestamp (newest first)
|
||||||
|
tokenCookies.sort((a, b) => b.timestamp - a.timestamp);
|
||||||
|
return tokenCookies[0].value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue