mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Revert "(UI) - Security Improvement, move to JWT Auth for Admin UI Sessions (#8995)"
This reverts commit 01a44a4e47
.
This commit is contained in:
parent
207f41cbea
commit
8d6815ce98
17 changed files with 539 additions and 1105 deletions
|
@ -272,7 +272,6 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/key/health",
|
||||
"/team/info",
|
||||
"/team/list",
|
||||
"/organization/info",
|
||||
"/organization/list",
|
||||
"/team/available",
|
||||
"/user/info",
|
||||
|
@ -283,11 +282,6 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/health",
|
||||
"/key/list",
|
||||
"/user/filter/ui",
|
||||
"/user/list",
|
||||
"/user/available_roles",
|
||||
"/guardrails/list",
|
||||
"/cache/ping",
|
||||
"/get/config/callbacks",
|
||||
]
|
||||
|
||||
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
||||
|
@ -306,8 +300,6 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/user/update",
|
||||
"/user/delete",
|
||||
"/user/info",
|
||||
# user invitation management
|
||||
"/invitation/new",
|
||||
# team
|
||||
"/team/new",
|
||||
"/team/update",
|
||||
|
@ -317,20 +309,6 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/team/block",
|
||||
"/team/unblock",
|
||||
"/team/available",
|
||||
# team member management
|
||||
"/team/member_add",
|
||||
"/team/member_update",
|
||||
"/team/member_delete",
|
||||
# organization management
|
||||
"/organization/new",
|
||||
"/organization/update",
|
||||
"/organization/delete",
|
||||
"/organization/info",
|
||||
"/organization/list",
|
||||
# organization member management
|
||||
"/organization/member_add",
|
||||
"/organization/member_update",
|
||||
"/organization/member_delete",
|
||||
# model
|
||||
"/model/new",
|
||||
"/model/update",
|
||||
|
@ -377,32 +355,20 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/sso",
|
||||
"/sso/get/ui_settings",
|
||||
"/login",
|
||||
"/sso/session/validate",
|
||||
"/key/info",
|
||||
"/config",
|
||||
"/spend",
|
||||
"/model/info",
|
||||
"/model/metrics",
|
||||
"/model/metrics/{sub_path}",
|
||||
"/model/settings",
|
||||
"/get/litellm_model_cost_map",
|
||||
"/model/streaming_metrics",
|
||||
"/v2/model/info",
|
||||
"/v2/key/info",
|
||||
"/models",
|
||||
"/v1/models",
|
||||
"/global/spend",
|
||||
"/global/spend/logs",
|
||||
"/spend/logs/ui",
|
||||
"/spend/logs/ui/{id}",
|
||||
"/global/spend/keys",
|
||||
"/global/spend/models",
|
||||
"/global/predict/spend/logs",
|
||||
"/global/activity",
|
||||
"/global/activity/{sub_path}",
|
||||
"/global/activity/exceptions",
|
||||
"/global/activity/exceptions/{sub_path}",
|
||||
"/global/all_end_users",
|
||||
"/health/services",
|
||||
] + info_routes
|
||||
|
||||
|
@ -2493,7 +2459,6 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||
"spend_tracking_routes",
|
||||
"global_spend_tracking_routes",
|
||||
"info_routes",
|
||||
"ui_routes",
|
||||
]
|
||||
team_id_jwt_field: Optional[str] = None
|
||||
team_id_upsert: bool = False
|
||||
|
|
|
@ -204,11 +204,9 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
|||
"""
|
||||
|
||||
for allowed_route in allowed_routes:
|
||||
if allowed_route in LiteLLMRoutes.__members__ and (
|
||||
RouteChecks.check_route_access(
|
||||
route=user_route,
|
||||
allowed_routes=LiteLLMRoutes[allowed_route].value,
|
||||
)
|
||||
if (
|
||||
allowed_route in LiteLLMRoutes.__members__
|
||||
and user_route in LiteLLMRoutes[allowed_route].value
|
||||
):
|
||||
return True
|
||||
elif allowed_route == user_route:
|
||||
|
@ -219,18 +217,16 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
|||
def allowed_routes_check(
|
||||
user_role: Literal[
|
||||
LitellmUserRoles.PROXY_ADMIN,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||||
LitellmUserRoles.TEAM,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||
],
|
||||
user_route: str,
|
||||
litellm_proxy_roles: LiteLLM_JWTAuth,
|
||||
jwt_valid_token: dict,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user -> not admin - allowed to access these routes
|
||||
"""
|
||||
|
||||
if user_role == LitellmUserRoles.PROXY_ADMIN:
|
||||
is_allowed = _allowed_routes_check(
|
||||
user_route=user_route,
|
||||
|
|
|
@ -33,7 +33,6 @@ from litellm.proxy._types import (
|
|||
ScopeMapping,
|
||||
Span,
|
||||
)
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
||||
from .auth_checks import (
|
||||
|
@ -407,60 +406,10 @@ class JWTHandler:
|
|||
else:
|
||||
return False
|
||||
|
||||
def _validate_ui_token(self, token: str) -> Optional[dict]:
|
||||
"""
|
||||
Helper function to validate tokens generated for the LiteLLM UI.
|
||||
Returns the decoded payload if it's a valid UI token, None otherwise.
|
||||
"""
|
||||
import jwt
|
||||
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
try:
|
||||
# Decode without verification to check if it's a UI token
|
||||
unverified_payload = jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
# Check if this looks like a UI token (has specific claims that only UI tokens would have)
|
||||
if UISessionHandler.is_ui_session_token(unverified_payload):
|
||||
|
||||
# This looks like a UI token, now verify it with the master key
|
||||
if not master_key:
|
||||
verbose_proxy_logger.debug(
|
||||
"Missing LITELLM_MASTER_KEY for UI token validation"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
master_key,
|
||||
algorithms=["HS256"],
|
||||
audience="litellm-ui",
|
||||
leeway=self.leeway,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Successfully validated UI token for payload: %s",
|
||||
json.dumps(payload, indent=4),
|
||||
)
|
||||
return payload
|
||||
except jwt.InvalidTokenError as e:
|
||||
verbose_proxy_logger.debug(f"Invalid UI token: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Invalid UI token, Unable to validate token signature {str(e)}"
|
||||
)
|
||||
|
||||
return None # Not a UI token
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
|
||||
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
|
||||
# the key in different ways (e.g. HS* and RS*)."
|
||||
|
||||
ui_payload = self._validate_ui_token(token)
|
||||
if ui_payload:
|
||||
return ui_payload
|
||||
algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
|
||||
|
||||
audience = os.getenv("JWT_AUDIENCE")
|
||||
|
@ -667,7 +616,6 @@ class JWTAuthManager:
|
|||
user_id: Optional[str],
|
||||
org_id: Optional[str],
|
||||
api_key: str,
|
||||
jwt_valid_token: dict,
|
||||
) -> Optional[JWTAuthBuilderResult]:
|
||||
"""Check admin status and route access permissions"""
|
||||
if not jwt_handler.is_admin(scopes=scopes):
|
||||
|
@ -677,7 +625,6 @@ class JWTAuthManager:
|
|||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
jwt_valid_token=jwt_valid_token,
|
||||
)
|
||||
if not is_allowed:
|
||||
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
|
@ -751,7 +698,6 @@ class JWTAuthManager:
|
|||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
jwt_valid_token: dict,
|
||||
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
||||
"""Find first team with access to the requested model"""
|
||||
|
||||
|
@ -784,7 +730,6 @@ class JWTAuthManager:
|
|||
user_role=LitellmUserRoles.TEAM,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
jwt_valid_token=jwt_valid_token,
|
||||
)
|
||||
if is_allowed:
|
||||
return team_id, team_object
|
||||
|
@ -975,13 +920,7 @@ class JWTAuthManager:
|
|||
|
||||
# Check admin access
|
||||
admin_result = await JWTAuthManager.check_admin_access(
|
||||
jwt_handler=jwt_handler,
|
||||
scopes=scopes,
|
||||
route=route,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
api_key=api_key,
|
||||
jwt_valid_token=jwt_valid_token,
|
||||
jwt_handler, scopes, route, user_id, org_id, api_key
|
||||
)
|
||||
if admin_result:
|
||||
return admin_result
|
||||
|
@ -1013,7 +952,6 @@ class JWTAuthManager:
|
|||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
jwt_valid_token=jwt_valid_token,
|
||||
)
|
||||
|
||||
# Get other objects
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import re
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
|
@ -225,9 +225,7 @@ class RouteChecks:
|
|||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_route_access(
|
||||
route: str, allowed_routes: Union[List[str], Set[str]]
|
||||
) -> bool:
|
||||
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
|
||||
"""
|
||||
Check if a route has access by checking both exact matches and patterns
|
||||
|
||||
|
|
|
@ -51,7 +51,6 @@ from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
|||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.auth.service_account_checks import service_account_checks
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, _to_ns
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
@ -336,7 +335,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
"pass_through_endpoints", None
|
||||
)
|
||||
passed_in_key: Optional[str] = None
|
||||
cookie_token: Optional[str] = None
|
||||
if isinstance(api_key, str):
|
||||
passed_in_key = api_key
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
|
@ -346,10 +344,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
api_key = anthropic_api_key_header
|
||||
elif isinstance(google_ai_studio_api_key_header, str):
|
||||
api_key = google_ai_studio_api_key_header
|
||||
elif cookie_token := UISessionHandler._get_ui_session_token_from_cookies(
|
||||
request
|
||||
):
|
||||
api_key = cookie_token
|
||||
elif pass_through_endpoints is not None:
|
||||
for endpoint in pass_through_endpoints:
|
||||
if endpoint.get("path", "") == route:
|
||||
|
@ -426,10 +420,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
if general_settings.get("enable_oauth2_proxy_auth", False) is True:
|
||||
return await handle_oauth2_proxy_request(request=request)
|
||||
|
||||
if (
|
||||
general_settings.get("enable_jwt_auth", False) is True
|
||||
or cookie_token is not None
|
||||
):
|
||||
if general_settings.get("enable_jwt_auth", False) is True:
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
|
|
|
@ -8,7 +8,7 @@ Has all /sso/* routes
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
|
@ -43,7 +43,6 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
|||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -409,7 +408,11 @@ def get_disabled_non_admin_personal_key_creation():
|
|||
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
||||
async def auth_callback(request: Request): # noqa: PLR0915
|
||||
"""Verify login"""
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
jwt_handler,
|
||||
master_key,
|
||||
premium_user,
|
||||
|
@ -423,7 +426,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
user_role: Optional[LitellmUserRoles] = None
|
||||
# get url from request
|
||||
if master_key is None:
|
||||
raise ProxyException(
|
||||
|
@ -529,7 +531,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
max_internal_user_budget = litellm.max_internal_user_budget
|
||||
internal_user_budget_duration = litellm.internal_user_budget_duration
|
||||
|
||||
# User might not be already created on first generation of key
|
||||
# But if it is, we want their models preferences
|
||||
default_ui_key_values: Dict[str, Any] = {
|
||||
"duration": "24hr",
|
||||
"key_max_budget": litellm.max_ui_session_budget,
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"spend": 0,
|
||||
"team_id": "litellm-dashboard",
|
||||
}
|
||||
user_defined_values: Optional[SSOUserDefinedValues] = None
|
||||
|
||||
if user_custom_sso is not None:
|
||||
|
@ -548,6 +559,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
)
|
||||
|
||||
_user_id_from_sso = user_id
|
||||
user_role = None
|
||||
try:
|
||||
if prisma_client is not None:
|
||||
try:
|
||||
|
@ -620,14 +632,24 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
f"user_defined_values for creating ui key: {user_defined_values}"
|
||||
)
|
||||
|
||||
default_ui_key_values.update(user_defined_values)
|
||||
default_ui_key_values["request_type"] = "key"
|
||||
response = await generate_key_helper_fn(
|
||||
**default_ui_key_values, # type: ignore
|
||||
table_name="key",
|
||||
)
|
||||
|
||||
key = response["token"] # type: ignore
|
||||
user_id = response["user_id"] # type: ignore
|
||||
|
||||
litellm_dashboard_ui = "/ui/"
|
||||
user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||||
if (
|
||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||
):
|
||||
# checks if user is admin
|
||||
user_role = LitellmUserRoles.PROXY_ADMIN
|
||||
user_role = LitellmUserRoles.PROXY_ADMIN.value
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
||||
|
@ -647,20 +669,30 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
disabled_non_admin_personal_key_creation = (
|
||||
get_disabled_non_admin_personal_key_creation()
|
||||
)
|
||||
jwt_token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||
user_id=user_defined_values.get("user_id", ""),
|
||||
user_role=user_role,
|
||||
user_email=user_defined_values.get("user_email", ""),
|
||||
premium_user=premium_user,
|
||||
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||
login_method="sso",
|
||||
|
||||
import jwt
|
||||
|
||||
jwt_token = jwt.encode( # type: ignore
|
||||
{
|
||||
"user_id": user_id,
|
||||
"key": key,
|
||||
"user_email": user_email,
|
||||
"user_role": user_role,
|
||||
"login_method": "sso",
|
||||
"premium_user": premium_user,
|
||||
"auth_header_name": general_settings.get(
|
||||
"litellm_key_header_name", "Authorization"
|
||||
),
|
||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
||||
},
|
||||
master_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
if user_id is not None and isinstance(user_id, str):
|
||||
litellm_dashboard_ui += "?userID=" + user_id
|
||||
|
||||
return UISessionHandler.generate_authenticated_redirect_response(
|
||||
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||
)
|
||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
||||
return redirect_response
|
||||
|
||||
|
||||
async def insert_sso_user(
|
||||
|
@ -746,25 +778,3 @@ async def get_ui_settings(request: Request):
|
|||
),
|
||||
"DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sso/session/validate",
|
||||
include_in_schema=False,
|
||||
tags=["experimental"],
|
||||
)
|
||||
async def validate_session(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
|
||||
ui_session_token = UISessionHandler._get_ui_session_token_from_cookies(request)
|
||||
ui_session_id = UISessionHandler._get_latest_ui_cookie_name(request.cookies)
|
||||
if ui_session_token is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
jwt_handler = JWTHandler()
|
||||
validated_jwt_token = await jwt_handler.auth_jwt(token=ui_session_token)
|
||||
validated_jwt_token["session_id"] = ui_session_id
|
||||
return {"valid": True, "data": validated_jwt_token}
|
||||
|
|
|
@ -1,141 +0,0 @@
|
|||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth, LitellmUserRoles
|
||||
|
||||
|
||||
class UISessionHandler:
|
||||
|
||||
@staticmethod
|
||||
def _get_latest_ui_cookie_name(cookies: dict) -> Optional[str]:
|
||||
"""
|
||||
Get the name of the most recent LiteLLM UI cookie
|
||||
"""
|
||||
# Find all LiteLLM UI cookies (format: litellm_ui_token_{timestamp})
|
||||
litellm_ui_cookies = [
|
||||
k for k in cookies.keys() if k.startswith("litellm_ui_token_")
|
||||
]
|
||||
if not litellm_ui_cookies:
|
||||
return None
|
||||
|
||||
# Sort by timestamp (descending) to get the most recent one
|
||||
try:
|
||||
# Extract timestamps and sort numerically
|
||||
sorted_cookies = sorted(
|
||||
litellm_ui_cookies,
|
||||
key=lambda x: int(x.split("_")[-1]),
|
||||
reverse=True,
|
||||
)
|
||||
return sorted_cookies[0]
|
||||
except (ValueError, IndexError):
|
||||
# Fallback to simple string sort if timestamp extraction fails
|
||||
litellm_ui_cookies.sort(reverse=True)
|
||||
return litellm_ui_cookies[0]
|
||||
|
||||
@staticmethod
|
||||
# Add this function to extract auth token from cookies
|
||||
def _get_ui_session_token_from_cookies(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract authentication token from cookies if present
|
||||
"""
|
||||
try:
|
||||
cookies = request.cookies
|
||||
verbose_proxy_logger.debug(f"AUTH COOKIES: {cookies}")
|
||||
|
||||
cookie_name = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||
if cookie_name:
|
||||
return cookies[cookie_name]
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error getting UI session token from cookies: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def build_authenticated_ui_jwt_token(
|
||||
user_id: str,
|
||||
user_role: Optional[LitellmUserRoles],
|
||||
user_email: Optional[str],
|
||||
premium_user: bool,
|
||||
disabled_non_admin_personal_key_creation: bool,
|
||||
login_method: Literal["username_password", "sso"],
|
||||
) -> str:
|
||||
"""
|
||||
Build a JWT token for the authenticated UI session
|
||||
|
||||
This token is used to authenticate the user's session when they are redirected to the UI
|
||||
"""
|
||||
import jwt
|
||||
|
||||
from litellm.proxy.proxy_server import general_settings, master_key
|
||||
|
||||
if master_key is None:
|
||||
raise ValueError("Master key is not set")
|
||||
|
||||
expiration = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
initial_payload = {
|
||||
"user_id": user_id,
|
||||
"user_email": user_email,
|
||||
"user_role": user_role, # this is the path without sso - we can assume only admins will use this
|
||||
"login_method": login_method,
|
||||
"premium_user": premium_user,
|
||||
"auth_header_name": general_settings.get(
|
||||
"litellm_key_header_name", "Authorization"
|
||||
),
|
||||
"iss": "litellm-proxy", # Issuer - identifies this as an internal token
|
||||
"aud": "litellm-ui", # Audience - identifies this as a UI token
|
||||
"exp": expiration,
|
||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
||||
}
|
||||
|
||||
if (
|
||||
user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
or user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
|
||||
):
|
||||
initial_payload["scope"] = [
|
||||
LiteLLM_JWTAuth().admin_jwt_scope,
|
||||
]
|
||||
|
||||
jwt_token = jwt.encode(
|
||||
initial_payload,
|
||||
master_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
return jwt_token
|
||||
|
||||
@staticmethod
|
||||
def is_ui_session_token(token_dict: dict) -> bool:
|
||||
"""
|
||||
Returns True if the token is a UI session token
|
||||
"""
|
||||
return (
|
||||
token_dict.get("iss") == "litellm-proxy"
|
||||
and token_dict.get("aud") == "litellm-ui"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_authenticated_redirect_response(
|
||||
redirect_url: str, jwt_token: str
|
||||
) -> RedirectResponse:
|
||||
redirect_response = RedirectResponse(url=redirect_url, status_code=303)
|
||||
redirect_response.set_cookie(
|
||||
key=UISessionHandler._generate_token_name(),
|
||||
value=jwt_token,
|
||||
secure=True,
|
||||
httponly=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return redirect_response
|
||||
|
||||
@staticmethod
|
||||
def _generate_token_name() -> str:
|
||||
current_timestamp = int(time.time())
|
||||
cookie_name = f"litellm_ui_token_{current_timestamp}"
|
||||
return cookie_name
|
|
@ -7377,8 +7377,6 @@ async def login(request: Request): # noqa: PLR0915
|
|||
import multipart
|
||||
except ImportError:
|
||||
subprocess.run(["pip", "install", "python-multipart"])
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
|
||||
global master_key
|
||||
if master_key is None:
|
||||
raise ProxyException(
|
||||
|
@ -7449,23 +7447,56 @@ async def login(request: Request): # noqa: PLR0915
|
|||
user_role=user_role,
|
||||
)
|
||||
)
|
||||
if os.getenv("DATABASE_URL") is not None:
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
**{
|
||||
"user_role": LitellmUserRoles.PROXY_ADMIN,
|
||||
"duration": "24hr",
|
||||
"key_max_budget": litellm.max_ui_session_budget,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"spend": 0,
|
||||
"user_id": key_user_id,
|
||||
"team_id": "litellm-dashboard",
|
||||
}, # type: ignore
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="DATABASE_URL",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
key = response["token"] # type: ignore
|
||||
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
||||
if litellm_dashboard_ui.endswith("/"):
|
||||
litellm_dashboard_ui += "ui/"
|
||||
else:
|
||||
litellm_dashboard_ui += "/ui/"
|
||||
jwt_token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||
user_id=user_id,
|
||||
user_role=user_role,
|
||||
user_email=None,
|
||||
premium_user=premium_user,
|
||||
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||
login_method="username_password",
|
||||
import jwt
|
||||
|
||||
jwt_token = jwt.encode( # type: ignore
|
||||
{
|
||||
"user_id": user_id,
|
||||
"key": key,
|
||||
"user_email": None,
|
||||
"user_role": user_role, # this is the path without sso - we can assume only admins will use this
|
||||
"login_method": "username_password",
|
||||
"premium_user": premium_user,
|
||||
"auth_header_name": general_settings.get(
|
||||
"litellm_key_header_name", "Authorization"
|
||||
),
|
||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
||||
},
|
||||
master_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
litellm_dashboard_ui += "?userID=" + user_id
|
||||
return UISessionHandler.generate_authenticated_redirect_response(
|
||||
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||
)
|
||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
||||
return redirect_response
|
||||
elif _user_row is not None:
|
||||
"""
|
||||
When sharing invite links
|
||||
|
@ -7484,23 +7515,58 @@ async def login(request: Request): # noqa: PLR0915
|
|||
if secrets.compare_digest(password, _password) or secrets.compare_digest(
|
||||
hash_password, _password
|
||||
):
|
||||
if os.getenv("DATABASE_URL") is not None:
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
**{ # type: ignore
|
||||
"user_role": user_role,
|
||||
"duration": "24hr",
|
||||
"key_max_budget": litellm.max_ui_session_budget,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"spend": 0,
|
||||
"user_id": user_id,
|
||||
"team_id": "litellm-dashboard",
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="DATABASE_URL",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
key = response["token"] # type: ignore
|
||||
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
||||
if litellm_dashboard_ui.endswith("/"):
|
||||
litellm_dashboard_ui += "ui/"
|
||||
else:
|
||||
litellm_dashboard_ui += "/ui/"
|
||||
jwt_token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||
user_id=user_id,
|
||||
user_role=user_role,
|
||||
user_email=user_email,
|
||||
premium_user=premium_user,
|
||||
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||
login_method="username_password",
|
||||
import jwt
|
||||
|
||||
jwt_token = jwt.encode( # type: ignore
|
||||
{
|
||||
"user_id": user_id,
|
||||
"key": key,
|
||||
"user_email": user_email,
|
||||
"user_role": user_role,
|
||||
"login_method": "username_password",
|
||||
"premium_user": premium_user,
|
||||
"auth_header_name": general_settings.get(
|
||||
"litellm_key_header_name", "Authorization"
|
||||
),
|
||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
||||
},
|
||||
master_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
litellm_dashboard_ui += "?userID=" + user_id
|
||||
return UISessionHandler.generate_authenticated_redirect_response(
|
||||
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||
redirect_response = RedirectResponse(
|
||||
url=litellm_dashboard_ui, status_code=303
|
||||
)
|
||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
||||
return redirect_response
|
||||
else:
|
||||
raise ProxyException(
|
||||
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
|
||||
|
@ -7526,8 +7592,6 @@ async def onboarding(invite_link: str):
|
|||
- Get user from db
|
||||
- Pass in user_email if set
|
||||
"""
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
|
||||
global prisma_client, master_key, general_settings
|
||||
if master_key is None:
|
||||
raise ProxyException(
|
||||
|
@ -7584,26 +7648,51 @@ async def onboarding(invite_link: str):
|
|||
|
||||
user_email = user_obj.user_email
|
||||
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
**{
|
||||
"user_role": user_obj.user_role,
|
||||
"duration": "24hr",
|
||||
"key_max_budget": litellm.max_ui_session_budget,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"spend": 0,
|
||||
"user_id": user_obj.user_id,
|
||||
"team_id": "litellm-dashboard",
|
||||
}, # type: ignore
|
||||
)
|
||||
key = response["token"] # type: ignore
|
||||
|
||||
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
||||
if litellm_dashboard_ui.endswith("/"):
|
||||
litellm_dashboard_ui += "ui/onboarding"
|
||||
else:
|
||||
litellm_dashboard_ui += "/ui/onboarding"
|
||||
import jwt
|
||||
|
||||
disabled_non_admin_personal_key_creation = (
|
||||
get_disabled_non_admin_personal_key_creation()
|
||||
)
|
||||
jwt_token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||
user_id=user_obj.user_id,
|
||||
user_role=user_obj.user_role,
|
||||
user_email=user_obj.user_email,
|
||||
premium_user=user_obj.premium_user,
|
||||
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||
login_method="username_password",
|
||||
|
||||
jwt_token = jwt.encode( # type: ignore
|
||||
{
|
||||
"user_id": user_obj.user_id,
|
||||
"key": key,
|
||||
"user_email": user_obj.user_email,
|
||||
"user_role": user_obj.user_role,
|
||||
"login_method": "username_password",
|
||||
"premium_user": premium_user,
|
||||
"auth_header_name": general_settings.get(
|
||||
"litellm_key_header_name", "Authorization"
|
||||
),
|
||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
||||
},
|
||||
master_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
|
||||
litellm_dashboard_ui += "?token={}&user_email={}".format(jwt_token, user_email)
|
||||
|
||||
return {
|
||||
"login_url": litellm_dashboard_ui,
|
||||
"token": jwt_token,
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import jwt
|
||||
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
|
||||
|
||||
class TestJWTHandler:
|
||||
@pytest.fixture
|
||||
def jwt_handler(self):
|
||||
handler = JWTHandler()
|
||||
handler.leeway = 60 # Set leeway for testing
|
||||
return handler
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mocks(self, monkeypatch):
|
||||
# Mock master_key
|
||||
test_master_key = "test_master_key"
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", test_master_key)
|
||||
|
||||
# Mock UISessionHandler.is_ui_session_token to return True for our test token
|
||||
def mock_is_ui_session_token(payload):
|
||||
return "ui_session_id" in payload
|
||||
|
||||
monkeypatch.setattr(
|
||||
UISessionHandler, "is_ui_session_token", mock_is_ui_session_token
|
||||
)
|
||||
|
||||
return test_master_key
|
||||
|
||||
def test_validate_valid_ui_token(self, jwt_handler, setup_mocks):
|
||||
# Setup
|
||||
test_master_key = setup_mocks
|
||||
|
||||
# Create a valid UI token
|
||||
valid_payload = {
|
||||
"ui_session_id": "test_session_id",
|
||||
"exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600,
|
||||
"iat": datetime.datetime.now(tz=timezone.utc).timestamp(),
|
||||
"aud": "litellm-ui",
|
||||
}
|
||||
valid_token = jwt.encode(valid_payload, test_master_key, algorithm="HS256")
|
||||
|
||||
# Test valid UI token
|
||||
result = jwt_handler._validate_ui_token(valid_token)
|
||||
assert result is not None
|
||||
assert result["ui_session_id"] == "test_session_id"
|
||||
|
||||
def test_validate_expired_ui_token(self, jwt_handler, setup_mocks):
|
||||
# Setup
|
||||
test_master_key = setup_mocks
|
||||
|
||||
# Create an expired UI token
|
||||
expired_payload = {
|
||||
"ui_session_id": "test_session_id",
|
||||
"exp": datetime.datetime.now(tz=timezone.utc).timestamp() - 3600,
|
||||
"iat": datetime.datetime.now(tz=timezone.utc).timestamp() - 7200,
|
||||
"aud": "litellm-ui",
|
||||
}
|
||||
expired_token = jwt.encode(expired_payload, test_master_key, algorithm="HS256")
|
||||
|
||||
# Test expired UI token
|
||||
with pytest.raises(ValueError, match="Invalid UI token"):
|
||||
jwt_handler._validate_ui_token(expired_token)
|
||||
|
||||
def test_validate_invalid_signature_ui_token(self, jwt_handler, setup_mocks):
|
||||
# Setup
|
||||
test_master_key = setup_mocks
|
||||
|
||||
# Create a token with invalid signature
|
||||
valid_payload = {
|
||||
"ui_session_id": "test_session_id",
|
||||
"exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600,
|
||||
"iat": datetime.datetime.now(tz=timezone.utc).timestamp(),
|
||||
"aud": "litellm-ui",
|
||||
}
|
||||
invalid_token = jwt.encode(valid_payload, "wrong_key", algorithm="HS256")
|
||||
|
||||
# Test UI token with invalid signature
|
||||
with pytest.raises(ValueError, match="Invalid UI token"):
|
||||
jwt_handler._validate_ui_token(invalid_token)
|
||||
|
||||
def test_validate_non_ui_token(self, jwt_handler, setup_mocks):
|
||||
# Setup
|
||||
test_master_key = setup_mocks
|
||||
|
||||
# Create a non-UI token
|
||||
non_ui_payload = {
|
||||
"sub": "user123",
|
||||
"exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600,
|
||||
}
|
||||
non_ui_token = jwt.encode(non_ui_payload, test_master_key, algorithm="HS256")
|
||||
|
||||
# Test non-UI token
|
||||
result = jwt_handler._validate_ui_token(non_ui_token)
|
||||
assert result is None
|
|
@ -1,179 +0,0 @@
|
|||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
|
||||
|
||||
class TestUISessionHandler:
|
||||
|
||||
def test_get_latest_ui_cookie_name(self):
|
||||
# Test with multiple cookies
|
||||
cookies = {
|
||||
"litellm_ui_token_1000": "value1",
|
||||
"litellm_ui_token_2000": "value2",
|
||||
"other_cookie": "other_value",
|
||||
}
|
||||
|
||||
result = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||
assert result == "litellm_ui_token_2000"
|
||||
|
||||
# Test with no matching cookies
|
||||
cookies = {"other_cookie": "value"}
|
||||
result = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||
assert result is None
|
||||
|
||||
def test_get_ui_session_token_from_cookies(self):
|
||||
# Create mock request with cookies
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {
|
||||
"litellm_ui_token_1000": "test_token",
|
||||
"other_cookie": "other_value",
|
||||
}
|
||||
|
||||
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||
assert result == "test_token"
|
||||
|
||||
# Test with no matching cookies
|
||||
mock_request.cookies = {"other_cookie": "value"}
|
||||
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||
assert result is None
|
||||
|
||||
@patch("litellm.proxy.proxy_server.master_key", "test_master_key")
|
||||
@patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{"litellm_key_header_name": "X-API-Key"},
|
||||
)
|
||||
def test_build_authenticated_ui_jwt_token(self):
|
||||
# Test token generation
|
||||
token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||
user_id="test_user",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_email="test@example.com",
|
||||
premium_user=True,
|
||||
disabled_non_admin_personal_key_creation=False,
|
||||
login_method="username_password",
|
||||
)
|
||||
|
||||
# Decode and verify token
|
||||
decoded = jwt.decode(
|
||||
token, "test_master_key", algorithms=["HS256"], audience="litellm-ui"
|
||||
)
|
||||
|
||||
assert decoded["user_id"] == "test_user"
|
||||
assert decoded["user_email"] == "test@example.com"
|
||||
assert decoded["user_role"] == LitellmUserRoles.PROXY_ADMIN
|
||||
assert decoded["premium_user"] is True
|
||||
assert decoded["login_method"] == "username_password"
|
||||
assert decoded["auth_header_name"] == "X-API-Key"
|
||||
assert decoded["iss"] == "litellm-proxy"
|
||||
assert decoded["aud"] == "litellm-ui"
|
||||
assert "exp" in decoded
|
||||
assert decoded["disabled_non_admin_personal_key_creation"] is False
|
||||
assert decoded["scope"] == ["litellm_proxy_admin"]
|
||||
|
||||
def test_is_ui_session_token(self):
|
||||
# Valid UI session token
|
||||
token_dict = {
|
||||
"iss": "litellm-proxy",
|
||||
"aud": "litellm-ui",
|
||||
"user_id": "test_user",
|
||||
}
|
||||
assert UISessionHandler.is_ui_session_token(token_dict) is True
|
||||
|
||||
# Invalid token (wrong issuer)
|
||||
token_dict = {
|
||||
"iss": "other-issuer",
|
||||
"aud": "litellm-ui",
|
||||
}
|
||||
assert UISessionHandler.is_ui_session_token(token_dict) is False
|
||||
|
||||
# Invalid token (wrong audience)
|
||||
token_dict = {
|
||||
"iss": "litellm-proxy",
|
||||
"aud": "other-audience",
|
||||
}
|
||||
assert UISessionHandler.is_ui_session_token(token_dict) is False
|
||||
|
||||
def test_generate_authenticated_redirect_response(self):
|
||||
redirect_url = "https://example.com/dashboard"
|
||||
jwt_token = "test.jwt.token"
|
||||
|
||||
response = UISessionHandler.generate_authenticated_redirect_response(
|
||||
redirect_url=redirect_url, jwt_token=jwt_token
|
||||
)
|
||||
|
||||
assert isinstance(response, RedirectResponse)
|
||||
assert response.status_code == 303
|
||||
assert response.headers["location"] == redirect_url
|
||||
|
||||
# Check cookie was set
|
||||
cookie_header = response.headers.get("set-cookie", "")
|
||||
assert "test.jwt.token" in cookie_header
|
||||
assert "Secure" in cookie_header
|
||||
assert "HttpOnly" in cookie_header
|
||||
assert "SameSite=strict" in cookie_header
|
||||
|
||||
def test_generate_token_name(self):
|
||||
# Mock time.time() to return a fixed value
|
||||
with patch("time.time", return_value=1234567890):
|
||||
token_name = UISessionHandler._generate_token_name()
|
||||
assert token_name == "litellm_ui_token_1234567890"
|
||||
|
||||
def test_latest_token_is_used(self):
|
||||
"""Test that the most recent token is correctly identified and used"""
|
||||
# Create mock request with multiple UI tokens of different timestamps
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {
|
||||
"litellm_ui_token_1000": "old_token",
|
||||
"litellm_ui_token_2000": "newer_token",
|
||||
"litellm_ui_token_1500": "middle_token",
|
||||
"other_cookie": "other_value",
|
||||
}
|
||||
|
||||
# Get the token from cookies
|
||||
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||
|
||||
# Verify the newest token (highest timestamp) is returned
|
||||
assert result == "newer_token"
|
||||
|
||||
# Test with timestamps that aren't in numerical order in the dictionary
|
||||
mock_request.cookies = {
|
||||
"litellm_ui_token_5000": "newest_token",
|
||||
"litellm_ui_token_1000": "oldest_token",
|
||||
"litellm_ui_token_3000": "middle_token",
|
||||
"other_cookie": "other_value",
|
||||
}
|
||||
|
||||
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||
assert result == "newest_token"
|
||||
|
||||
# Test with non-numeric timestamp parts
|
||||
mock_request.cookies = {
|
||||
"litellm_ui_token_abc": "invalid_token",
|
||||
"litellm_ui_token_2000": "valid_token",
|
||||
"other_cookie": "other_value",
|
||||
}
|
||||
|
||||
# Should handle the error and still return a token (fallback to string sort)
|
||||
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||
assert result is not None
|
|
@ -1,124 +0,0 @@
|
|||
import pytest
|
||||
import time
|
||||
import jwt
|
||||
from datetime import datetime, timezone
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
|
||||
class TestUISessionHandler:
|
||||
|
||||
def test_get_latest_ui_cookie_name(self):
|
||||
# Test with multiple cookies
|
||||
cookies = {
|
||||
"litellm_ui_token_1000": "value1",
|
||||
"litellm_ui_token_2000": "value2",
|
||||
"other_cookie": "other_value",
|
||||
}
|
||||
|
||||
result = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||
assert result == "litellm_ui_token_2000"
|
||||
|
||||
# Test with no matching cookies
|
||||
cookies = {"other_cookie": "value"}
|
||||
result = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||
assert result is None
|
||||
|
||||
def test_get_ui_session_token_from_cookies(self):
|
||||
# Create mock request with cookies
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {
|
||||
"litellm_ui_token_1000": "test_token",
|
||||
"other_cookie": "other_value",
|
||||
}
|
||||
|
||||
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||
assert result == "test_token"
|
||||
|
||||
# Test with no matching cookies
|
||||
mock_request.cookies = {"other_cookie": "value"}
|
||||
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||
assert result is None
|
||||
|
||||
@patch("litellm.proxy.proxy_server.master_key", "test_master_key")
|
||||
@patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{"litellm_key_header_name": "X-API-Key"},
|
||||
)
|
||||
def test_build_authenticated_ui_jwt_token(self):
|
||||
# Test token generation
|
||||
token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||
user_id="test_user",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_email="test@example.com",
|
||||
premium_user=True,
|
||||
disabled_non_admin_personal_key_creation=False,
|
||||
login_method="username_password",
|
||||
)
|
||||
|
||||
# Decode and verify token
|
||||
decoded = jwt.decode(token, "test_master_key", algorithms=["HS256"])
|
||||
|
||||
assert decoded["user_id"] == "test_user"
|
||||
assert decoded["user_email"] == "test@example.com"
|
||||
assert decoded["user_role"] == LitellmUserRoles.PROXY_ADMIN
|
||||
assert decoded["premium_user"] is True
|
||||
assert decoded["login_method"] == "username_password"
|
||||
assert decoded["auth_header_name"] == "X-API-Key"
|
||||
assert decoded["iss"] == "litellm-proxy"
|
||||
assert decoded["aud"] == "litellm-ui"
|
||||
assert "exp" in decoded
|
||||
assert decoded["disabled_non_admin_personal_key_creation"] is False
|
||||
assert decoded["scope"] == ["litellm:admin"]
|
||||
|
||||
def test_is_ui_session_token(self):
|
||||
# Valid UI session token
|
||||
token_dict = {
|
||||
"iss": "litellm-proxy",
|
||||
"aud": "litellm-ui",
|
||||
"user_id": "test_user",
|
||||
}
|
||||
assert UISessionHandler.is_ui_session_token(token_dict) is True
|
||||
|
||||
# Invalid token (wrong issuer)
|
||||
token_dict = {
|
||||
"iss": "other-issuer",
|
||||
"aud": "litellm-ui",
|
||||
}
|
||||
assert UISessionHandler.is_ui_session_token(token_dict) is False
|
||||
|
||||
# Invalid token (wrong audience)
|
||||
token_dict = {
|
||||
"iss": "litellm-proxy",
|
||||
"aud": "other-audience",
|
||||
}
|
||||
assert UISessionHandler.is_ui_session_token(token_dict) is False
|
||||
|
||||
def test_generate_authenticated_redirect_response(self):
|
||||
redirect_url = "https://example.com/dashboard"
|
||||
jwt_token = "test.jwt.token"
|
||||
|
||||
response = UISessionHandler.generate_authenticated_redirect_response(
|
||||
redirect_url=redirect_url, jwt_token=jwt_token
|
||||
)
|
||||
|
||||
assert isinstance(response, RedirectResponse)
|
||||
assert response.status_code == 303
|
||||
assert response.headers["location"] == redirect_url
|
||||
|
||||
# Check cookie was set
|
||||
cookie_header = response.headers.get("set-cookie", "")
|
||||
assert "test.jwt.token" in cookie_header
|
||||
assert "Secure" in cookie_header
|
||||
assert "HttpOnly" in cookie_header
|
||||
assert "SameSite=strict" in cookie_header
|
||||
|
||||
def test_generate_token_name(self):
|
||||
# Mock time.time() to return a fixed value
|
||||
with patch("time.time", return_value=1234567890):
|
||||
token_name = UISessionHandler._generate_token_name()
|
||||
assert token_name == "litellm_ui_token_1234567890"
|
|
@ -20,12 +20,12 @@ import {
|
|||
} from "@/components/networking";
|
||||
import { jwtDecode } from "jwt-decode";
|
||||
import { Form, Button as Button2, message } from "antd";
|
||||
import { getUISessionDetails, setAuthToken } from "@/utils/cookieUtils";
|
||||
import { getCookie } from "@/utils/cookieUtils";
|
||||
|
||||
export default function Onboarding() {
|
||||
const [form] = Form.useForm();
|
||||
const searchParams = useSearchParams()!;
|
||||
const token = getUISessionDetails();
|
||||
const token = getCookie('token');
|
||||
const inviteID = searchParams.get("invitation_id");
|
||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||
const [defaultUserEmail, setDefaultUserEmail] = useState<string>("");
|
||||
|
@ -88,7 +88,7 @@ export default function Onboarding() {
|
|||
litellm_dashboard_ui += "?userID=" + user_id;
|
||||
|
||||
// set cookie "token" to jwtToken
|
||||
setAuthToken(jwtToken);
|
||||
document.cookie = "token=" + jwtToken;
|
||||
console.log("redirecting to:", litellm_dashboard_ui);
|
||||
|
||||
window.location.href = litellm_dashboard_ui;
|
||||
|
|
|
@ -30,7 +30,12 @@ import { Organization } from "@/components/networking";
|
|||
import GuardrailsPanel from "@/components/guardrails";
|
||||
import { fetchUserModels } from "@/components/create_key_button";
|
||||
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
||||
import { getUISessionDetails } from "@/utils/cookieUtils";
|
||||
function getCookie(name: string) {
|
||||
const cookieValue = document.cookie
|
||||
.split("; ")
|
||||
.find((row) => row.startsWith(name + "="));
|
||||
return cookieValue ? cookieValue.split("=")[1] : null;
|
||||
}
|
||||
|
||||
function formatUserRole(userRole: string) {
|
||||
if (!userRole) {
|
||||
|
@ -112,59 +117,63 @@ export default function CreateKeyPage() {
|
|||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchSessionDetails = async () => {
|
||||
try {
|
||||
const sessionDetails = await getUISessionDetails();
|
||||
// sessionDetails is already decoded, no need for jwtDecode
|
||||
console.log("Session details:", sessionDetails);
|
||||
|
||||
// Set access token to the session_id
|
||||
setAccessToken(sessionDetails.session_id);
|
||||
|
||||
setDisabledPersonalKeyCreation(
|
||||
sessionDetails.disabled_non_admin_personal_key_creation,
|
||||
);
|
||||
|
||||
if (sessionDetails.user_role) {
|
||||
const formattedUserRole = formatUserRole(sessionDetails.user_role);
|
||||
console.log("User role:", formattedUserRole);
|
||||
setUserRole(formattedUserRole);
|
||||
if (formattedUserRole == "Admin Viewer") {
|
||||
setPage("usage");
|
||||
}
|
||||
} else {
|
||||
console.log("User role not defined");
|
||||
}
|
||||
|
||||
if (sessionDetails.user_email) {
|
||||
setUserEmail(sessionDetails.user_email);
|
||||
} else {
|
||||
console.log("User Email is not set");
|
||||
}
|
||||
|
||||
if (sessionDetails.login_method) {
|
||||
setShowSSOBanner(
|
||||
sessionDetails.login_method == "username_password" ? true : false,
|
||||
);
|
||||
}
|
||||
|
||||
if (sessionDetails.premium_user) {
|
||||
setPremiumUser(sessionDetails.premium_user);
|
||||
}
|
||||
|
||||
if (sessionDetails.auth_header_name) {
|
||||
setGlobalLitellmHeaderName(sessionDetails.auth_header_name);
|
||||
}
|
||||
|
||||
// Store the full session details as token for components that need it
|
||||
setToken(JSON.stringify(sessionDetails));
|
||||
} catch (error) {
|
||||
console.error("Error fetching session details:", error);
|
||||
}
|
||||
};
|
||||
|
||||
fetchSessionDetails();
|
||||
const token = getCookie("token");
|
||||
setToken(token);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!token) {
|
||||
return;
|
||||
}
|
||||
|
||||
const decoded = jwtDecode(token) as { [key: string]: any };
|
||||
if (decoded) {
|
||||
// cast decoded to dictionary
|
||||
console.log("Decoded token:", decoded);
|
||||
|
||||
console.log("Decoded key:", decoded.key);
|
||||
// set accessToken
|
||||
setAccessToken(decoded.key);
|
||||
|
||||
setDisabledPersonalKeyCreation(
|
||||
decoded.disabled_non_admin_personal_key_creation,
|
||||
);
|
||||
|
||||
// check if userRole is defined
|
||||
if (decoded.user_role) {
|
||||
const formattedUserRole = formatUserRole(decoded.user_role);
|
||||
console.log("Decoded user_role:", formattedUserRole);
|
||||
setUserRole(formattedUserRole);
|
||||
if (formattedUserRole == "Admin Viewer") {
|
||||
setPage("usage");
|
||||
}
|
||||
} else {
|
||||
console.log("User role not defined");
|
||||
}
|
||||
|
||||
if (decoded.user_email) {
|
||||
setUserEmail(decoded.user_email);
|
||||
} else {
|
||||
console.log(`User Email is not set ${decoded}`);
|
||||
}
|
||||
|
||||
if (decoded.login_method) {
|
||||
setShowSSOBanner(
|
||||
decoded.login_method == "username_password" ? true : false,
|
||||
);
|
||||
} else {
|
||||
console.log(`User Email is not set ${decoded}`);
|
||||
}
|
||||
|
||||
if (decoded.premium_user) {
|
||||
setPremiumUser(decoded.premium_user);
|
||||
}
|
||||
|
||||
if (decoded.auth_header_name) {
|
||||
setGlobalLitellmHeaderName(decoded.auth_header_name);
|
||||
}
|
||||
}
|
||||
}, [token]);
|
||||
|
||||
useEffect(() => {
|
||||
if (accessToken && userID && userRole) {
|
||||
|
|
|
@ -462,7 +462,6 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
|
||||
useEffect(() => {
|
||||
if (!accessToken || !token || !userRole || !userID) {
|
||||
console.log("exiting on model_dashboard.tsx because of missing accessToken, token, userRole, or userID", accessToken, token, userRole, userID);
|
||||
return;
|
||||
}
|
||||
const fetchData = async () => {
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -21,7 +21,6 @@ import { useSearchParams, useRouter } from "next/navigation";
|
|||
import { Team } from "./key_team_helpers/key_list";
|
||||
import { jwtDecode } from "jwt-decode";
|
||||
import { Typography } from "antd";
|
||||
import { getUISessionDetails } from "@/utils/cookieUtils";
|
||||
import { clearTokenCookies } from "@/utils/cookieUtils";
|
||||
const isLocal = process.env.NODE_ENV === "development";
|
||||
if (isLocal != true) {
|
||||
|
@ -46,6 +45,14 @@ export type UserInfo = {
|
|||
spend: number;
|
||||
}
|
||||
|
||||
function getCookie(name: string) {
|
||||
console.log("COOKIES", document.cookie)
|
||||
const cookieValue = document.cookie
|
||||
.split('; ')
|
||||
.find(row => row.startsWith(name + '='));
|
||||
return cookieValue ? cookieValue.split('=')[1] : null;
|
||||
}
|
||||
|
||||
interface UserDashboardProps {
|
||||
userID: string | null;
|
||||
userRole: string | null;
|
||||
|
@ -87,7 +94,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
// Assuming useSearchParams() hook exists and works in your setup
|
||||
const searchParams = useSearchParams()!;
|
||||
|
||||
const token = getUISessionDetails();
|
||||
const token = getCookie('token');
|
||||
|
||||
const invitation_id = searchParams.get("invitation_id");
|
||||
|
||||
|
@ -139,37 +146,32 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
// console.log(`selectedTeam: ${Object.entries(selectedTeam)}`);
|
||||
// Moved useEffect inside the component and used a condition to run fetch only if the params are available
|
||||
useEffect(() => {
|
||||
const fetchSessionDetails = async () => {
|
||||
try {
|
||||
const sessionDetails = await getUISessionDetails();
|
||||
console.log("Session details:", sessionDetails);
|
||||
|
||||
// Set access token to the session_id
|
||||
setAccessToken(sessionDetails.session_id);
|
||||
|
||||
if (token) {
|
||||
const decoded = jwtDecode(token) as { [key: string]: any };
|
||||
if (decoded) {
|
||||
// cast decoded to dictionary
|
||||
console.log("Decoded token:", decoded);
|
||||
|
||||
console.log("Decoded key:", decoded.key);
|
||||
// set accessToken
|
||||
setAccessToken(decoded.key);
|
||||
|
||||
// check if userRole is defined
|
||||
if (sessionDetails.user_role) {
|
||||
const formattedUserRole = formatUserRole(sessionDetails.user_role);
|
||||
console.log("User role:", formattedUserRole);
|
||||
if (decoded.user_role) {
|
||||
const formattedUserRole = formatUserRole(decoded.user_role);
|
||||
console.log("Decoded user_role:", formattedUserRole);
|
||||
setUserRole(formattedUserRole);
|
||||
} else {
|
||||
console.log("User role not defined");
|
||||
}
|
||||
|
||||
if (sessionDetails.user_email) {
|
||||
setUserEmail(sessionDetails.user_email);
|
||||
if (decoded.user_email) {
|
||||
setUserEmail(decoded.user_email);
|
||||
} else {
|
||||
console.log("User Email is not set");
|
||||
console.log(`User Email is not set ${decoded}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching session details:", error);
|
||||
}
|
||||
};
|
||||
|
||||
fetchSessionDetails();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
}
|
||||
if (userID && accessToken && userRole && !keys && !userSpendData) {
|
||||
const cachedUserModels = sessionStorage.getItem("userModels" + userID);
|
||||
if (cachedUserModels) {
|
||||
|
@ -244,7 +246,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
fetchTeams(accessToken, userID, userRole, currentOrg, setTeams);
|
||||
}
|
||||
}
|
||||
}, [userID, accessToken, keys, userRole]);
|
||||
}, [userID, token, accessToken, keys, userRole]);
|
||||
|
||||
useEffect(() => {
|
||||
console.log(`currentOrg: ${JSON.stringify(currentOrg)}, accessToken: ${accessToken}, userID: ${userID}, userRole: ${userRole}`)
|
||||
|
@ -331,35 +333,31 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
<div className="w-full mx-4 h-[75vh]">
|
||||
<Grid numItems={1} className="gap-2 p-8 w-full mt-2">
|
||||
<Col numColSpan={1} className="flex flex-col gap-2">
|
||||
{accessToken && (
|
||||
<>
|
||||
<CreateKey
|
||||
key={selectedTeam ? selectedTeam.team_id : null}
|
||||
userID={userID}
|
||||
team={selectedTeam as Team | null}
|
||||
teams={teams as Team[]}
|
||||
userRole={userRole}
|
||||
accessToken={accessToken}
|
||||
data={keys}
|
||||
setData={setKeys}
|
||||
/>
|
||||
<CreateKey
|
||||
key={selectedTeam ? selectedTeam.team_id : null}
|
||||
userID={userID}
|
||||
team={selectedTeam as Team | null}
|
||||
teams={teams as Team[]}
|
||||
userRole={userRole}
|
||||
accessToken={accessToken}
|
||||
data={keys}
|
||||
setData={setKeys}
|
||||
/>
|
||||
|
||||
<ViewKeyTable
|
||||
userID={userID}
|
||||
userRole={userRole}
|
||||
accessToken={accessToken}
|
||||
selectedTeam={selectedTeam ? selectedTeam : null}
|
||||
setSelectedTeam={setSelectedTeam}
|
||||
data={keys}
|
||||
setData={setKeys}
|
||||
premiumUser={premiumUser}
|
||||
teams={teams}
|
||||
currentOrg={currentOrg}
|
||||
setCurrentOrg={setCurrentOrg}
|
||||
organizations={organizations}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<ViewKeyTable
|
||||
userID={userID}
|
||||
userRole={userRole}
|
||||
accessToken={accessToken}
|
||||
selectedTeam={selectedTeam ? selectedTeam : null}
|
||||
setSelectedTeam={setSelectedTeam}
|
||||
data={keys}
|
||||
setData={setKeys}
|
||||
premiumUser={premiumUser}
|
||||
teams={teams}
|
||||
currentOrg={currentOrg}
|
||||
setCurrentOrg={setCurrentOrg}
|
||||
organizations={organizations}
|
||||
/>
|
||||
</Col>
|
||||
</Grid>
|
||||
</div>
|
||||
|
|
|
@ -5,24 +5,6 @@
|
|||
/**
|
||||
* Clears the token cookie from both root and /ui paths
|
||||
*/
|
||||
import { validateSession } from "../components/networking"
|
||||
|
||||
// Define the interface for the JWT token data
|
||||
export interface JWTTokenData {
|
||||
user_id: string;
|
||||
user_email: string | null;
|
||||
user_role: string;
|
||||
login_method: string;
|
||||
premium_user: boolean;
|
||||
auth_header_name: string;
|
||||
iss: string;
|
||||
aud: string;
|
||||
exp: number;
|
||||
disabled_non_admin_personal_key_creation: boolean;
|
||||
scopes: string[];
|
||||
session_id: string; // ui session id currently in progress
|
||||
}
|
||||
|
||||
export function clearTokenCookies() {
|
||||
// Get the current domain
|
||||
const domain = window.location.hostname;
|
||||
|
@ -31,51 +13,32 @@ export function clearTokenCookies() {
|
|||
const paths = ['/', '/ui'];
|
||||
const sameSiteValues = ['Lax', 'Strict', 'None'];
|
||||
|
||||
// Get all cookies
|
||||
const allCookies = document.cookie.split("; ");
|
||||
const tokenPattern = /^token_\d+$/;
|
||||
|
||||
// Find all token cookies
|
||||
const tokenCookieNames = allCookies
|
||||
.map(cookie => cookie.split("=")[0])
|
||||
.filter(name => name === "token" || tokenPattern.test(name));
|
||||
|
||||
// Clear each token cookie with various combinations
|
||||
tokenCookieNames.forEach(cookieName => {
|
||||
paths.forEach(path => {
|
||||
// Basic clearing
|
||||
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
|
||||
|
||||
// With domain
|
||||
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
|
||||
|
||||
// Try different SameSite values
|
||||
sameSiteValues.forEach(sameSite => {
|
||||
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
|
||||
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
|
||||
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
|
||||
});
|
||||
paths.forEach(path => {
|
||||
// Basic clearing
|
||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
|
||||
|
||||
// With domain
|
||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
|
||||
|
||||
// Try different SameSite values
|
||||
sameSiteValues.forEach(sameSite => {
|
||||
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
|
||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
|
||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
|
||||
});
|
||||
});
|
||||
|
||||
console.log("After clearing cookies:", document.cookie);
|
||||
}
|
||||
|
||||
export function setAuthToken(token: string) {
|
||||
// Generate a token name with current timestamp
|
||||
const currentTimestamp = Math.floor(Date.now() / 1000);
|
||||
const tokenName = `token_${currentTimestamp}`;
|
||||
|
||||
// Set the cookie with the timestamp-based name
|
||||
document.cookie = `${tokenName}=${token}; path=/; domain=${window.location.hostname};`;
|
||||
}
|
||||
|
||||
export async function getUISessionDetails(): Promise<JWTTokenData> {
|
||||
const validated_jwt_token = await validateSession();
|
||||
|
||||
if (validated_jwt_token?.data) {
|
||||
return validated_jwt_token.data as JWTTokenData;
|
||||
} else {
|
||||
throw new Error("Invalid JWT token");
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Gets a cookie value by name
|
||||
* @param name The name of the cookie to retrieve
|
||||
* @returns The cookie value or null if not found
|
||||
*/
|
||||
export function getCookie(name: string) {
|
||||
const cookieValue = document.cookie
|
||||
.split('; ')
|
||||
.find(row => row.startsWith(name + '='));
|
||||
return cookieValue ? cookieValue.split('=')[1] : null;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue