mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(UI) - Security Improvement, move to JWT Auth for Admin UI Sessions (#8995)
* (UI) - Improvements to session handling logic (#8970)
* add cookieUtils
* use utils for clearing cookies
* on logout use clearTokenCookies
* ui use correct clearTokenCookies
* navbar show userEmail on UserID page
* add timestamp on token cookie
* update generate_authenticated_redirect_response
* use common getAuthToken
* fix clearTokenCookies
* fixes for get auth token
* fix invitation link sign in logic
* Revert "fix invitation link sign in logic"
This reverts commit 30e5308cb3
.
* fix getAuthToken
* update setAuthToken
* fix ui session handling
* fix ui session handler
* bug fix stop generating LiteLLM Virtual keys for access
* working JWT insert into cookies
* use central place to build UI JWT token
* add _validate_ui_token
* fix ui session handler
* fix fetchWithCredentials
* check allowed routes for ui session tokens
* expose validate_session endpoint
* validate session endpoint
* call sso/session/validate
* getUISessionDetails
* ui move to getUISessionDetails
* /sso/session/validate
* fix cookie utils
* use getUISessionDetails
* use ui_session_id
* "/spend/logs/ui" in spend_tracking_routes
* working sign in JWT flow for proxy admin
* allow proxy admin to access ui routes
* use check_route_access
* update types
* update login method
* fixes to ui session handler
* working flow for admin and internal users
* fixes for invite links
* use JWTs for SSO sign in
* fix /invitation/new flow
* fix code quality checks
* fix _get_ui_session_token_from_cookies
* /organization/list
* ui sso sign in
* TestUISessionHandler
* TestUISessionHandler
This commit is contained in:
parent
42931638df
commit
01a44a4e47
17 changed files with 1104 additions and 538 deletions
|
@ -272,6 +272,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/key/health",
|
"/key/health",
|
||||||
"/team/info",
|
"/team/info",
|
||||||
"/team/list",
|
"/team/list",
|
||||||
|
"/organization/info",
|
||||||
"/organization/list",
|
"/organization/list",
|
||||||
"/team/available",
|
"/team/available",
|
||||||
"/user/info",
|
"/user/info",
|
||||||
|
@ -282,6 +283,11 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/health",
|
"/health",
|
||||||
"/key/list",
|
"/key/list",
|
||||||
"/user/filter/ui",
|
"/user/filter/ui",
|
||||||
|
"/user/list",
|
||||||
|
"/user/available_roles",
|
||||||
|
"/guardrails/list",
|
||||||
|
"/cache/ping",
|
||||||
|
"/get/config/callbacks",
|
||||||
]
|
]
|
||||||
|
|
||||||
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
||||||
|
@ -300,6 +306,8 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/user/update",
|
"/user/update",
|
||||||
"/user/delete",
|
"/user/delete",
|
||||||
"/user/info",
|
"/user/info",
|
||||||
|
# user invitation management
|
||||||
|
"/invitation/new",
|
||||||
# team
|
# team
|
||||||
"/team/new",
|
"/team/new",
|
||||||
"/team/update",
|
"/team/update",
|
||||||
|
@ -309,6 +317,20 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/team/block",
|
"/team/block",
|
||||||
"/team/unblock",
|
"/team/unblock",
|
||||||
"/team/available",
|
"/team/available",
|
||||||
|
# team member management
|
||||||
|
"/team/member_add",
|
||||||
|
"/team/member_update",
|
||||||
|
"/team/member_delete",
|
||||||
|
# organization management
|
||||||
|
"/organization/new",
|
||||||
|
"/organization/update",
|
||||||
|
"/organization/delete",
|
||||||
|
"/organization/info",
|
||||||
|
"/organization/list",
|
||||||
|
# organization member management
|
||||||
|
"/organization/member_add",
|
||||||
|
"/organization/member_update",
|
||||||
|
"/organization/member_delete",
|
||||||
# model
|
# model
|
||||||
"/model/new",
|
"/model/new",
|
||||||
"/model/update",
|
"/model/update",
|
||||||
|
@ -355,20 +377,32 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/sso",
|
"/sso",
|
||||||
"/sso/get/ui_settings",
|
"/sso/get/ui_settings",
|
||||||
"/login",
|
"/login",
|
||||||
|
"/sso/session/validate",
|
||||||
"/key/info",
|
"/key/info",
|
||||||
"/config",
|
"/config",
|
||||||
"/spend",
|
"/spend",
|
||||||
"/model/info",
|
"/model/info",
|
||||||
|
"/model/metrics",
|
||||||
|
"/model/metrics/{sub_path}",
|
||||||
|
"/model/settings",
|
||||||
|
"/get/litellm_model_cost_map",
|
||||||
|
"/model/streaming_metrics",
|
||||||
"/v2/model/info",
|
"/v2/model/info",
|
||||||
"/v2/key/info",
|
"/v2/key/info",
|
||||||
"/models",
|
"/models",
|
||||||
"/v1/models",
|
"/v1/models",
|
||||||
"/global/spend",
|
"/global/spend",
|
||||||
"/global/spend/logs",
|
"/global/spend/logs",
|
||||||
|
"/spend/logs/ui",
|
||||||
|
"/spend/logs/ui/{id}",
|
||||||
"/global/spend/keys",
|
"/global/spend/keys",
|
||||||
"/global/spend/models",
|
"/global/spend/models",
|
||||||
"/global/predict/spend/logs",
|
"/global/predict/spend/logs",
|
||||||
"/global/activity",
|
"/global/activity",
|
||||||
|
"/global/activity/{sub_path}",
|
||||||
|
"/global/activity/exceptions",
|
||||||
|
"/global/activity/exceptions/{sub_path}",
|
||||||
|
"/global/all_end_users",
|
||||||
"/health/services",
|
"/health/services",
|
||||||
] + info_routes
|
] + info_routes
|
||||||
|
|
||||||
|
@ -2459,6 +2493,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
"spend_tracking_routes",
|
"spend_tracking_routes",
|
||||||
"global_spend_tracking_routes",
|
"global_spend_tracking_routes",
|
||||||
"info_routes",
|
"info_routes",
|
||||||
|
"ui_routes",
|
||||||
]
|
]
|
||||||
team_id_jwt_field: Optional[str] = None
|
team_id_jwt_field: Optional[str] = None
|
||||||
team_id_upsert: bool = False
|
team_id_upsert: bool = False
|
||||||
|
|
|
@ -204,9 +204,11 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for allowed_route in allowed_routes:
|
for allowed_route in allowed_routes:
|
||||||
if (
|
if allowed_route in LiteLLMRoutes.__members__ and (
|
||||||
allowed_route in LiteLLMRoutes.__members__
|
RouteChecks.check_route_access(
|
||||||
and user_route in LiteLLMRoutes[allowed_route].value
|
route=user_route,
|
||||||
|
allowed_routes=LiteLLMRoutes[allowed_route].value,
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
elif allowed_route == user_route:
|
elif allowed_route == user_route:
|
||||||
|
@ -217,16 +219,18 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||||
def allowed_routes_check(
|
def allowed_routes_check(
|
||||||
user_role: Literal[
|
user_role: Literal[
|
||||||
LitellmUserRoles.PROXY_ADMIN,
|
LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||||||
LitellmUserRoles.TEAM,
|
LitellmUserRoles.TEAM,
|
||||||
LitellmUserRoles.INTERNAL_USER,
|
LitellmUserRoles.INTERNAL_USER,
|
||||||
|
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||||
],
|
],
|
||||||
user_route: str,
|
user_route: str,
|
||||||
litellm_proxy_roles: LiteLLM_JWTAuth,
|
litellm_proxy_roles: LiteLLM_JWTAuth,
|
||||||
|
jwt_valid_token: dict,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if user -> not admin - allowed to access these routes
|
Check if user -> not admin - allowed to access these routes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if user_role == LitellmUserRoles.PROXY_ADMIN:
|
if user_role == LitellmUserRoles.PROXY_ADMIN:
|
||||||
is_allowed = _allowed_routes_check(
|
is_allowed = _allowed_routes_check(
|
||||||
user_route=user_route,
|
user_route=user_route,
|
||||||
|
|
|
@ -33,6 +33,7 @@ from litellm.proxy._types import (
|
||||||
ScopeMapping,
|
ScopeMapping,
|
||||||
Span,
|
Span,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
|
|
||||||
from .auth_checks import (
|
from .auth_checks import (
|
||||||
|
@ -406,10 +407,60 @@ class JWTHandler:
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _validate_ui_token(self, token: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Helper function to validate tokens generated for the LiteLLM UI.
|
||||||
|
Returns the decoded payload if it's a valid UI token, None otherwise.
|
||||||
|
"""
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import master_key
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Decode without verification to check if it's a UI token
|
||||||
|
unverified_payload = jwt.decode(token, options={"verify_signature": False})
|
||||||
|
|
||||||
|
# Check if this looks like a UI token (has specific claims that only UI tokens would have)
|
||||||
|
if UISessionHandler.is_ui_session_token(unverified_payload):
|
||||||
|
|
||||||
|
# This looks like a UI token, now verify it with the master key
|
||||||
|
if not master_key:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Missing LITELLM_MASTER_KEY for UI token validation"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
master_key,
|
||||||
|
algorithms=["HS256"],
|
||||||
|
audience="litellm-ui",
|
||||||
|
leeway=self.leeway,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Successfully validated UI token for payload: %s",
|
||||||
|
json.dumps(payload, indent=4),
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
verbose_proxy_logger.debug(f"Invalid UI token: {str(e)}")
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid UI token, Unable to validate token signature {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return None # Not a UI token
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
async def auth_jwt(self, token: str) -> dict:
|
async def auth_jwt(self, token: str) -> dict:
|
||||||
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
|
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
|
||||||
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
|
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
|
||||||
# the key in different ways (e.g. HS* and RS*)."
|
# the key in different ways (e.g. HS* and RS*)."
|
||||||
|
|
||||||
|
ui_payload = self._validate_ui_token(token)
|
||||||
|
if ui_payload:
|
||||||
|
return ui_payload
|
||||||
algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
|
algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
|
||||||
|
|
||||||
audience = os.getenv("JWT_AUDIENCE")
|
audience = os.getenv("JWT_AUDIENCE")
|
||||||
|
@ -616,6 +667,7 @@ class JWTAuthManager:
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
org_id: Optional[str],
|
org_id: Optional[str],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
jwt_valid_token: dict,
|
||||||
) -> Optional[JWTAuthBuilderResult]:
|
) -> Optional[JWTAuthBuilderResult]:
|
||||||
"""Check admin status and route access permissions"""
|
"""Check admin status and route access permissions"""
|
||||||
if not jwt_handler.is_admin(scopes=scopes):
|
if not jwt_handler.is_admin(scopes=scopes):
|
||||||
|
@ -625,6 +677,7 @@ class JWTAuthManager:
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
user_route=route,
|
user_route=route,
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
|
jwt_valid_token=jwt_valid_token,
|
||||||
)
|
)
|
||||||
if not is_allowed:
|
if not is_allowed:
|
||||||
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||||
|
@ -698,6 +751,7 @@ class JWTAuthManager:
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
parent_otel_span: Optional[Span],
|
parent_otel_span: Optional[Span],
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
jwt_valid_token: dict,
|
||||||
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
||||||
"""Find first team with access to the requested model"""
|
"""Find first team with access to the requested model"""
|
||||||
|
|
||||||
|
@ -730,6 +784,7 @@ class JWTAuthManager:
|
||||||
user_role=LitellmUserRoles.TEAM,
|
user_role=LitellmUserRoles.TEAM,
|
||||||
user_route=route,
|
user_route=route,
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
|
jwt_valid_token=jwt_valid_token,
|
||||||
)
|
)
|
||||||
if is_allowed:
|
if is_allowed:
|
||||||
return team_id, team_object
|
return team_id, team_object
|
||||||
|
@ -920,7 +975,13 @@ class JWTAuthManager:
|
||||||
|
|
||||||
# Check admin access
|
# Check admin access
|
||||||
admin_result = await JWTAuthManager.check_admin_access(
|
admin_result = await JWTAuthManager.check_admin_access(
|
||||||
jwt_handler, scopes, route, user_id, org_id, api_key
|
jwt_handler=jwt_handler,
|
||||||
|
scopes=scopes,
|
||||||
|
route=route,
|
||||||
|
user_id=user_id,
|
||||||
|
org_id=org_id,
|
||||||
|
api_key=api_key,
|
||||||
|
jwt_valid_token=jwt_valid_token,
|
||||||
)
|
)
|
||||||
if admin_result:
|
if admin_result:
|
||||||
return admin_result
|
return admin_result
|
||||||
|
@ -952,6 +1013,7 @@ class JWTAuthManager:
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
jwt_valid_token=jwt_valid_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get other objects
|
# Get other objects
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Set, Union
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, status
|
from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
|
@ -225,7 +225,9 @@ class RouteChecks:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
|
def check_route_access(
|
||||||
|
route: str, allowed_routes: Union[List[str], Set[str]]
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a route has access by checking both exact matches and patterns
|
Check if a route has access by checking both exact matches and patterns
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
||||||
from litellm.proxy.auth.route_checks import RouteChecks
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.proxy.auth.service_account_checks import service_account_checks
|
from litellm.proxy.auth.service_account_checks import service_account_checks
|
||||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, _to_ns
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, _to_ns
|
||||||
from litellm.types.services import ServiceTypes
|
from litellm.types.services import ServiceTypes
|
||||||
|
|
||||||
|
@ -335,6 +336,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
"pass_through_endpoints", None
|
"pass_through_endpoints", None
|
||||||
)
|
)
|
||||||
passed_in_key: Optional[str] = None
|
passed_in_key: Optional[str] = None
|
||||||
|
cookie_token: Optional[str] = None
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
passed_in_key = api_key
|
passed_in_key = api_key
|
||||||
api_key = _get_bearer_token(api_key=api_key)
|
api_key = _get_bearer_token(api_key=api_key)
|
||||||
|
@ -344,6 +346,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
api_key = anthropic_api_key_header
|
api_key = anthropic_api_key_header
|
||||||
elif isinstance(google_ai_studio_api_key_header, str):
|
elif isinstance(google_ai_studio_api_key_header, str):
|
||||||
api_key = google_ai_studio_api_key_header
|
api_key = google_ai_studio_api_key_header
|
||||||
|
elif cookie_token := UISessionHandler._get_ui_session_token_from_cookies(
|
||||||
|
request
|
||||||
|
):
|
||||||
|
api_key = cookie_token
|
||||||
elif pass_through_endpoints is not None:
|
elif pass_through_endpoints is not None:
|
||||||
for endpoint in pass_through_endpoints:
|
for endpoint in pass_through_endpoints:
|
||||||
if endpoint.get("path", "") == route:
|
if endpoint.get("path", "") == route:
|
||||||
|
@ -420,7 +426,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
if general_settings.get("enable_oauth2_proxy_auth", False) is True:
|
if general_settings.get("enable_oauth2_proxy_auth", False) is True:
|
||||||
return await handle_oauth2_proxy_request(request=request)
|
return await handle_oauth2_proxy_request(request=request)
|
||||||
|
|
||||||
if general_settings.get("enable_jwt_auth", False) is True:
|
if (
|
||||||
|
general_settings.get("enable_jwt_auth", False) is True
|
||||||
|
or cookie_token is not None
|
||||||
|
):
|
||||||
from litellm.proxy.proxy_server import premium_user
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
|
|
|
@ -8,7 +8,7 @@ Has all /sso/* routes
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
|
@ -43,6 +43,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
||||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -408,11 +409,7 @@ def get_disabled_non_admin_personal_key_creation():
|
||||||
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
||||||
async def auth_callback(request: Request): # noqa: PLR0915
|
async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
"""Verify login"""
|
"""Verify login"""
|
||||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
||||||
generate_key_helper_fn,
|
|
||||||
)
|
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
general_settings,
|
|
||||||
jwt_handler,
|
jwt_handler,
|
||||||
master_key,
|
master_key,
|
||||||
premium_user,
|
premium_user,
|
||||||
|
@ -426,6 +423,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
|
user_role: Optional[LitellmUserRoles] = None
|
||||||
# get url from request
|
# get url from request
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
@ -531,16 +529,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
max_internal_user_budget = litellm.max_internal_user_budget
|
max_internal_user_budget = litellm.max_internal_user_budget
|
||||||
internal_user_budget_duration = litellm.internal_user_budget_duration
|
internal_user_budget_duration = litellm.internal_user_budget_duration
|
||||||
|
|
||||||
# User might not be already created on first generation of key
|
|
||||||
# But if it is, we want their models preferences
|
# But if it is, we want their models preferences
|
||||||
default_ui_key_values: Dict[str, Any] = {
|
|
||||||
"duration": "24hr",
|
|
||||||
"key_max_budget": litellm.max_ui_session_budget,
|
|
||||||
"aliases": {},
|
|
||||||
"config": {},
|
|
||||||
"spend": 0,
|
|
||||||
"team_id": "litellm-dashboard",
|
|
||||||
}
|
|
||||||
user_defined_values: Optional[SSOUserDefinedValues] = None
|
user_defined_values: Optional[SSOUserDefinedValues] = None
|
||||||
|
|
||||||
if user_custom_sso is not None:
|
if user_custom_sso is not None:
|
||||||
|
@ -559,7 +548,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
_user_id_from_sso = user_id
|
_user_id_from_sso = user_id
|
||||||
user_role = None
|
|
||||||
try:
|
try:
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
try:
|
try:
|
||||||
|
@ -632,24 +620,14 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
f"user_defined_values for creating ui key: {user_defined_values}"
|
f"user_defined_values for creating ui key: {user_defined_values}"
|
||||||
)
|
)
|
||||||
|
|
||||||
default_ui_key_values.update(user_defined_values)
|
|
||||||
default_ui_key_values["request_type"] = "key"
|
|
||||||
response = await generate_key_helper_fn(
|
|
||||||
**default_ui_key_values, # type: ignore
|
|
||||||
table_name="key",
|
|
||||||
)
|
|
||||||
|
|
||||||
key = response["token"] # type: ignore
|
|
||||||
user_id = response["user_id"] # type: ignore
|
|
||||||
|
|
||||||
litellm_dashboard_ui = "/ui/"
|
litellm_dashboard_ui = "/ui/"
|
||||||
user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
if (
|
if (
|
||||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||||
):
|
):
|
||||||
# checks if user is admin
|
# checks if user is admin
|
||||||
user_role = LitellmUserRoles.PROXY_ADMIN.value
|
user_role = LitellmUserRoles.PROXY_ADMIN
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
||||||
|
@ -669,30 +647,20 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
disabled_non_admin_personal_key_creation = (
|
disabled_non_admin_personal_key_creation = (
|
||||||
get_disabled_non_admin_personal_key_creation()
|
get_disabled_non_admin_personal_key_creation()
|
||||||
)
|
)
|
||||||
|
jwt_token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||||
import jwt
|
user_id=user_defined_values.get("user_id", ""),
|
||||||
|
user_role=user_role,
|
||||||
jwt_token = jwt.encode( # type: ignore
|
user_email=user_defined_values.get("user_email", ""),
|
||||||
{
|
premium_user=premium_user,
|
||||||
"user_id": user_id,
|
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||||
"key": key,
|
login_method="sso",
|
||||||
"user_email": user_email,
|
|
||||||
"user_role": user_role,
|
|
||||||
"login_method": "sso",
|
|
||||||
"premium_user": premium_user,
|
|
||||||
"auth_header_name": general_settings.get(
|
|
||||||
"litellm_key_header_name", "Authorization"
|
|
||||||
),
|
|
||||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
|
||||||
},
|
|
||||||
master_key,
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
)
|
||||||
if user_id is not None and isinstance(user_id, str):
|
if user_id is not None and isinstance(user_id, str):
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?userID=" + user_id
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
return UISessionHandler.generate_authenticated_redirect_response(
|
||||||
return redirect_response
|
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def insert_sso_user(
|
async def insert_sso_user(
|
||||||
|
@ -778,3 +746,25 @@ async def get_ui_settings(request: Request):
|
||||||
),
|
),
|
||||||
"DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries,
|
"DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/sso/session/validate",
|
||||||
|
include_in_schema=False,
|
||||||
|
tags=["experimental"],
|
||||||
|
)
|
||||||
|
async def validate_session(
|
||||||
|
request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
|
|
||||||
|
ui_session_token = UISessionHandler._get_ui_session_token_from_cookies(request)
|
||||||
|
ui_session_id = UISessionHandler._get_latest_ui_cookie_name(request.cookies)
|
||||||
|
if ui_session_token is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
|
validated_jwt_token = await jwt_handler.auth_jwt(token=ui_session_token)
|
||||||
|
validated_jwt_token["session_id"] = ui_session_id
|
||||||
|
return {"valid": True, "data": validated_jwt_token}
|
||||||
|
|
141
litellm/proxy/management_helpers/ui_session_handler.py
Normal file
141
litellm/proxy/management_helpers/ui_session_handler.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from fastapi.requests import Request
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import LiteLLM_JWTAuth, LitellmUserRoles
|
||||||
|
|
||||||
|
|
||||||
|
class UISessionHandler:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_latest_ui_cookie_name(cookies: dict) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the name of the most recent LiteLLM UI cookie
|
||||||
|
"""
|
||||||
|
# Find all LiteLLM UI cookies (format: litellm_ui_token_{timestamp})
|
||||||
|
litellm_ui_cookies = [
|
||||||
|
k for k in cookies.keys() if k.startswith("litellm_ui_token_")
|
||||||
|
]
|
||||||
|
if not litellm_ui_cookies:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort by timestamp (descending) to get the most recent one
|
||||||
|
try:
|
||||||
|
# Extract timestamps and sort numerically
|
||||||
|
sorted_cookies = sorted(
|
||||||
|
litellm_ui_cookies,
|
||||||
|
key=lambda x: int(x.split("_")[-1]),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
return sorted_cookies[0]
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# Fallback to simple string sort if timestamp extraction fails
|
||||||
|
litellm_ui_cookies.sort(reverse=True)
|
||||||
|
return litellm_ui_cookies[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Add this function to extract auth token from cookies
|
||||||
|
def _get_ui_session_token_from_cookies(request: Request) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract authentication token from cookies if present
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cookies = request.cookies
|
||||||
|
verbose_proxy_logger.debug(f"AUTH COOKIES: {cookies}")
|
||||||
|
|
||||||
|
cookie_name = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||||
|
if cookie_name:
|
||||||
|
return cookies[cookie_name]
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
f"Error getting UI session token from cookies: {e}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_authenticated_ui_jwt_token(
|
||||||
|
user_id: str,
|
||||||
|
user_role: Optional[LitellmUserRoles],
|
||||||
|
user_email: Optional[str],
|
||||||
|
premium_user: bool,
|
||||||
|
disabled_non_admin_personal_key_creation: bool,
|
||||||
|
login_method: Literal["username_password", "sso"],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a JWT token for the authenticated UI session
|
||||||
|
|
||||||
|
This token is used to authenticate the user's session when they are redirected to the UI
|
||||||
|
"""
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import general_settings, master_key
|
||||||
|
|
||||||
|
if master_key is None:
|
||||||
|
raise ValueError("Master key is not set")
|
||||||
|
|
||||||
|
expiration = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||||
|
initial_payload = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_email": user_email,
|
||||||
|
"user_role": user_role, # this is the path without sso - we can assume only admins will use this
|
||||||
|
"login_method": login_method,
|
||||||
|
"premium_user": premium_user,
|
||||||
|
"auth_header_name": general_settings.get(
|
||||||
|
"litellm_key_header_name", "Authorization"
|
||||||
|
),
|
||||||
|
"iss": "litellm-proxy", # Issuer - identifies this as an internal token
|
||||||
|
"aud": "litellm-ui", # Audience - identifies this as a UI token
|
||||||
|
"exp": expiration,
|
||||||
|
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_role == LitellmUserRoles.PROXY_ADMIN
|
||||||
|
or user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
|
||||||
|
):
|
||||||
|
initial_payload["scope"] = [
|
||||||
|
LiteLLM_JWTAuth().admin_jwt_scope,
|
||||||
|
]
|
||||||
|
|
||||||
|
jwt_token = jwt.encode(
|
||||||
|
initial_payload,
|
||||||
|
master_key,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
return jwt_token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_ui_session_token(token_dict: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the token is a UI session token
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
token_dict.get("iss") == "litellm-proxy"
|
||||||
|
and token_dict.get("aud") == "litellm-ui"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_authenticated_redirect_response(
|
||||||
|
redirect_url: str, jwt_token: str
|
||||||
|
) -> RedirectResponse:
|
||||||
|
redirect_response = RedirectResponse(url=redirect_url, status_code=303)
|
||||||
|
redirect_response.set_cookie(
|
||||||
|
key=UISessionHandler._generate_token_name(),
|
||||||
|
value=jwt_token,
|
||||||
|
secure=True,
|
||||||
|
httponly=True,
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
return redirect_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_token_name() -> str:
|
||||||
|
current_timestamp = int(time.time())
|
||||||
|
cookie_name = f"litellm_ui_token_{current_timestamp}"
|
||||||
|
return cookie_name
|
|
@ -7375,6 +7375,8 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
import multipart
|
import multipart
|
||||||
except ImportError:
|
except ImportError:
|
||||||
subprocess.run(["pip", "install", "python-multipart"])
|
subprocess.run(["pip", "install", "python-multipart"])
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
|
|
||||||
global master_key
|
global master_key
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
@ -7445,56 +7447,23 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
user_role=user_role,
|
user_role=user_role,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if os.getenv("DATABASE_URL") is not None:
|
|
||||||
response = await generate_key_helper_fn(
|
|
||||||
request_type="key",
|
|
||||||
**{
|
|
||||||
"user_role": LitellmUserRoles.PROXY_ADMIN,
|
|
||||||
"duration": "24hr",
|
|
||||||
"key_max_budget": litellm.max_ui_session_budget,
|
|
||||||
"models": [],
|
|
||||||
"aliases": {},
|
|
||||||
"config": {},
|
|
||||||
"spend": 0,
|
|
||||||
"user_id": key_user_id,
|
|
||||||
"team_id": "litellm-dashboard",
|
|
||||||
}, # type: ignore
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ProxyException(
|
|
||||||
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="DATABASE_URL",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
key = response["token"] # type: ignore
|
|
||||||
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
||||||
if litellm_dashboard_ui.endswith("/"):
|
if litellm_dashboard_ui.endswith("/"):
|
||||||
litellm_dashboard_ui += "ui/"
|
litellm_dashboard_ui += "ui/"
|
||||||
else:
|
else:
|
||||||
litellm_dashboard_ui += "/ui/"
|
litellm_dashboard_ui += "/ui/"
|
||||||
import jwt
|
jwt_token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||||
|
user_id=user_id,
|
||||||
jwt_token = jwt.encode( # type: ignore
|
user_role=user_role,
|
||||||
{
|
user_email=None,
|
||||||
"user_id": user_id,
|
premium_user=premium_user,
|
||||||
"key": key,
|
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||||
"user_email": None,
|
login_method="username_password",
|
||||||
"user_role": user_role, # this is the path without sso - we can assume only admins will use this
|
|
||||||
"login_method": "username_password",
|
|
||||||
"premium_user": premium_user,
|
|
||||||
"auth_header_name": general_settings.get(
|
|
||||||
"litellm_key_header_name", "Authorization"
|
|
||||||
),
|
|
||||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
|
||||||
},
|
|
||||||
master_key,
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?userID=" + user_id
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
return UISessionHandler.generate_authenticated_redirect_response(
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||||
return redirect_response
|
)
|
||||||
elif _user_row is not None:
|
elif _user_row is not None:
|
||||||
"""
|
"""
|
||||||
When sharing invite links
|
When sharing invite links
|
||||||
|
@ -7513,58 +7482,23 @@ async def login(request: Request): # noqa: PLR0915
|
||||||
if secrets.compare_digest(password, _password) or secrets.compare_digest(
|
if secrets.compare_digest(password, _password) or secrets.compare_digest(
|
||||||
hash_password, _password
|
hash_password, _password
|
||||||
):
|
):
|
||||||
if os.getenv("DATABASE_URL") is not None:
|
|
||||||
response = await generate_key_helper_fn(
|
|
||||||
request_type="key",
|
|
||||||
**{ # type: ignore
|
|
||||||
"user_role": user_role,
|
|
||||||
"duration": "24hr",
|
|
||||||
"key_max_budget": litellm.max_ui_session_budget,
|
|
||||||
"models": [],
|
|
||||||
"aliases": {},
|
|
||||||
"config": {},
|
|
||||||
"spend": 0,
|
|
||||||
"user_id": user_id,
|
|
||||||
"team_id": "litellm-dashboard",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ProxyException(
|
|
||||||
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="DATABASE_URL",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
key = response["token"] # type: ignore
|
|
||||||
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
||||||
if litellm_dashboard_ui.endswith("/"):
|
if litellm_dashboard_ui.endswith("/"):
|
||||||
litellm_dashboard_ui += "ui/"
|
litellm_dashboard_ui += "ui/"
|
||||||
else:
|
else:
|
||||||
litellm_dashboard_ui += "/ui/"
|
litellm_dashboard_ui += "/ui/"
|
||||||
import jwt
|
jwt_token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||||
|
user_id=user_id,
|
||||||
jwt_token = jwt.encode( # type: ignore
|
user_role=user_role,
|
||||||
{
|
user_email=user_email,
|
||||||
"user_id": user_id,
|
premium_user=premium_user,
|
||||||
"key": key,
|
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||||
"user_email": user_email,
|
login_method="username_password",
|
||||||
"user_role": user_role,
|
|
||||||
"login_method": "username_password",
|
|
||||||
"premium_user": premium_user,
|
|
||||||
"auth_header_name": general_settings.get(
|
|
||||||
"litellm_key_header_name", "Authorization"
|
|
||||||
),
|
|
||||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
|
||||||
},
|
|
||||||
master_key,
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
litellm_dashboard_ui += "?userID=" + user_id
|
||||||
redirect_response = RedirectResponse(
|
return UISessionHandler.generate_authenticated_redirect_response(
|
||||||
url=litellm_dashboard_ui, status_code=303
|
redirect_url=litellm_dashboard_ui, jwt_token=jwt_token
|
||||||
)
|
)
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
|
||||||
return redirect_response
|
|
||||||
else:
|
else:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
|
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
|
||||||
|
@ -7590,6 +7524,8 @@ async def onboarding(invite_link: str):
|
||||||
- Get user from db
|
- Get user from db
|
||||||
- Pass in user_email if set
|
- Pass in user_email if set
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
|
|
||||||
global prisma_client, master_key, general_settings
|
global prisma_client, master_key, general_settings
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
@ -7646,51 +7582,26 @@ async def onboarding(invite_link: str):
|
||||||
|
|
||||||
user_email = user_obj.user_email
|
user_email = user_obj.user_email
|
||||||
|
|
||||||
response = await generate_key_helper_fn(
|
|
||||||
request_type="key",
|
|
||||||
**{
|
|
||||||
"user_role": user_obj.user_role,
|
|
||||||
"duration": "24hr",
|
|
||||||
"key_max_budget": litellm.max_ui_session_budget,
|
|
||||||
"models": [],
|
|
||||||
"aliases": {},
|
|
||||||
"config": {},
|
|
||||||
"spend": 0,
|
|
||||||
"user_id": user_obj.user_id,
|
|
||||||
"team_id": "litellm-dashboard",
|
|
||||||
}, # type: ignore
|
|
||||||
)
|
|
||||||
key = response["token"] # type: ignore
|
|
||||||
|
|
||||||
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
||||||
if litellm_dashboard_ui.endswith("/"):
|
if litellm_dashboard_ui.endswith("/"):
|
||||||
litellm_dashboard_ui += "ui/onboarding"
|
litellm_dashboard_ui += "ui/onboarding"
|
||||||
else:
|
else:
|
||||||
litellm_dashboard_ui += "/ui/onboarding"
|
litellm_dashboard_ui += "/ui/onboarding"
|
||||||
import jwt
|
|
||||||
|
|
||||||
disabled_non_admin_personal_key_creation = (
|
disabled_non_admin_personal_key_creation = (
|
||||||
get_disabled_non_admin_personal_key_creation()
|
get_disabled_non_admin_personal_key_creation()
|
||||||
)
|
)
|
||||||
|
jwt_token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||||
jwt_token = jwt.encode( # type: ignore
|
user_id=user_obj.user_id,
|
||||||
{
|
user_role=user_obj.user_role,
|
||||||
"user_id": user_obj.user_id,
|
user_email=user_obj.user_email,
|
||||||
"key": key,
|
premium_user=user_obj.premium_user,
|
||||||
"user_email": user_obj.user_email,
|
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||||
"user_role": user_obj.user_role,
|
login_method="username_password",
|
||||||
"login_method": "username_password",
|
|
||||||
"premium_user": premium_user,
|
|
||||||
"auth_header_name": general_settings.get(
|
|
||||||
"litellm_key_header_name", "Authorization"
|
|
||||||
),
|
|
||||||
"disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
|
|
||||||
},
|
|
||||||
master_key,
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
litellm_dashboard_ui += "?token={}&user_email={}".format(jwt_token, user_email)
|
litellm_dashboard_ui += "?token={}&user_email={}".format(jwt_token, user_email)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"login_url": litellm_dashboard_ui,
|
"login_url": litellm_dashboard_ui,
|
||||||
"token": jwt_token,
|
"token": jwt_token,
|
||||||
|
|
108
tests/litellm/proxy/auth/test_handle_jwt.py
Normal file
108
tests/litellm/proxy/auth/test_handle_jwt.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
|
|
||||||
|
|
||||||
|
class TestJWTHandler:
|
||||||
|
@pytest.fixture
|
||||||
|
def jwt_handler(self):
|
||||||
|
handler = JWTHandler()
|
||||||
|
handler.leeway = 60 # Set leeway for testing
|
||||||
|
return handler
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_mocks(self, monkeypatch):
|
||||||
|
# Mock master_key
|
||||||
|
test_master_key = "test_master_key"
|
||||||
|
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", test_master_key)
|
||||||
|
|
||||||
|
# Mock UISessionHandler.is_ui_session_token to return True for our test token
|
||||||
|
def mock_is_ui_session_token(payload):
|
||||||
|
return "ui_session_id" in payload
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
UISessionHandler, "is_ui_session_token", mock_is_ui_session_token
|
||||||
|
)
|
||||||
|
|
||||||
|
return test_master_key
|
||||||
|
|
||||||
|
def test_validate_valid_ui_token(self, jwt_handler, setup_mocks):
|
||||||
|
# Setup
|
||||||
|
test_master_key = setup_mocks
|
||||||
|
|
||||||
|
# Create a valid UI token
|
||||||
|
valid_payload = {
|
||||||
|
"ui_session_id": "test_session_id",
|
||||||
|
"exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600,
|
||||||
|
"iat": datetime.datetime.now(tz=timezone.utc).timestamp(),
|
||||||
|
"aud": "litellm-ui",
|
||||||
|
}
|
||||||
|
valid_token = jwt.encode(valid_payload, test_master_key, algorithm="HS256")
|
||||||
|
|
||||||
|
# Test valid UI token
|
||||||
|
result = jwt_handler._validate_ui_token(valid_token)
|
||||||
|
assert result is not None
|
||||||
|
assert result["ui_session_id"] == "test_session_id"
|
||||||
|
|
||||||
|
def test_validate_expired_ui_token(self, jwt_handler, setup_mocks):
|
||||||
|
# Setup
|
||||||
|
test_master_key = setup_mocks
|
||||||
|
|
||||||
|
# Create an expired UI token
|
||||||
|
expired_payload = {
|
||||||
|
"ui_session_id": "test_session_id",
|
||||||
|
"exp": datetime.datetime.now(tz=timezone.utc).timestamp() - 3600,
|
||||||
|
"iat": datetime.datetime.now(tz=timezone.utc).timestamp() - 7200,
|
||||||
|
"aud": "litellm-ui",
|
||||||
|
}
|
||||||
|
expired_token = jwt.encode(expired_payload, test_master_key, algorithm="HS256")
|
||||||
|
|
||||||
|
# Test expired UI token
|
||||||
|
with pytest.raises(ValueError, match="Invalid UI token"):
|
||||||
|
jwt_handler._validate_ui_token(expired_token)
|
||||||
|
|
||||||
|
def test_validate_invalid_signature_ui_token(self, jwt_handler, setup_mocks):
|
||||||
|
# Setup
|
||||||
|
test_master_key = setup_mocks
|
||||||
|
|
||||||
|
# Create a token with invalid signature
|
||||||
|
valid_payload = {
|
||||||
|
"ui_session_id": "test_session_id",
|
||||||
|
"exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600,
|
||||||
|
"iat": datetime.datetime.now(tz=timezone.utc).timestamp(),
|
||||||
|
"aud": "litellm-ui",
|
||||||
|
}
|
||||||
|
invalid_token = jwt.encode(valid_payload, "wrong_key", algorithm="HS256")
|
||||||
|
|
||||||
|
# Test UI token with invalid signature
|
||||||
|
with pytest.raises(ValueError, match="Invalid UI token"):
|
||||||
|
jwt_handler._validate_ui_token(invalid_token)
|
||||||
|
|
||||||
|
def test_validate_non_ui_token(self, jwt_handler, setup_mocks):
|
||||||
|
# Setup
|
||||||
|
test_master_key = setup_mocks
|
||||||
|
|
||||||
|
# Create a non-UI token
|
||||||
|
non_ui_payload = {
|
||||||
|
"sub": "user123",
|
||||||
|
"exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600,
|
||||||
|
}
|
||||||
|
non_ui_token = jwt.encode(non_ui_payload, test_master_key, algorithm="HS256")
|
||||||
|
|
||||||
|
# Test non-UI token
|
||||||
|
result = jwt_handler._validate_ui_token(non_ui_token)
|
||||||
|
assert result is None
|
|
@ -0,0 +1,179 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
import pytest
|
||||||
|
from fastapi.requests import Request
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
from litellm.proxy._types import LitellmUserRoles
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
|
|
||||||
|
|
||||||
|
class TestUISessionHandler:
|
||||||
|
|
||||||
|
def test_get_latest_ui_cookie_name(self):
|
||||||
|
# Test with multiple cookies
|
||||||
|
cookies = {
|
||||||
|
"litellm_ui_token_1000": "value1",
|
||||||
|
"litellm_ui_token_2000": "value2",
|
||||||
|
"other_cookie": "other_value",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||||
|
assert result == "litellm_ui_token_2000"
|
||||||
|
|
||||||
|
# Test with no matching cookies
|
||||||
|
cookies = {"other_cookie": "value"}
|
||||||
|
result = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_ui_session_token_from_cookies(self):
|
||||||
|
# Create mock request with cookies
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.cookies = {
|
||||||
|
"litellm_ui_token_1000": "test_token",
|
||||||
|
"other_cookie": "other_value",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||||
|
assert result == "test_token"
|
||||||
|
|
||||||
|
# Test with no matching cookies
|
||||||
|
mock_request.cookies = {"other_cookie": "value"}
|
||||||
|
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@patch("litellm.proxy.proxy_server.master_key", "test_master_key")
|
||||||
|
@patch(
|
||||||
|
"litellm.proxy.proxy_server.general_settings",
|
||||||
|
{"litellm_key_header_name": "X-API-Key"},
|
||||||
|
)
|
||||||
|
def test_build_authenticated_ui_jwt_token(self):
|
||||||
|
# Test token generation
|
||||||
|
token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||||
|
user_id="test_user",
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
user_email="test@example.com",
|
||||||
|
premium_user=True,
|
||||||
|
disabled_non_admin_personal_key_creation=False,
|
||||||
|
login_method="username_password",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode and verify token
|
||||||
|
decoded = jwt.decode(
|
||||||
|
token, "test_master_key", algorithms=["HS256"], audience="litellm-ui"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert decoded["user_id"] == "test_user"
|
||||||
|
assert decoded["user_email"] == "test@example.com"
|
||||||
|
assert decoded["user_role"] == LitellmUserRoles.PROXY_ADMIN
|
||||||
|
assert decoded["premium_user"] is True
|
||||||
|
assert decoded["login_method"] == "username_password"
|
||||||
|
assert decoded["auth_header_name"] == "X-API-Key"
|
||||||
|
assert decoded["iss"] == "litellm-proxy"
|
||||||
|
assert decoded["aud"] == "litellm-ui"
|
||||||
|
assert "exp" in decoded
|
||||||
|
assert decoded["disabled_non_admin_personal_key_creation"] is False
|
||||||
|
assert decoded["scope"] == ["litellm_proxy_admin"]
|
||||||
|
|
||||||
|
def test_is_ui_session_token(self):
|
||||||
|
# Valid UI session token
|
||||||
|
token_dict = {
|
||||||
|
"iss": "litellm-proxy",
|
||||||
|
"aud": "litellm-ui",
|
||||||
|
"user_id": "test_user",
|
||||||
|
}
|
||||||
|
assert UISessionHandler.is_ui_session_token(token_dict) is True
|
||||||
|
|
||||||
|
# Invalid token (wrong issuer)
|
||||||
|
token_dict = {
|
||||||
|
"iss": "other-issuer",
|
||||||
|
"aud": "litellm-ui",
|
||||||
|
}
|
||||||
|
assert UISessionHandler.is_ui_session_token(token_dict) is False
|
||||||
|
|
||||||
|
# Invalid token (wrong audience)
|
||||||
|
token_dict = {
|
||||||
|
"iss": "litellm-proxy",
|
||||||
|
"aud": "other-audience",
|
||||||
|
}
|
||||||
|
assert UISessionHandler.is_ui_session_token(token_dict) is False
|
||||||
|
|
||||||
|
def test_generate_authenticated_redirect_response(self):
|
||||||
|
redirect_url = "https://example.com/dashboard"
|
||||||
|
jwt_token = "test.jwt.token"
|
||||||
|
|
||||||
|
response = UISessionHandler.generate_authenticated_redirect_response(
|
||||||
|
redirect_url=redirect_url, jwt_token=jwt_token
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, RedirectResponse)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"] == redirect_url
|
||||||
|
|
||||||
|
# Check cookie was set
|
||||||
|
cookie_header = response.headers.get("set-cookie", "")
|
||||||
|
assert "test.jwt.token" in cookie_header
|
||||||
|
assert "Secure" in cookie_header
|
||||||
|
assert "HttpOnly" in cookie_header
|
||||||
|
assert "SameSite=strict" in cookie_header
|
||||||
|
|
||||||
|
def test_generate_token_name(self):
|
||||||
|
# Mock time.time() to return a fixed value
|
||||||
|
with patch("time.time", return_value=1234567890):
|
||||||
|
token_name = UISessionHandler._generate_token_name()
|
||||||
|
assert token_name == "litellm_ui_token_1234567890"
|
||||||
|
|
||||||
|
def test_latest_token_is_used(self):
|
||||||
|
"""Test that the most recent token is correctly identified and used"""
|
||||||
|
# Create mock request with multiple UI tokens of different timestamps
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.cookies = {
|
||||||
|
"litellm_ui_token_1000": "old_token",
|
||||||
|
"litellm_ui_token_2000": "newer_token",
|
||||||
|
"litellm_ui_token_1500": "middle_token",
|
||||||
|
"other_cookie": "other_value",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the token from cookies
|
||||||
|
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||||
|
|
||||||
|
# Verify the newest token (highest timestamp) is returned
|
||||||
|
assert result == "newer_token"
|
||||||
|
|
||||||
|
# Test with timestamps that aren't in numerical order in the dictionary
|
||||||
|
mock_request.cookies = {
|
||||||
|
"litellm_ui_token_5000": "newest_token",
|
||||||
|
"litellm_ui_token_1000": "oldest_token",
|
||||||
|
"litellm_ui_token_3000": "middle_token",
|
||||||
|
"other_cookie": "other_value",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||||
|
assert result == "newest_token"
|
||||||
|
|
||||||
|
# Test with non-numeric timestamp parts
|
||||||
|
mock_request.cookies = {
|
||||||
|
"litellm_ui_token_abc": "invalid_token",
|
||||||
|
"litellm_ui_token_2000": "valid_token",
|
||||||
|
"other_cookie": "other_value",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should handle the error and still return a token (fallback to string sort)
|
||||||
|
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||||
|
assert result is not None
|
124
tests/test_ui_session_handler.py
Normal file
124
tests/test_ui_session_handler.py
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import jwt
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from fastapi.requests import Request
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler
|
||||||
|
from litellm.proxy._types import LitellmUserRoles
|
||||||
|
|
||||||
|
|
||||||
|
class TestUISessionHandler:
|
||||||
|
|
||||||
|
def test_get_latest_ui_cookie_name(self):
|
||||||
|
# Test with multiple cookies
|
||||||
|
cookies = {
|
||||||
|
"litellm_ui_token_1000": "value1",
|
||||||
|
"litellm_ui_token_2000": "value2",
|
||||||
|
"other_cookie": "other_value",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||||
|
assert result == "litellm_ui_token_2000"
|
||||||
|
|
||||||
|
# Test with no matching cookies
|
||||||
|
cookies = {"other_cookie": "value"}
|
||||||
|
result = UISessionHandler._get_latest_ui_cookie_name(cookies)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_ui_session_token_from_cookies(self):
|
||||||
|
# Create mock request with cookies
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.cookies = {
|
||||||
|
"litellm_ui_token_1000": "test_token",
|
||||||
|
"other_cookie": "other_value",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||||
|
assert result == "test_token"
|
||||||
|
|
||||||
|
# Test with no matching cookies
|
||||||
|
mock_request.cookies = {"other_cookie": "value"}
|
||||||
|
result = UISessionHandler._get_ui_session_token_from_cookies(mock_request)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@patch("litellm.proxy.proxy_server.master_key", "test_master_key")
|
||||||
|
@patch(
|
||||||
|
"litellm.proxy.proxy_server.general_settings",
|
||||||
|
{"litellm_key_header_name": "X-API-Key"},
|
||||||
|
)
|
||||||
|
def test_build_authenticated_ui_jwt_token(self):
|
||||||
|
# Test token generation
|
||||||
|
token = UISessionHandler.build_authenticated_ui_jwt_token(
|
||||||
|
user_id="test_user",
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
user_email="test@example.com",
|
||||||
|
premium_user=True,
|
||||||
|
disabled_non_admin_personal_key_creation=False,
|
||||||
|
login_method="username_password",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode and verify token
|
||||||
|
decoded = jwt.decode(token, "test_master_key", algorithms=["HS256"])
|
||||||
|
|
||||||
|
assert decoded["user_id"] == "test_user"
|
||||||
|
assert decoded["user_email"] == "test@example.com"
|
||||||
|
assert decoded["user_role"] == LitellmUserRoles.PROXY_ADMIN
|
||||||
|
assert decoded["premium_user"] is True
|
||||||
|
assert decoded["login_method"] == "username_password"
|
||||||
|
assert decoded["auth_header_name"] == "X-API-Key"
|
||||||
|
assert decoded["iss"] == "litellm-proxy"
|
||||||
|
assert decoded["aud"] == "litellm-ui"
|
||||||
|
assert "exp" in decoded
|
||||||
|
assert decoded["disabled_non_admin_personal_key_creation"] is False
|
||||||
|
assert decoded["scope"] == ["litellm:admin"]
|
||||||
|
|
||||||
|
def test_is_ui_session_token(self):
|
||||||
|
# Valid UI session token
|
||||||
|
token_dict = {
|
||||||
|
"iss": "litellm-proxy",
|
||||||
|
"aud": "litellm-ui",
|
||||||
|
"user_id": "test_user",
|
||||||
|
}
|
||||||
|
assert UISessionHandler.is_ui_session_token(token_dict) is True
|
||||||
|
|
||||||
|
# Invalid token (wrong issuer)
|
||||||
|
token_dict = {
|
||||||
|
"iss": "other-issuer",
|
||||||
|
"aud": "litellm-ui",
|
||||||
|
}
|
||||||
|
assert UISessionHandler.is_ui_session_token(token_dict) is False
|
||||||
|
|
||||||
|
# Invalid token (wrong audience)
|
||||||
|
token_dict = {
|
||||||
|
"iss": "litellm-proxy",
|
||||||
|
"aud": "other-audience",
|
||||||
|
}
|
||||||
|
assert UISessionHandler.is_ui_session_token(token_dict) is False
|
||||||
|
|
||||||
|
def test_generate_authenticated_redirect_response(self):
|
||||||
|
redirect_url = "https://example.com/dashboard"
|
||||||
|
jwt_token = "test.jwt.token"
|
||||||
|
|
||||||
|
response = UISessionHandler.generate_authenticated_redirect_response(
|
||||||
|
redirect_url=redirect_url, jwt_token=jwt_token
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, RedirectResponse)
|
||||||
|
assert response.status_code == 303
|
||||||
|
assert response.headers["location"] == redirect_url
|
||||||
|
|
||||||
|
# Check cookie was set
|
||||||
|
cookie_header = response.headers.get("set-cookie", "")
|
||||||
|
assert "test.jwt.token" in cookie_header
|
||||||
|
assert "Secure" in cookie_header
|
||||||
|
assert "HttpOnly" in cookie_header
|
||||||
|
assert "SameSite=strict" in cookie_header
|
||||||
|
|
||||||
|
def test_generate_token_name(self):
|
||||||
|
# Mock time.time() to return a fixed value
|
||||||
|
with patch("time.time", return_value=1234567890):
|
||||||
|
token_name = UISessionHandler._generate_token_name()
|
||||||
|
assert token_name == "litellm_ui_token_1234567890"
|
|
@ -20,12 +20,12 @@ import {
|
||||||
} from "@/components/networking";
|
} from "@/components/networking";
|
||||||
import { jwtDecode } from "jwt-decode";
|
import { jwtDecode } from "jwt-decode";
|
||||||
import { Form, Button as Button2, message } from "antd";
|
import { Form, Button as Button2, message } from "antd";
|
||||||
import { getCookie } from "@/utils/cookieUtils";
|
import { getUISessionDetails, setAuthToken } from "@/utils/cookieUtils";
|
||||||
|
|
||||||
export default function Onboarding() {
|
export default function Onboarding() {
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
const searchParams = useSearchParams()!;
|
const searchParams = useSearchParams()!;
|
||||||
const token = getCookie('token');
|
const token = getUISessionDetails();
|
||||||
const inviteID = searchParams.get("invitation_id");
|
const inviteID = searchParams.get("invitation_id");
|
||||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||||
const [defaultUserEmail, setDefaultUserEmail] = useState<string>("");
|
const [defaultUserEmail, setDefaultUserEmail] = useState<string>("");
|
||||||
|
@ -88,7 +88,7 @@ export default function Onboarding() {
|
||||||
litellm_dashboard_ui += "?userID=" + user_id;
|
litellm_dashboard_ui += "?userID=" + user_id;
|
||||||
|
|
||||||
// set cookie "token" to jwtToken
|
// set cookie "token" to jwtToken
|
||||||
document.cookie = "token=" + jwtToken;
|
setAuthToken(jwtToken);
|
||||||
console.log("redirecting to:", litellm_dashboard_ui);
|
console.log("redirecting to:", litellm_dashboard_ui);
|
||||||
|
|
||||||
window.location.href = litellm_dashboard_ui;
|
window.location.href = litellm_dashboard_ui;
|
||||||
|
|
|
@ -30,12 +30,7 @@ import { Organization } from "@/components/networking";
|
||||||
import GuardrailsPanel from "@/components/guardrails";
|
import GuardrailsPanel from "@/components/guardrails";
|
||||||
import { fetchUserModels } from "@/components/create_key_button";
|
import { fetchUserModels } from "@/components/create_key_button";
|
||||||
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
||||||
function getCookie(name: string) {
|
import { getUISessionDetails } from "@/utils/cookieUtils";
|
||||||
const cookieValue = document.cookie
|
|
||||||
.split("; ")
|
|
||||||
.find((row) => row.startsWith(name + "="));
|
|
||||||
return cookieValue ? cookieValue.split("=")[1] : null;
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatUserRole(userRole: string) {
|
function formatUserRole(userRole: string) {
|
||||||
if (!userRole) {
|
if (!userRole) {
|
||||||
|
@ -117,32 +112,22 @@ export default function CreateKeyPage() {
|
||||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const token = getCookie("token");
|
const fetchSessionDetails = async () => {
|
||||||
setToken(token);
|
try {
|
||||||
}, []);
|
const sessionDetails = await getUISessionDetails();
|
||||||
|
// sessionDetails is already decoded, no need for jwtDecode
|
||||||
|
console.log("Session details:", sessionDetails);
|
||||||
|
|
||||||
useEffect(() => {
|
// Set access token to the session_id
|
||||||
if (!token) {
|
setAccessToken(sessionDetails.session_id);
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const decoded = jwtDecode(token) as { [key: string]: any };
|
|
||||||
if (decoded) {
|
|
||||||
// cast decoded to dictionary
|
|
||||||
console.log("Decoded token:", decoded);
|
|
||||||
|
|
||||||
console.log("Decoded key:", decoded.key);
|
|
||||||
// set accessToken
|
|
||||||
setAccessToken(decoded.key);
|
|
||||||
|
|
||||||
setDisabledPersonalKeyCreation(
|
setDisabledPersonalKeyCreation(
|
||||||
decoded.disabled_non_admin_personal_key_creation,
|
sessionDetails.disabled_non_admin_personal_key_creation,
|
||||||
);
|
);
|
||||||
|
|
||||||
// check if userRole is defined
|
if (sessionDetails.user_role) {
|
||||||
if (decoded.user_role) {
|
const formattedUserRole = formatUserRole(sessionDetails.user_role);
|
||||||
const formattedUserRole = formatUserRole(decoded.user_role);
|
console.log("User role:", formattedUserRole);
|
||||||
console.log("Decoded user_role:", formattedUserRole);
|
|
||||||
setUserRole(formattedUserRole);
|
setUserRole(formattedUserRole);
|
||||||
if (formattedUserRole == "Admin Viewer") {
|
if (formattedUserRole == "Admin Viewer") {
|
||||||
setPage("usage");
|
setPage("usage");
|
||||||
|
@ -151,29 +136,35 @@ export default function CreateKeyPage() {
|
||||||
console.log("User role not defined");
|
console.log("User role not defined");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (decoded.user_email) {
|
if (sessionDetails.user_email) {
|
||||||
setUserEmail(decoded.user_email);
|
setUserEmail(sessionDetails.user_email);
|
||||||
} else {
|
} else {
|
||||||
console.log(`User Email is not set ${decoded}`);
|
console.log("User Email is not set");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (decoded.login_method) {
|
if (sessionDetails.login_method) {
|
||||||
setShowSSOBanner(
|
setShowSSOBanner(
|
||||||
decoded.login_method == "username_password" ? true : false,
|
sessionDetails.login_method == "username_password" ? true : false,
|
||||||
);
|
);
|
||||||
} else {
|
|
||||||
console.log(`User Email is not set ${decoded}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (decoded.premium_user) {
|
if (sessionDetails.premium_user) {
|
||||||
setPremiumUser(decoded.premium_user);
|
setPremiumUser(sessionDetails.premium_user);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (decoded.auth_header_name) {
|
if (sessionDetails.auth_header_name) {
|
||||||
setGlobalLitellmHeaderName(decoded.auth_header_name);
|
setGlobalLitellmHeaderName(sessionDetails.auth_header_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store the full session details as token for components that need it
|
||||||
|
setToken(JSON.stringify(sessionDetails));
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error fetching session details:", error);
|
||||||
}
|
}
|
||||||
}, [token]);
|
};
|
||||||
|
|
||||||
|
fetchSessionDetails();
|
||||||
|
}, []);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (accessToken && userID && userRole) {
|
if (accessToken && userID && userRole) {
|
||||||
|
|
|
@ -462,6 +462,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!accessToken || !token || !userRole || !userID) {
|
if (!accessToken || !token || !userRole || !userID) {
|
||||||
|
console.log("exiting on model_dashboard.tsx because of missing accessToken, token, userRole, or userID", accessToken, token, userRole, userID);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const fetchData = async () => {
|
const fetchData = async () => {
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -21,6 +21,7 @@ import { useSearchParams, useRouter } from "next/navigation";
|
||||||
import { Team } from "./key_team_helpers/key_list";
|
import { Team } from "./key_team_helpers/key_list";
|
||||||
import { jwtDecode } from "jwt-decode";
|
import { jwtDecode } from "jwt-decode";
|
||||||
import { Typography } from "antd";
|
import { Typography } from "antd";
|
||||||
|
import { getUISessionDetails } from "@/utils/cookieUtils";
|
||||||
import { clearTokenCookies } from "@/utils/cookieUtils";
|
import { clearTokenCookies } from "@/utils/cookieUtils";
|
||||||
const isLocal = process.env.NODE_ENV === "development";
|
const isLocal = process.env.NODE_ENV === "development";
|
||||||
if (isLocal != true) {
|
if (isLocal != true) {
|
||||||
|
@ -45,14 +46,6 @@ export type UserInfo = {
|
||||||
spend: number;
|
spend: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
function getCookie(name: string) {
|
|
||||||
console.log("COOKIES", document.cookie)
|
|
||||||
const cookieValue = document.cookie
|
|
||||||
.split('; ')
|
|
||||||
.find(row => row.startsWith(name + '='));
|
|
||||||
return cookieValue ? cookieValue.split('=')[1] : null;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface UserDashboardProps {
|
interface UserDashboardProps {
|
||||||
userID: string | null;
|
userID: string | null;
|
||||||
userRole: string | null;
|
userRole: string | null;
|
||||||
|
@ -94,7 +87,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
// Assuming useSearchParams() hook exists and works in your setup
|
// Assuming useSearchParams() hook exists and works in your setup
|
||||||
const searchParams = useSearchParams()!;
|
const searchParams = useSearchParams()!;
|
||||||
|
|
||||||
const token = getCookie('token');
|
const token = getUISessionDetails();
|
||||||
|
|
||||||
const invitation_id = searchParams.get("invitation_id");
|
const invitation_id = searchParams.get("invitation_id");
|
||||||
|
|
||||||
|
@ -146,32 +139,37 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
// console.log(`selectedTeam: ${Object.entries(selectedTeam)}`);
|
// console.log(`selectedTeam: ${Object.entries(selectedTeam)}`);
|
||||||
// Moved useEffect inside the component and used a condition to run fetch only if the params are available
|
// Moved useEffect inside the component and used a condition to run fetch only if the params are available
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (token) {
|
const fetchSessionDetails = async () => {
|
||||||
const decoded = jwtDecode(token) as { [key: string]: any };
|
try {
|
||||||
if (decoded) {
|
const sessionDetails = await getUISessionDetails();
|
||||||
// cast decoded to dictionary
|
console.log("Session details:", sessionDetails);
|
||||||
console.log("Decoded token:", decoded);
|
|
||||||
|
|
||||||
console.log("Decoded key:", decoded.key);
|
// Set access token to the session_id
|
||||||
// set accessToken
|
setAccessToken(sessionDetails.session_id);
|
||||||
setAccessToken(decoded.key);
|
|
||||||
|
|
||||||
// check if userRole is defined
|
// check if userRole is defined
|
||||||
if (decoded.user_role) {
|
if (sessionDetails.user_role) {
|
||||||
const formattedUserRole = formatUserRole(decoded.user_role);
|
const formattedUserRole = formatUserRole(sessionDetails.user_role);
|
||||||
console.log("Decoded user_role:", formattedUserRole);
|
console.log("User role:", formattedUserRole);
|
||||||
setUserRole(formattedUserRole);
|
setUserRole(formattedUserRole);
|
||||||
} else {
|
} else {
|
||||||
console.log("User role not defined");
|
console.log("User role not defined");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (decoded.user_email) {
|
if (sessionDetails.user_email) {
|
||||||
setUserEmail(decoded.user_email);
|
setUserEmail(sessionDetails.user_email);
|
||||||
} else {
|
} else {
|
||||||
console.log(`User Email is not set ${decoded}`);
|
console.log("User Email is not set");
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error fetching session details:", error);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fetchSessionDetails();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
if (userID && accessToken && userRole && !keys && !userSpendData) {
|
if (userID && accessToken && userRole && !keys && !userSpendData) {
|
||||||
const cachedUserModels = sessionStorage.getItem("userModels" + userID);
|
const cachedUserModels = sessionStorage.getItem("userModels" + userID);
|
||||||
if (cachedUserModels) {
|
if (cachedUserModels) {
|
||||||
|
@ -246,7 +244,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
fetchTeams(accessToken, userID, userRole, currentOrg, setTeams);
|
fetchTeams(accessToken, userID, userRole, currentOrg, setTeams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [userID, token, accessToken, keys, userRole]);
|
}, [userID, accessToken, keys, userRole]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
console.log(`currentOrg: ${JSON.stringify(currentOrg)}, accessToken: ${accessToken}, userID: ${userID}, userRole: ${userRole}`)
|
console.log(`currentOrg: ${JSON.stringify(currentOrg)}, accessToken: ${accessToken}, userID: ${userID}, userRole: ${userRole}`)
|
||||||
|
@ -333,6 +331,8 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
<div className="w-full mx-4 h-[75vh]">
|
<div className="w-full mx-4 h-[75vh]">
|
||||||
<Grid numItems={1} className="gap-2 p-8 w-full mt-2">
|
<Grid numItems={1} className="gap-2 p-8 w-full mt-2">
|
||||||
<Col numColSpan={1} className="flex flex-col gap-2">
|
<Col numColSpan={1} className="flex flex-col gap-2">
|
||||||
|
{accessToken && (
|
||||||
|
<>
|
||||||
<CreateKey
|
<CreateKey
|
||||||
key={selectedTeam ? selectedTeam.team_id : null}
|
key={selectedTeam ? selectedTeam.team_id : null}
|
||||||
userID={userID}
|
userID={userID}
|
||||||
|
@ -358,6 +358,8 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
setCurrentOrg={setCurrentOrg}
|
setCurrentOrg={setCurrentOrg}
|
||||||
organizations={organizations}
|
organizations={organizations}
|
||||||
/>
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Col>
|
</Col>
|
||||||
</Grid>
|
</Grid>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -5,6 +5,24 @@
|
||||||
/**
|
/**
|
||||||
* Clears the token cookie from both root and /ui paths
|
* Clears the token cookie from both root and /ui paths
|
||||||
*/
|
*/
|
||||||
|
import { validateSession } from "../components/networking"
|
||||||
|
|
||||||
|
// Define the interface for the JWT token data
|
||||||
|
export interface JWTTokenData {
|
||||||
|
user_id: string;
|
||||||
|
user_email: string | null;
|
||||||
|
user_role: string;
|
||||||
|
login_method: string;
|
||||||
|
premium_user: boolean;
|
||||||
|
auth_header_name: string;
|
||||||
|
iss: string;
|
||||||
|
aud: string;
|
||||||
|
exp: number;
|
||||||
|
disabled_non_admin_personal_key_creation: boolean;
|
||||||
|
scopes: string[];
|
||||||
|
session_id: string; // ui session id currently in progress
|
||||||
|
}
|
||||||
|
|
||||||
export function clearTokenCookies() {
|
export function clearTokenCookies() {
|
||||||
// Get the current domain
|
// Get the current domain
|
||||||
const domain = window.location.hostname;
|
const domain = window.location.hostname;
|
||||||
|
@ -13,32 +31,51 @@ export function clearTokenCookies() {
|
||||||
const paths = ['/', '/ui'];
|
const paths = ['/', '/ui'];
|
||||||
const sameSiteValues = ['Lax', 'Strict', 'None'];
|
const sameSiteValues = ['Lax', 'Strict', 'None'];
|
||||||
|
|
||||||
|
// Get all cookies
|
||||||
|
const allCookies = document.cookie.split("; ");
|
||||||
|
const tokenPattern = /^token_\d+$/;
|
||||||
|
|
||||||
|
// Find all token cookies
|
||||||
|
const tokenCookieNames = allCookies
|
||||||
|
.map(cookie => cookie.split("=")[0])
|
||||||
|
.filter(name => name === "token" || tokenPattern.test(name));
|
||||||
|
|
||||||
|
// Clear each token cookie with various combinations
|
||||||
|
tokenCookieNames.forEach(cookieName => {
|
||||||
paths.forEach(path => {
|
paths.forEach(path => {
|
||||||
// Basic clearing
|
// Basic clearing
|
||||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
|
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
|
||||||
|
|
||||||
// With domain
|
// With domain
|
||||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
|
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
|
||||||
|
|
||||||
// Try different SameSite values
|
// Try different SameSite values
|
||||||
sameSiteValues.forEach(sameSite => {
|
sameSiteValues.forEach(sameSite => {
|
||||||
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
|
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
|
||||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
|
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
|
||||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
|
document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log("After clearing cookies:", document.cookie);
|
console.log("After clearing cookies:", document.cookie);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export function setAuthToken(token: string) {
|
||||||
* Gets a cookie value by name
|
// Generate a token name with current timestamp
|
||||||
* @param name The name of the cookie to retrieve
|
const currentTimestamp = Math.floor(Date.now() / 1000);
|
||||||
* @returns The cookie value or null if not found
|
const tokenName = `token_${currentTimestamp}`;
|
||||||
*/
|
|
||||||
export function getCookie(name: string) {
|
// Set the cookie with the timestamp-based name
|
||||||
const cookieValue = document.cookie
|
document.cookie = `${tokenName}=${token}; path=/; domain=${window.location.hostname};`;
|
||||||
.split('; ')
|
}
|
||||||
.find(row => row.startsWith(name + '='));
|
|
||||||
return cookieValue ? cookieValue.split('=')[1] : null;
|
export async function getUISessionDetails(): Promise<JWTTokenData> {
|
||||||
|
const validated_jwt_token = await validateSession();
|
||||||
|
|
||||||
|
if (validated_jwt_token?.data) {
|
||||||
|
return validated_jwt_token.data as JWTTokenData;
|
||||||
|
} else {
|
||||||
|
throw new Error("Invalid JWT token");
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue