diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 31bea6504f..6d00b2774f 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -23,3 +23,5 @@ general_settings: alerting: ["slack"] alerting_args: report_check_interval: 10 + enable_jwt_auth: True + diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index fce6d4254b..d06469b71f 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -123,18 +123,8 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: """ for allowed_route in allowed_routes: if ( - allowed_route == LiteLLMRoutes.openai_routes.name - and user_route in LiteLLMRoutes.openai_routes.value - ): - return True - elif ( - allowed_route == LiteLLMRoutes.info_routes.name - and user_route in LiteLLMRoutes.info_routes.value - ): - return True - elif ( - allowed_route == LiteLLMRoutes.management_routes.name - and user_route in LiteLLMRoutes.management_routes.value + allowed_route in LiteLLMRoutes.__members__ + and user_route in LiteLLMRoutes[allowed_route].value ): return True elif allowed_route == user_route: @@ -152,17 +142,11 @@ def allowed_routes_check( """ if user_role == "proxy_admin": - if litellm_proxy_roles.admin_allowed_routes is None: - is_allowed = _allowed_routes_check( - user_route=user_route, allowed_routes=["management_routes"] - ) - return is_allowed - elif litellm_proxy_roles.admin_allowed_routes is not None: - is_allowed = _allowed_routes_check( - user_route=user_route, - allowed_routes=litellm_proxy_roles.admin_allowed_routes, - ) - return is_allowed + is_allowed = _allowed_routes_check( + user_route=user_route, + allowed_routes=litellm_proxy_roles.admin_allowed_routes, + ) + return is_allowed elif user_role == "team": if litellm_proxy_roles.team_allowed_routes is None: diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 0a186d7dde..e02bb1e8aa 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -167,10 +167,17 @@ class JWTHandler: for key in keys: if kid is not None and key == kid: public_key = keys[key] + elif ( + kid is not None + and isinstance(key, dict) + and key.get("kid", None) is not None + and key["kid"] == kid + ): + public_key = key if public_key is None: raise Exception( - f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}" + f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}" ) return public_key diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9a96dba794..b52c9b249e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -497,7 +497,7 @@ async def user_api_key_auth( litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) if is_allowed: - return UserAPIKeyAuth() + return UserAPIKeyAuth(user_role="proxy_admin") else: allowed_routes = ( jwt_handler.litellm_jwtauth.admin_allowed_routes diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 45f4616290..a3768bc23c 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -12,7 +12,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest -from litellm.proxy._types import LiteLLM_JWTAuth +from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.caching import DualCache from datetime import datetime, timedelta @@ -602,3 +602,128 @@ async def test_user_token_output( assert team_result.team_rpm_limit == 99 assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] assert team_result.user_id == user_id + + +@pytest.mark.parametrize("audience", [None, "litellm-proxy"]) +@pytest.mark.asyncio +async def test_allowed_routes_admin(prisma_client, audience): + """ + Add a check to make sure jwt proxy admin scope can access all allowed admin routes + + - iterate through allowed endpoints + - check if admin passes user_api_key_auth for them + """ + import jwt, json + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.backends import default_backend + from fastapi import Request + from starlette.datastructures import URL + from litellm.proxy.proxy_server import user_api_key_auth, new_team + from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth + import litellm + import uuid + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + await litellm.proxy.proxy_server.prisma_client.connect() + + os.environ.pop("JWT_AUDIENCE", None) + if audience: + os.environ["JWT_AUDIENCE"] = audience + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + assert isinstance(public_jwk, dict) + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + + jwt_handler = JWTHandler() + + jwt_handler.user_api_key_cache = cache + + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id") + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + + ## admin token + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_proxy_admin", + "aud": audience, + } + + admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") + + # verify token + + response = await jwt_handler.auth_jwt(token=admin_token) + + ## RUN IT THROUGH USER API KEY AUTH + + """ + - 1. Initial call should fail -> team doesn't exist + - 2. Create team via admin token + - 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted + """ + + bearer_token = "Bearer " + admin_token + + pseudo_routes = jwt_handler.litellm_jwtauth.admin_allowed_routes + + actual_routes = [] + for route in pseudo_routes: + if route in LiteLLMRoutes.__members__: + actual_routes.extend(LiteLLMRoutes[route].value) + + for route in actual_routes: + request = Request(scope={"type": "http"}) + + request._url = URL(url=route) + + ## 1. INITIAL TEAM CALL - should fail + # use generated key to auth in + setattr( + litellm.proxy.proxy_server, + "general_settings", + { + "enable_jwt_auth": True, + }, + ) + setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) + try: + result = await user_api_key_auth(request=request, api_key=bearer_token) + except Exception as e: + raise e