From 1580f993fe26fb74b95249d05f5bd482ff6f1900 Mon Sep 17 00:00:00 2001 From: Sara Ghaemi Date: Tue, 7 May 2024 11:22:17 -0400 Subject: [PATCH 1/4] Updated JWT handler to support PEM public key --- litellm/proxy/auth/handle_jwt.py | 34 ++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 1324c2c59..c12f48cc1 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -15,6 +15,9 @@ from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable from litellm.proxy.utils import PrismaClient from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from typing import Optional +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization class JWTHandler: @@ -142,8 +145,8 @@ class JWTHandler: public_key = keys[0] elif len(keys) > 1: for key in keys: - if kid is not None and key["kid"] == kid: - public_key = key + if kid is not None and key == kid: + public_key = keys[key] if public_key is None: raise Exception( @@ -191,6 +194,33 @@ class JWTHandler: raise Exception("Token Expired") except Exception as e: raise Exception(f"Validation fails: {str(e)}") + elif public_key is not None and isinstance(public_key, str): + audience = os.getenv("JWT_AUDIENCE") + if audience is None: + raise Exception("Missing JWT Audience from environment.") + try: + cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend()) + + # Extract public key + key = cert.public_key().public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo + ) + + # decode the token using the public key + payload = jwt.decode( + token, + key, + algorithms=["RS256"], + audience=audience, + ) + return payload + + except jwt.ExpiredSignatureError: + # the token is expired, do something to refresh it + raise Exception("Token Expired") + except Exception as e: + raise Exception(f"Validation fails: {str(e)}") raise Exception("Invalid JWT Submitted") From 66b2b5fab9b5b87af1d27ceaca76e402064280c1 Mon Sep 17 00:00:00 2001 From: Sara Ghaemi Date: Tue, 7 May 2024 11:37:04 -0400 Subject: [PATCH 2/4] made audience optional and updated docs --- docs/my-website/docs/proxy/token_auth.md | 1 + litellm/proxy/auth/handle_jwt.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index 81475951f..e4772d70a 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback. ### Step 1. Setup Proxy - `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`. +- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience. ```bash export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks" diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index c12f48cc1..606ff6828 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -156,6 +156,11 @@ class JWTHandler: return public_key async def auth_jwt(self, token: str) -> dict: + audience = os.getenv("JWT_AUDIENCE") + decode_options = None + if audience is None: + decode_options = {"verify_aud": False} + from jwt.algorithms import RSAAlgorithm header = jwt.get_unverified_header(token) @@ -185,7 +190,8 @@ class JWTHandler: token, public_key_rsa, # type: ignore algorithms=["RS256"], - options={"verify_aud": False}, + options=decode_options, + audience=audience, ) return payload @@ -195,9 +201,6 @@ class JWTHandler: except Exception as e: raise Exception(f"Validation fails: {str(e)}") elif public_key is not None and isinstance(public_key, str): - audience = os.getenv("JWT_AUDIENCE") - if audience is None: - raise Exception("Missing JWT Audience from environment.") try: cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend()) @@ -213,6 +216,7 @@ class JWTHandler: key, algorithms=["RS256"], audience=audience, + options=decode_options ) return payload From 7017899d372a2b1e67ec3bb1ef0af28bf0a81fd6 Mon Sep 17 00:00:00 2001 From: Sara Ghaemi Date: Tue, 7 May 2024 12:10:47 -0400 Subject: [PATCH 3/4] updated tests to also check for audience if found --- litellm/tests/test_jwt.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 407814e84..3dacebe27 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -24,7 +24,6 @@ public_key = { "alg": "RS256", } - def test_load_config_with_custom_role_names(): config = { "general_settings": { @@ -136,6 +135,8 @@ async def test_valid_invalid_token(): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-proxy-admin", } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -163,6 +164,8 @@ async def test_valid_invalid_token(): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-NO-SCOPE", } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -266,6 +269,8 @@ async def test_team_token_output(prisma_client): "scope": "litellm_team", "client_id": team_id, } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -280,6 +285,8 @@ async def test_team_token_output(prisma_client): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") @@ -421,6 +428,8 @@ async def test_user_token_output(prisma_client): "scope": "litellm_team", "client_id": team_id, } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -435,6 +444,8 @@ async def test_user_token_output(prisma_client): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") From 86e0dd68c3c6f74f7955ba09dfde99691c3daab6 Mon Sep 17 00:00:00 2001 From: Sara Ghaemi Date: Tue, 7 May 2024 13:28:57 -0400 Subject: [PATCH 4/4] updated tests --- litellm/tests/test_jwt.py | 41 +++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 3dacebe27..b3af9913f 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -77,9 +77,9 @@ async def test_token_single_public_key(): == "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ" ) - +@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_valid_invalid_token(): +async def test_valid_invalid_token(audience): """ Tests - valid token @@ -90,6 +90,10 @@ async def test_valid_invalid_token(): from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend + os.environ.pop('JWT_AUDIENCE', None) + if audience: + os.environ["JWT_AUDIENCE"] = audience + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -134,9 +138,8 @@ async def test_valid_invalid_token(): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-proxy-admin", + "aud": audience } - if os.getenv("JWT_AUDIENCE"): - payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -163,9 +166,8 @@ async def test_valid_invalid_token(): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-NO-SCOPE", + "aud": audience } - if os.getenv("JWT_AUDIENCE"): - payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -181,7 +183,6 @@ async def test_valid_invalid_token(): except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") - @pytest.fixture def prisma_client(): import litellm @@ -204,8 +205,9 @@ def prisma_client(): return prisma_client +@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_team_token_output(prisma_client): +async def test_team_token_output(prisma_client, audience): import jwt, json from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -220,6 +222,10 @@ async def test_team_token_output(prisma_client): setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() + os.environ.pop('JWT_AUDIENCE', None) + if audience: + os.environ["JWT_AUDIENCE"] = audience + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -268,9 +274,8 @@ async def test_team_token_output(prisma_client): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_team", "client_id": team_id, + "aud": audience } - if os.getenv("JWT_AUDIENCE"): - payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -284,9 +289,8 @@ async def test_team_token_output(prisma_client): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", + "aud": audience } - if os.getenv("JWT_AUDIENCE"): - payload["aud"] = os.getenv("JWT_AUDIENCE") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") @@ -354,8 +358,9 @@ async def test_team_token_output(prisma_client): assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] +@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_user_token_output(prisma_client): +async def test_user_token_output(prisma_client, audience): """ - If user required, check if it exists - fail initial request (when user doesn't exist) @@ -376,6 +381,10 @@ async def test_user_token_output(prisma_client): setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() + os.environ.pop('JWT_AUDIENCE', None) + if audience: + os.environ["JWT_AUDIENCE"] = audience + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -427,9 +436,8 @@ async def test_user_token_output(prisma_client): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_team", "client_id": team_id, + "aud": audience } - if os.getenv("JWT_AUDIENCE"): - payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -443,9 +451,8 @@ async def test_user_token_output(prisma_client): "sub": user_id, "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", + "aud": audience } - if os.getenv("JWT_AUDIENCE"): - payload["aud"] = os.getenv("JWT_AUDIENCE") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")