#### What this tests #### # Unit tests for JWT-Auth import sys, os, asyncio, time, random import traceback from dotenv import load_dotenv load_dotenv() import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest from litellm.proxy._types import LiteLLMProxyRoles from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.caching import DualCache from datetime import datetime, timedelta public_key = { "kty": "RSA", "e": "AQAB", "n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ", "alg": "RS256", } def test_load_config_with_custom_role_names(): config = { "general_settings": { "litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"} } } proxy_roles = LiteLLMProxyRoles( **config.get("general_settings", {}).get("litellm_proxy_roles", {}) ) print(f"proxy_roles: {proxy_roles}") assert proxy_roles.proxy_admin == "litellm-proxy-admin" # test_load_config_with_custom_role_names() @pytest.mark.asyncio async def test_token_single_public_key(): import jwt jwt_handler = JWTHandler() backend_keys = { "keys": [ { "kty": "RSA", "use": "sig", "e": "AQAB", "n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ", "alg": "RS256", } ] } # set cache cache = DualCache() await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"]) jwt_handler.user_api_key_cache = cache public_key = await jwt_handler.get_public_key(kid=None) assert public_key is not None assert isinstance(public_key, dict) assert ( public_key["n"] == "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ" ) @pytest.mark.asyncio async def test_valid_invalid_token(): """ Tests - valid token - invalid token """ import jwt, json from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() ) # Get private key in PEM format private_key = key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) # Get public key in PEM format public_key = key.public_key().public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) public_key_obj = serialization.load_pem_public_key( public_key, backend=default_backend() ) # Convert RSA public key object to JWK (JSON Web Key) public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) assert isinstance(public_jwk, dict) # set cache cache = DualCache() await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) jwt_handler = JWTHandler() jwt_handler.user_api_key_cache = cache # VALID TOKEN ## GENERATE A TOKEN # Assuming the current time is in UTC expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) payload = { "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-proxy-admin", } # Generate the JWT token # But before, you should convert bytes to string private_key_str = private_key.decode("utf-8") token = jwt.encode(payload, private_key_str, algorithm="RS256") ## VERIFY IT WORKS # verify token response = await jwt_handler.auth_jwt(token=token) assert response is not None assert isinstance(response, dict) print(f"response: {response}") # INVALID TOKEN ## GENERATE A TOKEN # Assuming the current time is in UTC expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) payload = { "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-NO-SCOPE", } # Generate the JWT token # But before, you should convert bytes to string private_key_str = private_key.decode("utf-8") token = jwt.encode(payload, private_key_str, algorithm="RS256") ## VERIFY IT WORKS # verify token try: response = await jwt_handler.auth_jwt(token=token) except Exception as e: pytest.fail(f"An exception occurred - {str(e)}")