diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ba27de78be..d62dec6ee0 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -272,6 +272,7 @@ class LiteLLMRoutes(enum.Enum): "/key/health", "/team/info", "/team/list", + "/organization/info", "/organization/list", "/team/available", "/user/info", @@ -282,6 +283,11 @@ class LiteLLMRoutes(enum.Enum): "/health", "/key/list", "/user/filter/ui", + "/user/list", + "/user/available_roles", + "/guardrails/list", + "/cache/ping", + "/get/config/callbacks", ] # NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend @@ -300,6 +306,8 @@ class LiteLLMRoutes(enum.Enum): "/user/update", "/user/delete", "/user/info", + # user invitation management + "/invitation/new", # team "/team/new", "/team/update", @@ -309,6 +317,20 @@ class LiteLLMRoutes(enum.Enum): "/team/block", "/team/unblock", "/team/available", + # team member management + "/team/member_add", + "/team/member_update", + "/team/member_delete", + # organization management + "/organization/new", + "/organization/update", + "/organization/delete", + "/organization/info", + "/organization/list", + # organization member management + "/organization/member_add", + "/organization/member_update", + "/organization/member_delete", # model "/model/new", "/model/update", @@ -355,20 +377,32 @@ class LiteLLMRoutes(enum.Enum): "/sso", "/sso/get/ui_settings", "/login", + "/sso/session/validate", "/key/info", "/config", "/spend", "/model/info", + "/model/metrics", + "/model/metrics/{sub_path}", + "/model/settings", + "/get/litellm_model_cost_map", + "/model/streaming_metrics", "/v2/model/info", "/v2/key/info", "/models", "/v1/models", "/global/spend", "/global/spend/logs", + "/spend/logs/ui", + "/spend/logs/ui/{id}", "/global/spend/keys", "/global/spend/models", "/global/predict/spend/logs", "/global/activity", + "/global/activity/{sub_path}", + "/global/activity/exceptions", + "/global/activity/exceptions/{sub_path}", + "/global/all_end_users", "/health/services", ] + info_routes @@ -2459,6 +2493,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): "spend_tracking_routes", "global_spend_tracking_routes", "info_routes", + "ui_routes", ] team_id_jwt_field: Optional[str] = None team_id_upsert: bool = False diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 183b5609d0..e910e9aeed 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -204,9 +204,11 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: """ for allowed_route in allowed_routes: - if ( - allowed_route in LiteLLMRoutes.__members__ - and user_route in LiteLLMRoutes[allowed_route].value + if allowed_route in LiteLLMRoutes.__members__ and ( + RouteChecks.check_route_access( + route=user_route, + allowed_routes=LiteLLMRoutes[allowed_route].value, + ) ): return True elif allowed_route == user_route: @@ -217,16 +219,18 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: def allowed_routes_check( user_role: Literal[ LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, LitellmUserRoles.TEAM, LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, ], user_route: str, litellm_proxy_roles: LiteLLM_JWTAuth, + jwt_valid_token: dict, ) -> bool: """ Check if user -> not admin - allowed to access these routes """ - if user_role == LitellmUserRoles.PROXY_ADMIN: is_allowed = _allowed_routes_check( user_route=user_route, diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 29f4b31f9c..93246d1217 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -33,6 +33,7 @@ from litellm.proxy._types import ( ScopeMapping, Span, ) +from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler from litellm.proxy.utils import PrismaClient, ProxyLogging from .auth_checks import ( @@ -406,10 +407,60 @@ class JWTHandler: else: return False + def _validate_ui_token(self, token: str) -> Optional[dict]: + """ + Helper function to validate tokens generated for the LiteLLM UI. + Returns the decoded payload if it's a valid UI token, None otherwise. + """ + import jwt + + from litellm.proxy.proxy_server import master_key + + try: + # Decode without verification to check if it's a UI token + unverified_payload = jwt.decode(token, options={"verify_signature": False}) + + # Check if this looks like a UI token (has specific claims that only UI tokens would have) + if UISessionHandler.is_ui_session_token(unverified_payload): + + # This looks like a UI token, now verify it with the master key + if not master_key: + verbose_proxy_logger.debug( + "Missing LITELLM_MASTER_KEY for UI token validation" + ) + return None + + try: + payload = jwt.decode( + token, + master_key, + algorithms=["HS256"], + audience="litellm-ui", + leeway=self.leeway, + ) + verbose_proxy_logger.debug( + "Successfully validated UI token for payload: %s", + json.dumps(payload, indent=4), + ) + return payload + except jwt.InvalidTokenError as e: + verbose_proxy_logger.debug(f"Invalid UI token: {str(e)}") + raise ValueError( + f"Invalid UI token, Unable to validate token signature {str(e)}" + ) + + return None # Not a UI token + except Exception as e: + raise e + async def auth_jwt(self, token: str) -> dict: # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret # the key in different ways (e.g. HS* and RS*)." + + ui_payload = self._validate_ui_token(token) + if ui_payload: + return ui_payload algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"] audience = os.getenv("JWT_AUDIENCE") @@ -616,6 +667,7 @@ class JWTAuthManager: user_id: Optional[str], org_id: Optional[str], api_key: str, + jwt_valid_token: dict, ) -> Optional[JWTAuthBuilderResult]: """Check admin status and route access permissions""" if not jwt_handler.is_admin(scopes=scopes): @@ -625,6 +677,7 @@ class JWTAuthManager: user_role=LitellmUserRoles.PROXY_ADMIN, user_route=route, litellm_proxy_roles=jwt_handler.litellm_jwtauth, + jwt_valid_token=jwt_valid_token, ) if not is_allowed: allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes @@ -698,6 +751,7 @@ class JWTAuthManager: user_api_key_cache: DualCache, parent_otel_span: Optional[Span], proxy_logging_obj: ProxyLogging, + jwt_valid_token: dict, ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: """Find first team with access to the requested model""" @@ -730,6 +784,7 @@ class JWTAuthManager: user_role=LitellmUserRoles.TEAM, user_route=route, litellm_proxy_roles=jwt_handler.litellm_jwtauth, + jwt_valid_token=jwt_valid_token, ) if is_allowed: return team_id, team_object @@ -920,7 +975,13 @@ class JWTAuthManager: # Check admin access admin_result = await JWTAuthManager.check_admin_access( - jwt_handler, scopes, route, user_id, org_id, api_key + jwt_handler=jwt_handler, + scopes=scopes, + route=route, + user_id=user_id, + org_id=org_id, + api_key=api_key, + jwt_valid_token=jwt_valid_token, ) if admin_result: return admin_result @@ -952,6 +1013,7 @@ class JWTAuthManager: user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, + jwt_valid_token=jwt_valid_token, ) # Get other objects diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index a18a7ab5e1..7711e1ab0f 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional +from typing import List, Optional, Set, Union from fastapi import HTTPException, Request, status @@ -225,7 +225,9 @@ class RouteChecks: return False @staticmethod - def check_route_access(route: str, allowed_routes: List[str]) -> bool: + def check_route_access( + route: str, allowed_routes: Union[List[str], Set[str]] + ) -> bool: """ Check if a route has access by checking both exact matches and patterns diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index ecefc64d67..f3659ed404 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -51,6 +51,7 @@ from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.service_account_checks import service_account_checks from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler from litellm.proxy.utils import PrismaClient, ProxyLogging, _to_ns from litellm.types.services import ServiceTypes @@ -335,6 +336,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 "pass_through_endpoints", None ) passed_in_key: Optional[str] = None + cookie_token: Optional[str] = None if isinstance(api_key, str): passed_in_key = api_key api_key = _get_bearer_token(api_key=api_key) @@ -344,6 +346,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 api_key = anthropic_api_key_header elif isinstance(google_ai_studio_api_key_header, str): api_key = google_ai_studio_api_key_header + elif cookie_token := UISessionHandler._get_ui_session_token_from_cookies( + request + ): + api_key = cookie_token elif pass_through_endpoints is not None: for endpoint in pass_through_endpoints: if endpoint.get("path", "") == route: @@ -420,7 +426,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if general_settings.get("enable_oauth2_proxy_auth", False) is True: return await handle_oauth2_proxy_request(request=request) - if general_settings.get("enable_jwt_auth", False) is True: + if ( + general_settings.get("enable_jwt_auth", False) is True + or cookie_token is not None + ): from litellm.proxy.proxy_server import premium_user if premium_user is not True: diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 2e2720c104..7a1d15fe01 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -8,7 +8,7 @@ Has all /sso/* routes import asyncio import os import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, List, Optional, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status @@ -43,6 +43,7 @@ from litellm.proxy.management_endpoints.sso_helper_utils import ( ) from litellm.proxy.management_endpoints.team_endpoints import team_member_add from litellm.proxy.management_endpoints.types import CustomOpenID +from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler from litellm.secret_managers.main import str_to_bool if TYPE_CHECKING: @@ -408,11 +409,7 @@ def get_disabled_non_admin_personal_key_creation(): @router.get("/sso/callback", tags=["experimental"], include_in_schema=False) async def auth_callback(request: Request): # noqa: PLR0915 """Verify login""" - from litellm.proxy.management_endpoints.key_management_endpoints import ( - generate_key_helper_fn, - ) from litellm.proxy.proxy_server import ( - general_settings, jwt_handler, master_key, premium_user, @@ -426,6 +423,7 @@ async def auth_callback(request: Request): # noqa: PLR0915 microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + user_role: Optional[LitellmUserRoles] = None # get url from request if master_key is None: raise ProxyException( @@ -531,16 +529,7 @@ async def auth_callback(request: Request): # noqa: PLR0915 max_internal_user_budget = litellm.max_internal_user_budget internal_user_budget_duration = litellm.internal_user_budget_duration - # User might not be already created on first generation of key # But if it is, we want their models preferences - default_ui_key_values: Dict[str, Any] = { - "duration": "24hr", - "key_max_budget": litellm.max_ui_session_budget, - "aliases": {}, - "config": {}, - "spend": 0, - "team_id": "litellm-dashboard", - } user_defined_values: Optional[SSOUserDefinedValues] = None if user_custom_sso is not None: @@ -559,7 +548,6 @@ async def auth_callback(request: Request): # noqa: PLR0915 ) _user_id_from_sso = user_id - user_role = None try: if prisma_client is not None: try: @@ -632,24 +620,14 @@ async def auth_callback(request: Request): # noqa: PLR0915 f"user_defined_values for creating ui key: {user_defined_values}" ) - default_ui_key_values.update(user_defined_values) - default_ui_key_values["request_type"] = "key" - response = await generate_key_helper_fn( - **default_ui_key_values, # type: ignore - table_name="key", - ) - - key = response["token"] # type: ignore - user_id = response["user_id"] # type: ignore - litellm_dashboard_ui = "/ui/" - user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value + user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY if ( os.getenv("PROXY_ADMIN_ID", None) is not None and os.environ["PROXY_ADMIN_ID"] == user_id ): # checks if user is admin - user_role = LitellmUserRoles.PROXY_ADMIN.value + user_role = LitellmUserRoles.PROXY_ADMIN verbose_proxy_logger.debug( f"user_role: {user_role}; ui_access_mode: {ui_access_mode}" @@ -669,30 +647,20 @@ async def auth_callback(request: Request): # noqa: PLR0915 disabled_non_admin_personal_key_creation = ( get_disabled_non_admin_personal_key_creation() ) - - import jwt - - jwt_token = jwt.encode( # type: ignore - { - "user_id": user_id, - "key": key, - "user_email": user_email, - "user_role": user_role, - "login_method": "sso", - "premium_user": premium_user, - "auth_header_name": general_settings.get( - "litellm_key_header_name", "Authorization" - ), - "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation, - }, - master_key, - algorithm="HS256", + jwt_token = UISessionHandler.build_authenticated_ui_jwt_token( + user_id=user_defined_values.get("user_id", ""), + user_role=user_role, + user_email=user_defined_values.get("user_email", ""), + premium_user=premium_user, + disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation, + login_method="sso", ) if user_id is not None and isinstance(user_id, str): litellm_dashboard_ui += "?userID=" + user_id - redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) - redirect_response.set_cookie(key="token", value=jwt_token, secure=True) - return redirect_response + + return UISessionHandler.generate_authenticated_redirect_response( + redirect_url=litellm_dashboard_ui, jwt_token=jwt_token + ) async def insert_sso_user( @@ -778,3 +746,25 @@ async def get_ui_settings(request: Request): ), "DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries, } + + +@router.get( + "/sso/session/validate", + include_in_schema=False, + tags=["experimental"], +) +async def validate_session( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + from litellm.proxy.auth.handle_jwt import JWTHandler + from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler + + ui_session_token = UISessionHandler._get_ui_session_token_from_cookies(request) + ui_session_id = UISessionHandler._get_latest_ui_cookie_name(request.cookies) + if ui_session_token is None: + raise HTTPException(status_code=401, detail="Unauthorized") + jwt_handler = JWTHandler() + validated_jwt_token = await jwt_handler.auth_jwt(token=ui_session_token) + validated_jwt_token["session_id"] = ui_session_id + return {"valid": True, "data": validated_jwt_token} diff --git a/litellm/proxy/management_helpers/ui_session_handler.py b/litellm/proxy/management_helpers/ui_session_handler.py new file mode 100644 index 0000000000..c4def9d1f1 --- /dev/null +++ b/litellm/proxy/management_helpers/ui_session_handler.py @@ -0,0 +1,141 @@ +import time +from datetime import datetime, timedelta, timezone +from typing import Literal, Optional + +from fastapi.requests import Request +from fastapi.responses import RedirectResponse + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import LiteLLM_JWTAuth, LitellmUserRoles + + +class UISessionHandler: + + @staticmethod + def _get_latest_ui_cookie_name(cookies: dict) -> Optional[str]: + """ + Get the name of the most recent LiteLLM UI cookie + """ + # Find all LiteLLM UI cookies (format: litellm_ui_token_{timestamp}) + litellm_ui_cookies = [ + k for k in cookies.keys() if k.startswith("litellm_ui_token_") + ] + if not litellm_ui_cookies: + return None + + # Sort by timestamp (descending) to get the most recent one + try: + # Extract timestamps and sort numerically + sorted_cookies = sorted( + litellm_ui_cookies, + key=lambda x: int(x.split("_")[-1]), + reverse=True, + ) + return sorted_cookies[0] + except (ValueError, IndexError): + # Fallback to simple string sort if timestamp extraction fails + litellm_ui_cookies.sort(reverse=True) + return litellm_ui_cookies[0] + + @staticmethod + # Add this function to extract auth token from cookies + def _get_ui_session_token_from_cookies(request: Request) -> Optional[str]: + """ + Extract authentication token from cookies if present + """ + try: + cookies = request.cookies + verbose_proxy_logger.debug(f"AUTH COOKIES: {cookies}") + + cookie_name = UISessionHandler._get_latest_ui_cookie_name(cookies) + if cookie_name: + return cookies[cookie_name] + + return None + except Exception as e: + verbose_proxy_logger.error( + f"Error getting UI session token from cookies: {e}" + ) + return None + + @staticmethod + def build_authenticated_ui_jwt_token( + user_id: str, + user_role: Optional[LitellmUserRoles], + user_email: Optional[str], + premium_user: bool, + disabled_non_admin_personal_key_creation: bool, + login_method: Literal["username_password", "sso"], + ) -> str: + """ + Build a JWT token for the authenticated UI session + + This token is used to authenticate the user's session when they are redirected to the UI + """ + import jwt + + from litellm.proxy.proxy_server import general_settings, master_key + + if master_key is None: + raise ValueError("Master key is not set") + + expiration = datetime.now(timezone.utc) + timedelta(hours=24) + initial_payload = { + "user_id": user_id, + "user_email": user_email, + "user_role": user_role, # this is the path without sso - we can assume only admins will use this + "login_method": login_method, + "premium_user": premium_user, + "auth_header_name": general_settings.get( + "litellm_key_header_name", "Authorization" + ), + "iss": "litellm-proxy", # Issuer - identifies this as an internal token + "aud": "litellm-ui", # Audience - identifies this as a UI token + "exp": expiration, + "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation, + } + + if ( + user_role == LitellmUserRoles.PROXY_ADMIN + or user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY + ): + initial_payload["scope"] = [ + LiteLLM_JWTAuth().admin_jwt_scope, + ] + + jwt_token = jwt.encode( + initial_payload, + master_key, + algorithm="HS256", + ) + return jwt_token + + @staticmethod + def is_ui_session_token(token_dict: dict) -> bool: + """ + Returns True if the token is a UI session token + """ + return ( + token_dict.get("iss") == "litellm-proxy" + and token_dict.get("aud") == "litellm-ui" + ) + + @staticmethod + def generate_authenticated_redirect_response( + redirect_url: str, jwt_token: str + ) -> RedirectResponse: + redirect_response = RedirectResponse(url=redirect_url, status_code=303) + redirect_response.set_cookie( + key=UISessionHandler._generate_token_name(), + value=jwt_token, + secure=True, + httponly=True, + samesite="strict", + ) + return redirect_response + + @staticmethod + def _generate_token_name() -> str: + current_timestamp = int(time.time()) + cookie_name = f"litellm_ui_token_{current_timestamp}" + return cookie_name diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 28aceb7519..04682a4d0b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7375,6 +7375,8 @@ async def login(request: Request): # noqa: PLR0915 import multipart except ImportError: subprocess.run(["pip", "install", "python-multipart"]) + from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler + global master_key if master_key is None: raise ProxyException( @@ -7445,56 +7447,23 @@ async def login(request: Request): # noqa: PLR0915 user_role=user_role, ) ) - if os.getenv("DATABASE_URL") is not None: - response = await generate_key_helper_fn( - request_type="key", - **{ - "user_role": LitellmUserRoles.PROXY_ADMIN, - "duration": "24hr", - "key_max_budget": litellm.max_ui_session_budget, - "models": [], - "aliases": {}, - "config": {}, - "spend": 0, - "user_id": key_user_id, - "team_id": "litellm-dashboard", - }, # type: ignore - ) - else: - raise ProxyException( - message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.", - type=ProxyErrorTypes.auth_error, - param="DATABASE_URL", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - key = response["token"] # type: ignore litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "") if litellm_dashboard_ui.endswith("/"): litellm_dashboard_ui += "ui/" else: litellm_dashboard_ui += "/ui/" - import jwt - - jwt_token = jwt.encode( # type: ignore - { - "user_id": user_id, - "key": key, - "user_email": None, - "user_role": user_role, # this is the path without sso - we can assume only admins will use this - "login_method": "username_password", - "premium_user": premium_user, - "auth_header_name": general_settings.get( - "litellm_key_header_name", "Authorization" - ), - "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation, - }, - master_key, - algorithm="HS256", + jwt_token = UISessionHandler.build_authenticated_ui_jwt_token( + user_id=user_id, + user_role=user_role, + user_email=None, + premium_user=premium_user, + disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation, + login_method="username_password", ) litellm_dashboard_ui += "?userID=" + user_id - redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) - redirect_response.set_cookie(key="token", value=jwt_token) - return redirect_response + return UISessionHandler.generate_authenticated_redirect_response( + redirect_url=litellm_dashboard_ui, jwt_token=jwt_token + ) elif _user_row is not None: """ When sharing invite links @@ -7513,58 +7482,23 @@ async def login(request: Request): # noqa: PLR0915 if secrets.compare_digest(password, _password) or secrets.compare_digest( hash_password, _password ): - if os.getenv("DATABASE_URL") is not None: - response = await generate_key_helper_fn( - request_type="key", - **{ # type: ignore - "user_role": user_role, - "duration": "24hr", - "key_max_budget": litellm.max_ui_session_budget, - "models": [], - "aliases": {}, - "config": {}, - "spend": 0, - "user_id": user_id, - "team_id": "litellm-dashboard", - }, - ) - else: - raise ProxyException( - message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.", - type=ProxyErrorTypes.auth_error, - param="DATABASE_URL", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - key = response["token"] # type: ignore litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "") if litellm_dashboard_ui.endswith("/"): litellm_dashboard_ui += "ui/" else: litellm_dashboard_ui += "/ui/" - import jwt - - jwt_token = jwt.encode( # type: ignore - { - "user_id": user_id, - "key": key, - "user_email": user_email, - "user_role": user_role, - "login_method": "username_password", - "premium_user": premium_user, - "auth_header_name": general_settings.get( - "litellm_key_header_name", "Authorization" - ), - "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation, - }, - master_key, - algorithm="HS256", + jwt_token = UISessionHandler.build_authenticated_ui_jwt_token( + user_id=user_id, + user_role=user_role, + user_email=user_email, + premium_user=premium_user, + disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation, + login_method="username_password", ) litellm_dashboard_ui += "?userID=" + user_id - redirect_response = RedirectResponse( - url=litellm_dashboard_ui, status_code=303 + return UISessionHandler.generate_authenticated_redirect_response( + redirect_url=litellm_dashboard_ui, jwt_token=jwt_token ) - redirect_response.set_cookie(key="token", value=jwt_token) - return redirect_response else: raise ProxyException( message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}", @@ -7590,6 +7524,8 @@ async def onboarding(invite_link: str): - Get user from db - Pass in user_email if set """ + from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler + global prisma_client, master_key, general_settings if master_key is None: raise ProxyException( @@ -7646,51 +7582,26 @@ async def onboarding(invite_link: str): user_email = user_obj.user_email - response = await generate_key_helper_fn( - request_type="key", - **{ - "user_role": user_obj.user_role, - "duration": "24hr", - "key_max_budget": litellm.max_ui_session_budget, - "models": [], - "aliases": {}, - "config": {}, - "spend": 0, - "user_id": user_obj.user_id, - "team_id": "litellm-dashboard", - }, # type: ignore - ) - key = response["token"] # type: ignore - litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "") if litellm_dashboard_ui.endswith("/"): litellm_dashboard_ui += "ui/onboarding" else: litellm_dashboard_ui += "/ui/onboarding" - import jwt disabled_non_admin_personal_key_creation = ( get_disabled_non_admin_personal_key_creation() ) - - jwt_token = jwt.encode( # type: ignore - { - "user_id": user_obj.user_id, - "key": key, - "user_email": user_obj.user_email, - "user_role": user_obj.user_role, - "login_method": "username_password", - "premium_user": premium_user, - "auth_header_name": general_settings.get( - "litellm_key_header_name", "Authorization" - ), - "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation, - }, - master_key, - algorithm="HS256", + jwt_token = UISessionHandler.build_authenticated_ui_jwt_token( + user_id=user_obj.user_id, + user_role=user_obj.user_role, + user_email=user_obj.user_email, + premium_user=user_obj.premium_user, + disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation, + login_method="username_password", ) litellm_dashboard_ui += "?token={}&user_email={}".format(jwt_token, user_email) + return { "login_url": litellm_dashboard_ui, "token": jwt_token, diff --git a/tests/litellm/proxy/auth/test_handle_jwt.py b/tests/litellm/proxy/auth/test_handle_jwt.py new file mode 100644 index 0000000000..36eeac0df8 --- /dev/null +++ b/tests/litellm/proxy/auth/test_handle_jwt.py @@ -0,0 +1,108 @@ +import datetime +import json +import os +import sys +from datetime import timezone + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +import jwt + +from litellm.proxy.auth.handle_jwt import JWTHandler +from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler + + +class TestJWTHandler: + @pytest.fixture + def jwt_handler(self): + handler = JWTHandler() + handler.leeway = 60 # Set leeway for testing + return handler + + @pytest.fixture + def setup_mocks(self, monkeypatch): + # Mock master_key + test_master_key = "test_master_key" + monkeypatch.setattr("litellm.proxy.proxy_server.master_key", test_master_key) + + # Mock UISessionHandler.is_ui_session_token to return True for our test token + def mock_is_ui_session_token(payload): + return "ui_session_id" in payload + + monkeypatch.setattr( + UISessionHandler, "is_ui_session_token", mock_is_ui_session_token + ) + + return test_master_key + + def test_validate_valid_ui_token(self, jwt_handler, setup_mocks): + # Setup + test_master_key = setup_mocks + + # Create a valid UI token + valid_payload = { + "ui_session_id": "test_session_id", + "exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600, + "iat": datetime.datetime.now(tz=timezone.utc).timestamp(), + "aud": "litellm-ui", + } + valid_token = jwt.encode(valid_payload, test_master_key, algorithm="HS256") + + # Test valid UI token + result = jwt_handler._validate_ui_token(valid_token) + assert result is not None + assert result["ui_session_id"] == "test_session_id" + + def test_validate_expired_ui_token(self, jwt_handler, setup_mocks): + # Setup + test_master_key = setup_mocks + + # Create an expired UI token + expired_payload = { + "ui_session_id": "test_session_id", + "exp": datetime.datetime.now(tz=timezone.utc).timestamp() - 3600, + "iat": datetime.datetime.now(tz=timezone.utc).timestamp() - 7200, + "aud": "litellm-ui", + } + expired_token = jwt.encode(expired_payload, test_master_key, algorithm="HS256") + + # Test expired UI token + with pytest.raises(ValueError, match="Invalid UI token"): + jwt_handler._validate_ui_token(expired_token) + + def test_validate_invalid_signature_ui_token(self, jwt_handler, setup_mocks): + # Setup + test_master_key = setup_mocks + + # Create a token with invalid signature + valid_payload = { + "ui_session_id": "test_session_id", + "exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600, + "iat": datetime.datetime.now(tz=timezone.utc).timestamp(), + "aud": "litellm-ui", + } + invalid_token = jwt.encode(valid_payload, "wrong_key", algorithm="HS256") + + # Test UI token with invalid signature + with pytest.raises(ValueError, match="Invalid UI token"): + jwt_handler._validate_ui_token(invalid_token) + + def test_validate_non_ui_token(self, jwt_handler, setup_mocks): + # Setup + test_master_key = setup_mocks + + # Create a non-UI token + non_ui_payload = { + "sub": "user123", + "exp": datetime.datetime.now(tz=timezone.utc).timestamp() + 3600, + } + non_ui_token = jwt.encode(non_ui_payload, test_master_key, algorithm="HS256") + + # Test non-UI token + result = jwt_handler._validate_ui_token(non_ui_token) + assert result is None diff --git a/tests/litellm/proxy/management_endpoints/test_ui_session_handler.py b/tests/litellm/proxy/management_endpoints/test_ui_session_handler.py new file mode 100644 index 0000000000..956579239a --- /dev/null +++ b/tests/litellm/proxy/management_endpoints/test_ui_session_handler.py @@ -0,0 +1,179 @@ +import datetime +import json +import os +import sys +from datetime import timezone + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +import time +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import jwt +import pytest +from fastapi.requests import Request +from fastapi.responses import RedirectResponse + +from litellm.proxy._types import LitellmUserRoles +from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler + + +class TestUISessionHandler: + + def test_get_latest_ui_cookie_name(self): + # Test with multiple cookies + cookies = { + "litellm_ui_token_1000": "value1", + "litellm_ui_token_2000": "value2", + "other_cookie": "other_value", + } + + result = UISessionHandler._get_latest_ui_cookie_name(cookies) + assert result == "litellm_ui_token_2000" + + # Test with no matching cookies + cookies = {"other_cookie": "value"} + result = UISessionHandler._get_latest_ui_cookie_name(cookies) + assert result is None + + def test_get_ui_session_token_from_cookies(self): + # Create mock request with cookies + mock_request = MagicMock() + mock_request.cookies = { + "litellm_ui_token_1000": "test_token", + "other_cookie": "other_value", + } + + result = UISessionHandler._get_ui_session_token_from_cookies(mock_request) + assert result == "test_token" + + # Test with no matching cookies + mock_request.cookies = {"other_cookie": "value"} + result = UISessionHandler._get_ui_session_token_from_cookies(mock_request) + assert result is None + + @patch("litellm.proxy.proxy_server.master_key", "test_master_key") + @patch( + "litellm.proxy.proxy_server.general_settings", + {"litellm_key_header_name": "X-API-Key"}, + ) + def test_build_authenticated_ui_jwt_token(self): + # Test token generation + token = UISessionHandler.build_authenticated_ui_jwt_token( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + user_email="test@example.com", + premium_user=True, + disabled_non_admin_personal_key_creation=False, + login_method="username_password", + ) + + # Decode and verify token + decoded = jwt.decode( + token, "test_master_key", algorithms=["HS256"], audience="litellm-ui" + ) + + assert decoded["user_id"] == "test_user" + assert decoded["user_email"] == "test@example.com" + assert decoded["user_role"] == LitellmUserRoles.PROXY_ADMIN + assert decoded["premium_user"] is True + assert decoded["login_method"] == "username_password" + assert decoded["auth_header_name"] == "X-API-Key" + assert decoded["iss"] == "litellm-proxy" + assert decoded["aud"] == "litellm-ui" + assert "exp" in decoded + assert decoded["disabled_non_admin_personal_key_creation"] is False + assert decoded["scope"] == ["litellm_proxy_admin"] + + def test_is_ui_session_token(self): + # Valid UI session token + token_dict = { + "iss": "litellm-proxy", + "aud": "litellm-ui", + "user_id": "test_user", + } + assert UISessionHandler.is_ui_session_token(token_dict) is True + + # Invalid token (wrong issuer) + token_dict = { + "iss": "other-issuer", + "aud": "litellm-ui", + } + assert UISessionHandler.is_ui_session_token(token_dict) is False + + # Invalid token (wrong audience) + token_dict = { + "iss": "litellm-proxy", + "aud": "other-audience", + } + assert UISessionHandler.is_ui_session_token(token_dict) is False + + def test_generate_authenticated_redirect_response(self): + redirect_url = "https://example.com/dashboard" + jwt_token = "test.jwt.token" + + response = UISessionHandler.generate_authenticated_redirect_response( + redirect_url=redirect_url, jwt_token=jwt_token + ) + + assert isinstance(response, RedirectResponse) + assert response.status_code == 303 + assert response.headers["location"] == redirect_url + + # Check cookie was set + cookie_header = response.headers.get("set-cookie", "") + assert "test.jwt.token" in cookie_header + assert "Secure" in cookie_header + assert "HttpOnly" in cookie_header + assert "SameSite=strict" in cookie_header + + def test_generate_token_name(self): + # Mock time.time() to return a fixed value + with patch("time.time", return_value=1234567890): + token_name = UISessionHandler._generate_token_name() + assert token_name == "litellm_ui_token_1234567890" + + def test_latest_token_is_used(self): + """Test that the most recent token is correctly identified and used""" + # Create mock request with multiple UI tokens of different timestamps + mock_request = MagicMock() + mock_request.cookies = { + "litellm_ui_token_1000": "old_token", + "litellm_ui_token_2000": "newer_token", + "litellm_ui_token_1500": "middle_token", + "other_cookie": "other_value", + } + + # Get the token from cookies + result = UISessionHandler._get_ui_session_token_from_cookies(mock_request) + + # Verify the newest token (highest timestamp) is returned + assert result == "newer_token" + + # Test with timestamps that aren't in numerical order in the dictionary + mock_request.cookies = { + "litellm_ui_token_5000": "newest_token", + "litellm_ui_token_1000": "oldest_token", + "litellm_ui_token_3000": "middle_token", + "other_cookie": "other_value", + } + + result = UISessionHandler._get_ui_session_token_from_cookies(mock_request) + assert result == "newest_token" + + # Test with non-numeric timestamp parts + mock_request.cookies = { + "litellm_ui_token_abc": "invalid_token", + "litellm_ui_token_2000": "valid_token", + "other_cookie": "other_value", + } + + # Should handle the error and still return a token (fallback to string sort) + result = UISessionHandler._get_ui_session_token_from_cookies(mock_request) + assert result is not None diff --git a/tests/test_ui_session_handler.py b/tests/test_ui_session_handler.py new file mode 100644 index 0000000000..f1e6e00908 --- /dev/null +++ b/tests/test_ui_session_handler.py @@ -0,0 +1,124 @@ +import pytest +import time +import jwt +from datetime import datetime, timezone +from fastapi.requests import Request +from fastapi.responses import RedirectResponse +from unittest.mock import MagicMock, patch + +from litellm.proxy.management_helpers.ui_session_handler import UISessionHandler +from litellm.proxy._types import LitellmUserRoles + + +class TestUISessionHandler: + + def test_get_latest_ui_cookie_name(self): + # Test with multiple cookies + cookies = { + "litellm_ui_token_1000": "value1", + "litellm_ui_token_2000": "value2", + "other_cookie": "other_value", + } + + result = UISessionHandler._get_latest_ui_cookie_name(cookies) + assert result == "litellm_ui_token_2000" + + # Test with no matching cookies + cookies = {"other_cookie": "value"} + result = UISessionHandler._get_latest_ui_cookie_name(cookies) + assert result is None + + def test_get_ui_session_token_from_cookies(self): + # Create mock request with cookies + mock_request = MagicMock() + mock_request.cookies = { + "litellm_ui_token_1000": "test_token", + "other_cookie": "other_value", + } + + result = UISessionHandler._get_ui_session_token_from_cookies(mock_request) + assert result == "test_token" + + # Test with no matching cookies + mock_request.cookies = {"other_cookie": "value"} + result = UISessionHandler._get_ui_session_token_from_cookies(mock_request) + assert result is None + + @patch("litellm.proxy.proxy_server.master_key", "test_master_key") + @patch( + "litellm.proxy.proxy_server.general_settings", + {"litellm_key_header_name": "X-API-Key"}, + ) + def test_build_authenticated_ui_jwt_token(self): + # Test token generation + token = UISessionHandler.build_authenticated_ui_jwt_token( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + user_email="test@example.com", + premium_user=True, + disabled_non_admin_personal_key_creation=False, + login_method="username_password", + ) + + # Decode and verify token + decoded = jwt.decode(token, "test_master_key", algorithms=["HS256"]) + + assert decoded["user_id"] == "test_user" + assert decoded["user_email"] == "test@example.com" + assert decoded["user_role"] == LitellmUserRoles.PROXY_ADMIN + assert decoded["premium_user"] is True + assert decoded["login_method"] == "username_password" + assert decoded["auth_header_name"] == "X-API-Key" + assert decoded["iss"] == "litellm-proxy" + assert decoded["aud"] == "litellm-ui" + assert "exp" in decoded + assert decoded["disabled_non_admin_personal_key_creation"] is False + assert decoded["scope"] == ["litellm:admin"] + + def test_is_ui_session_token(self): + # Valid UI session token + token_dict = { + "iss": "litellm-proxy", + "aud": "litellm-ui", + "user_id": "test_user", + } + assert UISessionHandler.is_ui_session_token(token_dict) is True + + # Invalid token (wrong issuer) + token_dict = { + "iss": "other-issuer", + "aud": "litellm-ui", + } + assert UISessionHandler.is_ui_session_token(token_dict) is False + + # Invalid token (wrong audience) + token_dict = { + "iss": "litellm-proxy", + "aud": "other-audience", + } + assert UISessionHandler.is_ui_session_token(token_dict) is False + + def test_generate_authenticated_redirect_response(self): + redirect_url = "https://example.com/dashboard" + jwt_token = "test.jwt.token" + + response = UISessionHandler.generate_authenticated_redirect_response( + redirect_url=redirect_url, jwt_token=jwt_token + ) + + assert isinstance(response, RedirectResponse) + assert response.status_code == 303 + assert response.headers["location"] == redirect_url + + # Check cookie was set + cookie_header = response.headers.get("set-cookie", "") + assert "test.jwt.token" in cookie_header + assert "Secure" in cookie_header + assert "HttpOnly" in cookie_header + assert "SameSite=strict" in cookie_header + + def test_generate_token_name(self): + # Mock time.time() to return a fixed value + with patch("time.time", return_value=1234567890): + token_name = UISessionHandler._generate_token_name() + assert token_name == "litellm_ui_token_1234567890" diff --git a/ui/litellm-dashboard/src/app/onboarding/page.tsx b/ui/litellm-dashboard/src/app/onboarding/page.tsx index e46e46fcb5..bc40761f89 100644 --- a/ui/litellm-dashboard/src/app/onboarding/page.tsx +++ b/ui/litellm-dashboard/src/app/onboarding/page.tsx @@ -20,12 +20,12 @@ import { } from "@/components/networking"; import { jwtDecode } from "jwt-decode"; import { Form, Button as Button2, message } from "antd"; -import { getCookie } from "@/utils/cookieUtils"; +import { getUISessionDetails, setAuthToken } from "@/utils/cookieUtils"; export default function Onboarding() { const [form] = Form.useForm(); const searchParams = useSearchParams()!; - const token = getCookie('token'); + const token = getUISessionDetails(); const inviteID = searchParams.get("invitation_id"); const [accessToken, setAccessToken] = useState(null); const [defaultUserEmail, setDefaultUserEmail] = useState(""); @@ -88,7 +88,7 @@ export default function Onboarding() { litellm_dashboard_ui += "?userID=" + user_id; // set cookie "token" to jwtToken - document.cookie = "token=" + jwtToken; + setAuthToken(jwtToken); console.log("redirecting to:", litellm_dashboard_ui); window.location.href = litellm_dashboard_ui; diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index 2612cab594..602193ac76 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -30,12 +30,7 @@ import { Organization } from "@/components/networking"; import GuardrailsPanel from "@/components/guardrails"; import { fetchUserModels } from "@/components/create_key_button"; import { fetchTeams } from "@/components/common_components/fetch_teams"; -function getCookie(name: string) { - const cookieValue = document.cookie - .split("; ") - .find((row) => row.startsWith(name + "=")); - return cookieValue ? cookieValue.split("=")[1] : null; -} +import { getUISessionDetails } from "@/utils/cookieUtils"; function formatUserRole(userRole: string) { if (!userRole) { @@ -117,63 +112,59 @@ export default function CreateKeyPage() { const [accessToken, setAccessToken] = useState(null); useEffect(() => { - const token = getCookie("token"); - setToken(token); - }, []); - - useEffect(() => { - if (!token) { - return; - } - - const decoded = jwtDecode(token) as { [key: string]: any }; - if (decoded) { - // cast decoded to dictionary - console.log("Decoded token:", decoded); - - console.log("Decoded key:", decoded.key); - // set accessToken - setAccessToken(decoded.key); - - setDisabledPersonalKeyCreation( - decoded.disabled_non_admin_personal_key_creation, - ); - - // check if userRole is defined - if (decoded.user_role) { - const formattedUserRole = formatUserRole(decoded.user_role); - console.log("Decoded user_role:", formattedUserRole); - setUserRole(formattedUserRole); - if (formattedUserRole == "Admin Viewer") { - setPage("usage"); - } - } else { - console.log("User role not defined"); - } - - if (decoded.user_email) { - setUserEmail(decoded.user_email); - } else { - console.log(`User Email is not set ${decoded}`); - } - - if (decoded.login_method) { - setShowSSOBanner( - decoded.login_method == "username_password" ? true : false, + const fetchSessionDetails = async () => { + try { + const sessionDetails = await getUISessionDetails(); + // sessionDetails is already decoded, no need for jwtDecode + console.log("Session details:", sessionDetails); + + // Set access token to the session_id + setAccessToken(sessionDetails.session_id); + + setDisabledPersonalKeyCreation( + sessionDetails.disabled_non_admin_personal_key_creation, ); - } else { - console.log(`User Email is not set ${decoded}`); + + if (sessionDetails.user_role) { + const formattedUserRole = formatUserRole(sessionDetails.user_role); + console.log("User role:", formattedUserRole); + setUserRole(formattedUserRole); + if (formattedUserRole == "Admin Viewer") { + setPage("usage"); + } + } else { + console.log("User role not defined"); + } + + if (sessionDetails.user_email) { + setUserEmail(sessionDetails.user_email); + } else { + console.log("User Email is not set"); + } + + if (sessionDetails.login_method) { + setShowSSOBanner( + sessionDetails.login_method == "username_password" ? true : false, + ); + } + + if (sessionDetails.premium_user) { + setPremiumUser(sessionDetails.premium_user); + } + + if (sessionDetails.auth_header_name) { + setGlobalLitellmHeaderName(sessionDetails.auth_header_name); + } + + // Store the full session details as token for components that need it + setToken(JSON.stringify(sessionDetails)); + } catch (error) { + console.error("Error fetching session details:", error); } - - if (decoded.premium_user) { - setPremiumUser(decoded.premium_user); - } - - if (decoded.auth_header_name) { - setGlobalLitellmHeaderName(decoded.auth_header_name); - } - } - }, [token]); + }; + + fetchSessionDetails(); + }, []); useEffect(() => { if (accessToken && userID && userRole) { diff --git a/ui/litellm-dashboard/src/components/model_dashboard.tsx b/ui/litellm-dashboard/src/components/model_dashboard.tsx index 86df8f7442..5b5db91668 100644 --- a/ui/litellm-dashboard/src/components/model_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/model_dashboard.tsx @@ -462,6 +462,7 @@ const ModelDashboard: React.FC = ({ useEffect(() => { if (!accessToken || !token || !userRole || !userID) { + console.log("exiting on model_dashboard.tsx because of missing accessToken, token, userRole, or userID", accessToken, token, userRole, userID); return; } const fetchData = async () => { diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 4673e064f5..71652b1a66 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -63,6 +63,34 @@ const handleError = async (errorData: string) => { // Global variable for the header name let globalLitellmHeaderName: string = "Authorization"; +const fetchWithCredentials = async (url: string, options: RequestInit = {}) => { + const defaultOptions: RequestInit = { + credentials: 'include', + headers: { + 'Content-Type': 'application/json', + }, + }; + + // Merge the default options with the provided options + const mergedOptions = { + ...defaultOptions, + ...options, + headers: { + ...defaultOptions.headers, + ...(options.headers || {}), + }, + }; + + // Remove the Authorization header if it exists + if (mergedOptions.headers && 'Authorization' in mergedOptions.headers) { + delete mergedOptions.headers['Authorization']; + } + + const response = await fetch(url, mergedOptions); + + return response; +}; + // Function to set the global header name export function setGlobalLitellmHeaderName(headerName: string = "Authorization") { console.log(`setGlobalLitellmHeaderName: ${headerName}`); @@ -71,7 +99,7 @@ export function setGlobalLitellmHeaderName(headerName: string = "Authorization") export const getOpenAPISchema = async () => { const url = proxyBaseUrl ? `${proxyBaseUrl}/openapi.json` : `/openapi.json`; - const response = await fetch(url); + const response = await fetchWithCredentials(url); const jsonData = await response.json(); return jsonData; } @@ -81,11 +109,10 @@ export const modelCostMap = async ( ) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/get/litellm_model_cost_map` : `/get/litellm_model_cost_map`; - const response = await fetch( + const response = await fetchWithCredentials( url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, } @@ -104,10 +131,9 @@ export const modelCreateCall = async ( ) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/model/new` : `/model/new`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -150,10 +176,9 @@ export const modelSettingsCall = async (accessToken: String) => { : `/model/settings`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -181,10 +206,9 @@ export const modelDeleteCall = async ( console.log(`model_id in model delete call: ${model_id}`); try { const url = proxyBaseUrl ? `${proxyBaseUrl}/model/delete` : `/model/delete`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -222,10 +246,9 @@ export const budgetDeleteCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/budget/delete` : `/budget/delete`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -257,10 +280,9 @@ export const budgetCreateCall = async ( console.log("Form Values after check:", formValues); const url = proxyBaseUrl ? `${proxyBaseUrl}/budget/new` : `/budget/new`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -294,10 +316,9 @@ export const budgetUpdateCall = async ( console.log("Form Values after check:", formValues); const url = proxyBaseUrl ? `${proxyBaseUrl}/budget/update` : `/budget/update`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -330,10 +351,9 @@ export const invitationCreateCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/invitation/new` : `/invitation/new`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -369,10 +389,9 @@ export const invitationClaimCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/invitation/claim` : `/invitation/claim`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -407,10 +426,9 @@ export const alertingSettingsCall = async (accessToken: String) => { : `/alerting/settings`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -464,10 +482,9 @@ export const keyCreateCall = async ( console.log("Form Values after check:", formValues); const url = proxyBaseUrl ? `${proxyBaseUrl}/key/generate` : `/key/generate`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -526,10 +543,9 @@ export const userCreateCall = async ( console.log("Form Values after check:", formValues); const url = proxyBaseUrl ? `${proxyBaseUrl}/user/new` : `/user/new`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -560,10 +576,9 @@ export const keyDeleteCall = async (accessToken: String, user_key: String) => { const url = proxyBaseUrl ? `${proxyBaseUrl}/key/delete` : `/key/delete`; console.log("in keyDeleteCall:", user_key); //message.info("Making key delete request"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -593,10 +608,9 @@ export const userDeleteCall = async (accessToken: string, userIds: string[]) => const url = proxyBaseUrl ? `${proxyBaseUrl}/user/delete` : `/user/delete`; console.log("in userDeleteCall:", userIds); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -625,10 +639,9 @@ export const teamDeleteCall = async (accessToken: String, teamID: String) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/team/delete` : `/team/delete`; console.log("in teamDeleteCall:", teamID); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -685,10 +698,9 @@ export const userListCall = async ( url += `?${queryString}`; } - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -740,10 +752,9 @@ export const userInfoCall = async ( } console.log("Requesting user data from:", url); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -773,10 +784,9 @@ export const teamInfoCall = async ( url = `${url}?team_id=${teamID}`; } console.log("in teamInfoCall"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -823,10 +833,9 @@ export const teamListCall = async ( url += `?${queryString}`; } - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -857,10 +866,9 @@ export const availableTeamListCall = async ( try { let url = proxyBaseUrl ? `${proxyBaseUrl}/team/available` : `/team/available`; console.log("in availableTeamListCall"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -885,10 +893,9 @@ export const organizationListCall = async (accessToken: String) => { */ try { let url = proxyBaseUrl ? `${proxyBaseUrl}/organization/list` : `/organization/list`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -917,10 +924,9 @@ export const organizationInfoCall = async ( url = `${url}?organization_id=${organizationID}`; } console.log("in teamInfoCall"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -960,10 +966,9 @@ export const organizationCreateCall = async ( } const url = proxyBaseUrl ? `${proxyBaseUrl}/organization/new` : `/organization/new`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -996,10 +1001,9 @@ export const organizationUpdateCall = async ( console.log("Form Values in organizationUpdateCall:", formValues); // Log the form values before making the API call const url = proxyBaseUrl ? `${proxyBaseUrl}/organization/update` : `/organization/update`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "PATCH", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -1030,10 +1034,9 @@ export const organizationDeleteCall = async ( ) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/organization/delete` : `/organization/delete`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "DELETE", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -1063,10 +1066,9 @@ export const getTotalSpendCall = async (accessToken: String) => { let url = proxyBaseUrl ? `${proxyBaseUrl}/global/spend` : `/global/spend`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1096,7 +1098,7 @@ export const getOnboardingCredentials = async (inviteUUID: String) => { : `/onboarding/get_token`; url += `?invite_link=${inviteUUID}`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { "Content-Type": "application/json", @@ -1128,10 +1130,9 @@ export const claimOnboardingToken = async ( ? `${proxyBaseUrl}/onboarding/claim_token` : `/onboarding/claim_token`; try { - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -1162,10 +1163,9 @@ export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: st ? `${proxyBaseUrl}/key/${keyToRegenerate}/regenerate` : `/key/${keyToRegenerate}/regenerate`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify(formData), @@ -1201,10 +1201,9 @@ export const modelInfoCall = async ( let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/model/info` : `/v2/model/info`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1249,10 +1248,9 @@ export const modelHubCall = async (accessToken: String) => { : `/model_group/info`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1280,10 +1278,9 @@ export const getAllowedIPs = async (accessToken: String) => { ? `${proxyBaseUrl}/get/allowed_ips` : `/get/allowed_ips`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1309,10 +1306,9 @@ export const addAllowedIP = async (accessToken: String, ip: String) => { ? `${proxyBaseUrl}/add/allowed_ip` : `/add/allowed_ip`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ ip: ip }), @@ -1339,10 +1335,9 @@ export const deleteAllowedIP = async (accessToken: String, ip: String) => { ? `${proxyBaseUrl}/delete/allowed_ip` : `/delete/allowed_ip`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, body: JSON.stringify({ ip: ip }), @@ -1381,10 +1376,9 @@ export const modelMetricsCall = async ( url = `${url}?_selected_model_group=${modelGroup}&startTime=${startTime}&endTime=${endTime}&api_key=${apiKey}&customer=${customer}`; } // message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1420,10 +1414,9 @@ export const streamingModelMetricsCall = async ( url = `${url}?_selected_model_group=${modelGroup}&startTime=${startTime}&endTime=${endTime}`; } // message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1465,10 +1458,9 @@ export const modelMetricsSlowResponsesCall = async ( } // message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1509,10 +1501,9 @@ export const modelExceptionsCall = async ( if (modelGroup) { url = `${url}?_selected_model_group=${modelGroup}&startTime=${startTime}&endTime=${endTime}&api_key=${apiKey}&customer=${customer}`; } - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1549,10 +1540,9 @@ export const modelAvailableCall = async ( } //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1579,10 +1569,9 @@ export const keySpendLogsCall = async (accessToken: String, token: String) => { ? `${proxyBaseUrl}/global/spend/logs` : `/global/spend/logs`; console.log("in keySpendLogsCall:", url); - const response = await fetch(`${url}?api_key=${token}`, { + const response = await fetchWithCredentials(`${url}?api_key=${token}`, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1607,10 +1596,9 @@ export const teamSpendLogsCall = async (accessToken: String) => { ? `${proxyBaseUrl}/global/spend/teams` : `/global/spend/teams`; console.log("in teamSpendLogsCall:", url); - const response = await fetch(`${url}`, { + const response = await fetchWithCredentials(`${url}`, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1650,10 +1638,9 @@ export const tagsSpendLogsCall = async ( } console.log("in tagsSpendLogsCall:", url); - const response = await fetch(`${url}`, { + const response = await fetchWithCredentials(`${url}`, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1678,10 +1665,9 @@ export const allTagNamesCall = async (accessToken: String) => { : `/global/spend/all_tag_names`; console.log("in global/spend/all_tag_names call", url); - const response = await fetch(`${url}`, { + const response = await fetchWithCredentials(`${url}`, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1706,10 +1692,9 @@ export const allEndUsersCall = async (accessToken: String) => { : `/global/all_end_users`; console.log("in global/all_end_users call", url); - const response = await fetch(`${url}`, { + const response = await fetchWithCredentials(`${url}`, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1738,10 +1723,9 @@ export const userFilterUICall = async (accessToken: String, params: URLSearchPar url += `?user_id=${params.get("user_id")}`; } - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1774,10 +1758,9 @@ export const userSpendLogsCall = async ( url = `${url}?start_date=${startTime}&end_date=${endTime}`; } //message.info("Making spend logs request"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1829,10 +1812,9 @@ export const uiSpendLogsCall = async ( url += `?${queryString}`; } - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -1860,10 +1842,9 @@ export const adminSpendLogsCall = async (accessToken: String) => { : `/global/spend/logs`; //message.info("Making spend logs request"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -1890,10 +1871,9 @@ export const adminTopKeysCall = async (accessToken: String) => { : `/global/spend/keys?limit=5`; //message.info("Making spend keys request"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -1940,15 +1920,14 @@ export const adminTopEndUsersCall = async ( // Define requestOptions with body as an optional property const requestOptions = { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: body, }; - const response = await fetch(url, requestOptions); + const response = await fetchWithCredentials(url, requestOptions); if (!response.ok) { const errorData = await response.text(); handleError(errorData); @@ -1991,7 +1970,7 @@ export const adminspendByProvider = async ( }, }; - const response = await fetch(url, requestOptions); + const response = await fetchWithCredentials(url, requestOptions); if (!response.ok) { const errorData = await response.text(); @@ -2029,7 +2008,7 @@ export const adminGlobalActivity = async ( }, }; - const response = await fetch(url, requestOptions); + const response = await fetchWithCredentials(url, requestOptions); if (!response.ok) { const errorData = await response.text(); @@ -2065,7 +2044,7 @@ export const adminGlobalCacheActivity = async ( }, }; - const response = await fetch(url, requestOptions); + const response = await fetchWithCredentials(url, requestOptions); if (!response.ok) { const errorData = await response.text(); @@ -2101,7 +2080,7 @@ export const adminGlobalActivityPerModel = async ( }, }; - const response = await fetch(url, requestOptions); + const response = await fetchWithCredentials(url, requestOptions); if (!response.ok) { const errorData = await response.text(); @@ -2142,7 +2121,7 @@ export const adminGlobalActivityExceptions = async ( }, }; - const response = await fetch(url, requestOptions); + const response = await fetchWithCredentials(url, requestOptions); if (!response.ok) { const errorData = await response.text(); @@ -2183,7 +2162,7 @@ export const adminGlobalActivityExceptionsPerDeployment = async ( }, }; - const response = await fetch(url, requestOptions); + const response = await fetchWithCredentials(url, requestOptions); if (!response.ok) { const errorData = await response.text(); @@ -2205,10 +2184,9 @@ export const adminTopModelsCall = async (accessToken: String) => { : `/global/spend/models?limit=5`; //message.info("Making top models request"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -2232,10 +2210,9 @@ export const keyInfoCall = async (accessToken: String, keys: String[]) => { try { let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/key/info` : `/v2/key/info`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2297,10 +2274,9 @@ export const keyListCall = async ( url += `?${queryString}`; } - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -2325,10 +2301,9 @@ export const spendUsersCall = async (accessToken: String, userID: String) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/spend/users` : `/spend/users`; console.log("in spendUsersCall:", url); - const response = await fetch(`${url}?user_id=${userID}`, { + const response = await fetchWithCredentials(`${url}?user_id=${userID}`, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -2357,10 +2332,9 @@ export const userRequestModelCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/user/request_model` : `/user/request_model`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2392,10 +2366,9 @@ export const userGetRequesedtModelsCall = async (accessToken: String) => { ? `${proxyBaseUrl}/user/get_requests` : `/user/get_requests`; console.log("in userGetRequesedtModelsCall:", url); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -2432,10 +2405,9 @@ export const userGetAllUsersCall = async ( ? `${proxyBaseUrl}/user/get_users?role=${role}` : `/user/get_users?role=${role}`; console.log("in userGetAllUsersCall:", url); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -2461,10 +2433,9 @@ export const getPossibleUserRoles = async (accessToken: String) => { const url = proxyBaseUrl ? `${proxyBaseUrl}/user/available_roles` : `/user/available_roles`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -2499,10 +2470,9 @@ export const teamCreateCall = async ( } const url = proxyBaseUrl ? `${proxyBaseUrl}/team/new` : `/team/new`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2554,10 +2524,9 @@ export const keyUpdateCall = async ( } } const url = proxyBaseUrl ? `${proxyBaseUrl}/key/update` : `/key/update`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2589,10 +2558,9 @@ export const teamUpdateCall = async ( console.log("Form Values in teamUpateCall:", formValues); // Log the form values before making the API call const url = proxyBaseUrl ? `${proxyBaseUrl}/team/update` : `/team/update`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2624,10 +2592,9 @@ export const modelUpdateCall = async ( console.log("Form Values in modelUpateCall:", formValues); // Log the form values before making the API call const url = proxyBaseUrl ? `${proxyBaseUrl}/model/update` : `/model/update`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2668,10 +2635,9 @@ export const teamMemberAddCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/team/member_add` : `/team/member_add`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2708,10 +2674,9 @@ export const teamMemberUpdateCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/team/member_update` : `/team/member_update`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2749,10 +2714,9 @@ export const teamMemberDeleteCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/team/member_delete` : `/team/member_delete`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2790,10 +2754,9 @@ export const organizationMemberAddCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/organization/member_add` : `/organization/member_add`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2831,10 +2794,9 @@ export const organizationMemberDeleteCall = async ( ? `${proxyBaseUrl}/organization/member_delete` : `/organization/member_delete`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "DELETE", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2870,10 +2832,9 @@ export const organizationMemberUpdateCall = async ( ? `${proxyBaseUrl}/organization/member_update` : `/organization/member_update`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "PATCH", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2912,10 +2873,9 @@ export const userUpdateUserCall = async ( response_body["user_role"] = userRole; } response_body = JSON.stringify(response_body); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: response_body, @@ -2950,10 +2910,9 @@ export const PredictedSpendLogsCall = async ( //message.info("Predicting spend logs request"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -2986,10 +2945,9 @@ export const slackBudgetAlertsHealthCheck = async (accessToken: String) => { console.log("Checking Slack Budget Alerts service health"); //message.info("Sending Test Slack alert..."); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, }); @@ -3025,10 +2983,9 @@ export const serviceHealthCheck = async ( console.log("Checking Slack Budget Alerts service health"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3060,10 +3017,9 @@ export const getBudgetList = async (accessToken: String) => { let url = proxyBaseUrl ? `${proxyBaseUrl}/budget/list` : `/budget/list`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3093,10 +3049,9 @@ export const getBudgetSettings = async (accessToken: String) => { : `/budget/settings`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3131,10 +3086,9 @@ export const getCallbacksCall = async ( : `/get/config/callbacks`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3162,10 +3116,9 @@ export const getGeneralSettingsCall = async (accessToken: String) => { : `/config/list?config_type=general_settings`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3194,10 +3147,9 @@ export const getPassThroughEndpointsCall = async (accessToken: String) => { : `/config/pass_through_endpoint`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3228,10 +3180,9 @@ export const getConfigFieldSetting = async ( : `/config/field/info?field_name=${fieldName}`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3265,10 +3216,9 @@ export const updatePassThroughFieldSetting = async ( field_value: fieldValue, }; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify(formData), @@ -3302,10 +3252,9 @@ export const createPassThroughEndpoint = async ( let url = proxyBaseUrl ? `${proxyBaseUrl}/config/pass_through_endpoint` : `/config/pass_through_endpoint`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -3345,10 +3294,9 @@ export const updateConfigFieldSetting = async ( config_type: "general_settings", }; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify(formData), @@ -3385,10 +3333,9 @@ export const deleteConfigFieldSetting = async ( config_type: "general_settings", }; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify(formData), @@ -3417,10 +3364,9 @@ export const deletePassThroughEndpointsCall = async (accessToken: String, endpoi : `/config/pass_through_endpoint${endpointId}`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "DELETE", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3452,10 +3398,9 @@ export const setCallbacksCall = async ( let url = proxyBaseUrl ? `${proxyBaseUrl}/config/update` : `/config/update`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "POST", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, body: JSON.stringify({ @@ -3487,10 +3432,9 @@ export const healthCheckCall = async (accessToken: String) => { let url = proxyBaseUrl ? `${proxyBaseUrl}/health` : `/health`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3519,10 +3463,9 @@ export const cachingHealthCheckCall = async (accessToken: String) => { let url = proxyBaseUrl ? `${proxyBaseUrl}/cache/ping` : `/cache/ping`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3555,10 +3498,9 @@ export const getProxyUISettings = async ( : `/sso/get/ui_settings`; //message.info("Requesting model data"); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3584,10 +3526,9 @@ export const getGuardrailsList = async (accessToken: String) => { try { let url = proxyBaseUrl ? `${proxyBaseUrl}/guardrails/list` : `/guardrails/list`; - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3621,10 +3562,9 @@ export const uiSpendLogDetailsCall = async ( console.log("Fetching log details from:", url); - const response = await fetch(url, { + const response = await fetchWithCredentials(url, { method: "GET", - headers: { - [globalLitellmHeaderName]: `Bearer ${accessToken}`, + headers: { "Content-Type": "application/json", }, }); @@ -3644,3 +3584,33 @@ export const uiSpendLogDetailsCall = async ( } }; + +/** + * Validates the current session token with the server + * @returns The validated session data or null if validation fails + */ +export const validateSession = async () => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/sso/session/validate` : `/sso/session/validate`; + + const response = await fetchWithCredentials(url, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.text(); + handleError(errorData); + throw new Error("Session validation failed"); + } + + const data = await response.json(); + console.log("Session validation response:", data); + return data; + } catch (error) { + console.error("Failed to validate session:", error); + throw error; + } +}; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/user_dashboard.tsx b/ui/litellm-dashboard/src/components/user_dashboard.tsx index 22b47d525f..ff0f1b60f1 100644 --- a/ui/litellm-dashboard/src/components/user_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/user_dashboard.tsx @@ -21,6 +21,7 @@ import { useSearchParams, useRouter } from "next/navigation"; import { Team } from "./key_team_helpers/key_list"; import { jwtDecode } from "jwt-decode"; import { Typography } from "antd"; +import { getUISessionDetails } from "@/utils/cookieUtils"; import { clearTokenCookies } from "@/utils/cookieUtils"; const isLocal = process.env.NODE_ENV === "development"; if (isLocal != true) { @@ -45,14 +46,6 @@ export type UserInfo = { spend: number; } -function getCookie(name: string) { - console.log("COOKIES", document.cookie) - const cookieValue = document.cookie - .split('; ') - .find(row => row.startsWith(name + '=')); - return cookieValue ? cookieValue.split('=')[1] : null; -} - interface UserDashboardProps { userID: string | null; userRole: string | null; @@ -94,7 +87,7 @@ const UserDashboard: React.FC = ({ // Assuming useSearchParams() hook exists and works in your setup const searchParams = useSearchParams()!; - const token = getCookie('token'); + const token = getUISessionDetails(); const invitation_id = searchParams.get("invitation_id"); @@ -146,32 +139,37 @@ const UserDashboard: React.FC = ({ // console.log(`selectedTeam: ${Object.entries(selectedTeam)}`); // Moved useEffect inside the component and used a condition to run fetch only if the params are available useEffect(() => { - if (token) { - const decoded = jwtDecode(token) as { [key: string]: any }; - if (decoded) { - // cast decoded to dictionary - console.log("Decoded token:", decoded); - - console.log("Decoded key:", decoded.key); - // set accessToken - setAccessToken(decoded.key); - + const fetchSessionDetails = async () => { + try { + const sessionDetails = await getUISessionDetails(); + console.log("Session details:", sessionDetails); + + // Set access token to the session_id + setAccessToken(sessionDetails.session_id); + // check if userRole is defined - if (decoded.user_role) { - const formattedUserRole = formatUserRole(decoded.user_role); - console.log("Decoded user_role:", formattedUserRole); + if (sessionDetails.user_role) { + const formattedUserRole = formatUserRole(sessionDetails.user_role); + console.log("User role:", formattedUserRole); setUserRole(formattedUserRole); } else { console.log("User role not defined"); } - if (decoded.user_email) { - setUserEmail(decoded.user_email); + if (sessionDetails.user_email) { + setUserEmail(sessionDetails.user_email); } else { - console.log(`User Email is not set ${decoded}`); + console.log("User Email is not set"); } + } catch (error) { + console.error("Error fetching session details:", error); } - } + }; + + fetchSessionDetails(); + }, []); + + useEffect(() => { if (userID && accessToken && userRole && !keys && !userSpendData) { const cachedUserModels = sessionStorage.getItem("userModels" + userID); if (cachedUserModels) { @@ -246,7 +244,7 @@ const UserDashboard: React.FC = ({ fetchTeams(accessToken, userID, userRole, currentOrg, setTeams); } } - }, [userID, token, accessToken, keys, userRole]); + }, [userID, accessToken, keys, userRole]); useEffect(() => { console.log(`currentOrg: ${JSON.stringify(currentOrg)}, accessToken: ${accessToken}, userID: ${userID}, userRole: ${userRole}`) @@ -333,31 +331,35 @@ const UserDashboard: React.FC = ({
- + {accessToken && ( + <> + - + + + )}
diff --git a/ui/litellm-dashboard/src/utils/cookieUtils.ts b/ui/litellm-dashboard/src/utils/cookieUtils.ts index a09cf4e97f..793f59ef26 100644 --- a/ui/litellm-dashboard/src/utils/cookieUtils.ts +++ b/ui/litellm-dashboard/src/utils/cookieUtils.ts @@ -5,6 +5,24 @@ /** * Clears the token cookie from both root and /ui paths */ +import { validateSession } from "../components/networking" + +// Define the interface for the JWT token data +export interface JWTTokenData { + user_id: string; + user_email: string | null; + user_role: string; + login_method: string; + premium_user: boolean; + auth_header_name: string; + iss: string; + aud: string; + exp: number; + disabled_non_admin_personal_key_creation: boolean; + scopes: string[]; + session_id: string; // ui session id currently in progress +} + export function clearTokenCookies() { // Get the current domain const domain = window.location.hostname; @@ -13,32 +31,51 @@ export function clearTokenCookies() { const paths = ['/', '/ui']; const sameSiteValues = ['Lax', 'Strict', 'None']; - paths.forEach(path => { - // Basic clearing - document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`; - - // With domain - document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`; - - // Try different SameSite values - sameSiteValues.forEach(sameSite => { - const secureFlag = sameSite === 'None' ? ' Secure;' : ''; - document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`; - document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`; + // Get all cookies + const allCookies = document.cookie.split("; "); + const tokenPattern = /^token_\d+$/; + + // Find all token cookies + const tokenCookieNames = allCookies + .map(cookie => cookie.split("=")[0]) + .filter(name => name === "token" || tokenPattern.test(name)); + + // Clear each token cookie with various combinations + tokenCookieNames.forEach(cookieName => { + paths.forEach(path => { + // Basic clearing + document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`; + + // With domain + document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`; + + // Try different SameSite values + sameSiteValues.forEach(sameSite => { + const secureFlag = sameSite === 'None' ? ' Secure;' : ''; + document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`; + document.cookie = `${cookieName}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`; + }); }); }); console.log("After clearing cookies:", document.cookie); } -/** - * Gets a cookie value by name - * @param name The name of the cookie to retrieve - * @returns The cookie value or null if not found - */ -export function getCookie(name: string) { - const cookieValue = document.cookie - .split('; ') - .find(row => row.startsWith(name + '=')); - return cookieValue ? cookieValue.split('=')[1] : null; -} \ No newline at end of file +export function setAuthToken(token: string) { + // Generate a token name with current timestamp + const currentTimestamp = Math.floor(Date.now() / 1000); + const tokenName = `token_${currentTimestamp}`; + + // Set the cookie with the timestamp-based name + document.cookie = `${tokenName}=${token}; path=/; domain=${window.location.hostname};`; +} + +export async function getUISessionDetails(): Promise { + const validated_jwt_token = await validateSession(); + + if (validated_jwt_token?.data) { + return validated_jwt_token.data as JWTTokenData; + } else { + throw new Error("Invalid JWT token"); + } +}