Merge pull request #3500 from ghaemisr/main

Added support for JWT auth with PEM cert public keys
This commit is contained in:
Krish Dholakia 2024-05-07 11:07:30 -07:00 committed by GitHub
commit 93e5fb49d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 62 additions and 9 deletions

View file

@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback.
### Step 1. Setup Proxy
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience.
```bash
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"

View file

@ -15,6 +15,9 @@ from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
from litellm.proxy.utils import PrismaClient
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from typing import Optional
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
class JWTHandler:
@ -142,8 +145,8 @@ class JWTHandler:
public_key = keys[0]
elif len(keys) > 1:
for key in keys:
if kid is not None and key["kid"] == kid:
public_key = key
if kid is not None and key == kid:
public_key = keys[key]
if public_key is None:
raise Exception(
@ -153,6 +156,11 @@ class JWTHandler:
return public_key
async def auth_jwt(self, token: str) -> dict:
audience = os.getenv("JWT_AUDIENCE")
decode_options = None
if audience is None:
decode_options = {"verify_aud": False}
from jwt.algorithms import RSAAlgorithm
header = jwt.get_unverified_header(token)
@ -182,7 +190,33 @@ class JWTHandler:
token,
public_key_rsa, # type: ignore
algorithms=["RS256"],
options={"verify_aud": False},
options=decode_options,
audience=audience,
)
return payload
except jwt.ExpiredSignatureError:
# the token is expired, do something to refresh it
raise Exception("Token Expired")
except Exception as e:
raise Exception(f"Validation fails: {str(e)}")
elif public_key is not None and isinstance(public_key, str):
try:
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend())
# Extract public key
key = cert.public_key().public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo
)
# decode the token using the public key
payload = jwt.decode(
token,
key,
algorithms=["RS256"],
audience=audience,
options=decode_options
)
return payload

View file

@ -24,7 +24,6 @@ public_key = {
"alg": "RS256",
}
def test_load_config_with_custom_role_names():
config = {
"general_settings": {
@ -78,9 +77,9 @@ async def test_token_single_public_key():
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
)
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_valid_invalid_token():
async def test_valid_invalid_token(audience):
"""
Tests
- valid token
@ -91,6 +90,10 @@ async def test_valid_invalid_token():
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
os.environ.pop('JWT_AUDIENCE', None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
@ -135,6 +138,7 @@ async def test_valid_invalid_token():
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin",
"aud": audience
}
# Generate the JWT token
@ -162,6 +166,7 @@ async def test_valid_invalid_token():
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE",
"aud": audience
}
# Generate the JWT token
@ -178,7 +183,6 @@ async def test_valid_invalid_token():
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.fixture
def prisma_client():
import litellm
@ -201,8 +205,9 @@ def prisma_client():
return prisma_client
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_team_token_output(prisma_client):
async def test_team_token_output(prisma_client, audience):
import jwt, json
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
@ -217,6 +222,10 @@ async def test_team_token_output(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop('JWT_AUDIENCE', None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
@ -265,6 +274,7 @@ async def test_team_token_output(prisma_client):
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience
}
# Generate the JWT token
@ -279,6 +289,7 @@ async def test_team_token_output(prisma_client):
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
@ -347,8 +358,9 @@ async def test_team_token_output(prisma_client):
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_user_token_output(prisma_client):
async def test_user_token_output(prisma_client, audience):
"""
- If user required, check if it exists
- fail initial request (when user doesn't exist)
@ -369,6 +381,10 @@ async def test_user_token_output(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop('JWT_AUDIENCE', None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
@ -420,6 +436,7 @@ async def test_user_token_output(prisma_client):
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience
}
# Generate the JWT token
@ -434,6 +451,7 @@ async def test_user_token_output(prisma_client):
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")