diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 268208bae..21300bc54 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -147,9 +147,7 @@ class JWTHandler: scopes = [] return scopes - async def auth_jwt(self, token: str) -> dict: - from jwt.algorithms import RSAAlgorithm - + async def get_public_key(self, kid: Optional[str]) -> dict: keys_url = os.getenv("JWT_PUBLIC_KEY_URL") if keys_url is None: @@ -169,19 +167,33 @@ class JWTHandler: else: keys = cached_keys + public_key: Optional[dict] = None + + if len(keys) == 1: + public_key = keys[0] + elif len(keys) > 1: + for key in keys: + if kid is not None and key["kid"] == kid: + public_key = key + + if public_key is None: + raise Exception( + f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}" + ) + + return public_key + + async def auth_jwt(self, token: str) -> dict: + from jwt.algorithms import RSAAlgorithm + header = jwt.get_unverified_header(token) verbose_proxy_logger.debug(f"header: {header}") kid = header.get("kid", None) - public_key = None - if len(keys) == 1: - public_key = public_key - elif len(keys) > 1: - for key in keys: - if key["kid"] == kid: - public_key = key + public_key = await self.get_public_key(kid=kid) + if public_key is not None and isinstance(public_key, dict): jwk = {} if "kty" in public_key: @@ -193,13 +205,13 @@ class JWTHandler: if "e" in public_key: jwk["e"] = public_key["e"] - public_key = RSAAlgorithm.from_jwk(json.dumps(jwk)) + public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk)) try: # decode the token using the public key payload = jwt.decode( token, - public_key, # type: ignore + public_key_rsa, # type: ignore algorithms=["RS256"], options={"verify_aud": False}, ) diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index a2c9e4e4a..57c7e5c62 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -13,6 +13,16 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest from litellm.proxy._types import LiteLLMProxyRoles +from litellm.proxy.auth.handle_jwt import JWTHandler +from litellm.caching import DualCache +from datetime import datetime, timedelta + +public_key = { + "kty": "RSA", + "e": "AQAB", + "n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ", + "alg": "RS256", +} def test_load_config_with_custom_role_names(): @@ -32,3 +42,138 @@ def test_load_config_with_custom_role_names(): # test_load_config_with_custom_role_names() + + +@pytest.mark.asyncio +async def test_token_single_public_key(): + import jwt + + jwt_handler = JWTHandler() + + backend_keys = { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "e": "AQAB", + "n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ", + "alg": "RS256", + } + ] + } + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"]) + + jwt_handler.user_api_key_cache = cache + + public_key = await jwt_handler.get_public_key(kid=None) + + assert public_key is not None + assert isinstance(public_key, dict) + assert ( + public_key["n"] + == "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ" + ) + + +@pytest.mark.asyncio +async def test_valid_invalid_token(): + """ + Tests + - valid token + - invalid token + """ + import jwt, json + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.backends import default_backend + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + assert isinstance(public_jwk, dict) + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + + jwt_handler = JWTHandler() + + jwt_handler.user_api_key_cache = cache + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm-proxy-admin", + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + + # verify token + + response = await jwt_handler.auth_jwt(token=token) + + assert response is not None + assert isinstance(response, dict) + + print(f"response: {response}") + + # INVALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm-NO-SCOPE", + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + + # verify token + + try: + response = await jwt_handler.auth_jwt(token=token) + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}")