diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index 81475951f..e4772d70a 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback. ### Step 1. Setup Proxy - `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`. +- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience. ```bash export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks" diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 1324c2c59..606ff6828 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -15,6 +15,9 @@ from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable from litellm.proxy.utils import PrismaClient from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from typing import Optional +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization class JWTHandler: @@ -142,8 +145,8 @@ class JWTHandler: public_key = keys[0] elif len(keys) > 1: for key in keys: - if kid is not None and key["kid"] == kid: - public_key = key + if kid is not None and key == kid: + public_key = keys[key] if public_key is None: raise Exception( @@ -153,6 +156,11 @@ class JWTHandler: return public_key async def auth_jwt(self, token: str) -> dict: + audience = os.getenv("JWT_AUDIENCE") + decode_options = None + if audience is None: + decode_options = {"verify_aud": False} + from jwt.algorithms import RSAAlgorithm header = jwt.get_unverified_header(token) @@ -182,7 +190,33 @@ class JWTHandler: token, public_key_rsa, # type: ignore algorithms=["RS256"], - options={"verify_aud": False}, + options=decode_options, + audience=audience, + ) + return payload + + except jwt.ExpiredSignatureError: + # the token is expired, do something to refresh it + raise Exception("Token Expired") + except Exception as e: + raise Exception(f"Validation fails: {str(e)}") + elif public_key is not None and isinstance(public_key, str): + try: + cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend()) + + # Extract public key + key = cert.public_key().public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo + ) + + # decode the token using the public key + payload = jwt.decode( + token, + key, + algorithms=["RS256"], + audience=audience, + options=decode_options ) return payload diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 407814e84..b3af9913f 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -24,7 +24,6 @@ public_key = { "alg": "RS256", } - def test_load_config_with_custom_role_names(): config = { "general_settings": { @@ -78,9 +77,9 @@ async def test_token_single_public_key(): == "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ" ) - +@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_valid_invalid_token(): +async def test_valid_invalid_token(audience): """ Tests - valid token @@ -91,6 +90,10 @@ async def test_valid_invalid_token(): from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend + os.environ.pop('JWT_AUDIENCE', None) + if audience: + os.environ["JWT_AUDIENCE"] = audience + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -135,6 +138,7 @@ async def test_valid_invalid_token(): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-proxy-admin", + "aud": audience } # Generate the JWT token @@ -162,6 +166,7 @@ async def test_valid_invalid_token(): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-NO-SCOPE", + "aud": audience } # Generate the JWT token @@ -178,7 +183,6 @@ async def test_valid_invalid_token(): except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") - @pytest.fixture def prisma_client(): import litellm @@ -201,8 +205,9 @@ def prisma_client(): return prisma_client +@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_team_token_output(prisma_client): +async def test_team_token_output(prisma_client, audience): import jwt, json from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -217,6 +222,10 @@ async def test_team_token_output(prisma_client): setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() + os.environ.pop('JWT_AUDIENCE', None) + if audience: + os.environ["JWT_AUDIENCE"] = audience + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -265,6 +274,7 @@ async def test_team_token_output(prisma_client): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_team", "client_id": team_id, + "aud": audience } # Generate the JWT token @@ -279,6 +289,7 @@ async def test_team_token_output(prisma_client): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", + "aud": audience } admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") @@ -347,8 +358,9 @@ async def test_team_token_output(prisma_client): assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] +@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_user_token_output(prisma_client): +async def test_user_token_output(prisma_client, audience): """ - If user required, check if it exists - fail initial request (when user doesn't exist) @@ -369,6 +381,10 @@ async def test_user_token_output(prisma_client): setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() + os.environ.pop('JWT_AUDIENCE', None) + if audience: + os.environ["JWT_AUDIENCE"] = audience + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -420,6 +436,7 @@ async def test_user_token_output(prisma_client): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_team", "client_id": team_id, + "aud": audience } # Generate the JWT token @@ -434,6 +451,7 @@ async def test_user_token_output(prisma_client): "sub": user_id, "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", + "aud": audience } admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")