forked from phoenix/litellm-mirror
Merge pull request #3500 from ghaemisr/main
Added support for JWT auth with PEM cert public keys
This commit is contained in:
commit
93e5fb49d3
3 changed files with 62 additions and 9 deletions
|
@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback.
|
|||
### Step 1. Setup Proxy
|
||||
|
||||
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
||||
- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience.
|
||||
|
||||
```bash
|
||||
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"
|
||||
|
|
|
@ -15,6 +15,9 @@ from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
|
|||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from typing import Optional
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
|
||||
class JWTHandler:
|
||||
|
@ -142,8 +145,8 @@ class JWTHandler:
|
|||
public_key = keys[0]
|
||||
elif len(keys) > 1:
|
||||
for key in keys:
|
||||
if kid is not None and key["kid"] == kid:
|
||||
public_key = key
|
||||
if kid is not None and key == kid:
|
||||
public_key = keys[key]
|
||||
|
||||
if public_key is None:
|
||||
raise Exception(
|
||||
|
@ -153,6 +156,11 @@ class JWTHandler:
|
|||
return public_key
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
audience = os.getenv("JWT_AUDIENCE")
|
||||
decode_options = None
|
||||
if audience is None:
|
||||
decode_options = {"verify_aud": False}
|
||||
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
@ -182,7 +190,33 @@ class JWTHandler:
|
|||
token,
|
||||
public_key_rsa, # type: ignore
|
||||
algorithms=["RS256"],
|
||||
options={"verify_aud": False},
|
||||
options=decode_options,
|
||||
audience=audience,
|
||||
)
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
# the token is expired, do something to refresh it
|
||||
raise Exception("Token Expired")
|
||||
except Exception as e:
|
||||
raise Exception(f"Validation fails: {str(e)}")
|
||||
elif public_key is not None and isinstance(public_key, str):
|
||||
try:
|
||||
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend())
|
||||
|
||||
# Extract public key
|
||||
key = cert.public_key().public_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
|
||||
# decode the token using the public key
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=["RS256"],
|
||||
audience=audience,
|
||||
options=decode_options
|
||||
)
|
||||
return payload
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ public_key = {
|
|||
"alg": "RS256",
|
||||
}
|
||||
|
||||
|
||||
def test_load_config_with_custom_role_names():
|
||||
config = {
|
||||
"general_settings": {
|
||||
|
@ -78,9 +77,9 @@ async def test_token_single_public_key():
|
|||
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_invalid_token():
|
||||
async def test_valid_invalid_token(audience):
|
||||
"""
|
||||
Tests
|
||||
- valid token
|
||||
|
@ -91,6 +90,10 @@ async def test_valid_invalid_token():
|
|||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
|
@ -135,6 +138,7 @@ async def test_valid_invalid_token():
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm-proxy-admin",
|
||||
"aud": audience
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -162,6 +166,7 @@ async def test_valid_invalid_token():
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm-NO-SCOPE",
|
||||
"aud": audience
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -178,7 +183,6 @@ async def test_valid_invalid_token():
|
|||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prisma_client():
|
||||
import litellm
|
||||
|
@ -201,8 +205,9 @@ def prisma_client():
|
|||
return prisma_client
|
||||
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_token_output(prisma_client):
|
||||
async def test_team_token_output(prisma_client, audience):
|
||||
import jwt, json
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
@ -217,6 +222,10 @@ async def test_team_token_output(prisma_client):
|
|||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
|
@ -265,6 +274,7 @@ async def test_team_token_output(prisma_client):
|
|||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
"aud": audience
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -279,6 +289,7 @@ async def test_team_token_output(prisma_client):
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_proxy_admin",
|
||||
"aud": audience
|
||||
}
|
||||
|
||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
@ -347,8 +358,9 @@ async def test_team_token_output(prisma_client):
|
|||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_token_output(prisma_client):
|
||||
async def test_user_token_output(prisma_client, audience):
|
||||
"""
|
||||
- If user required, check if it exists
|
||||
- fail initial request (when user doesn't exist)
|
||||
|
@ -369,6 +381,10 @@ async def test_user_token_output(prisma_client):
|
|||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
|
@ -420,6 +436,7 @@ async def test_user_token_output(prisma_client):
|
|||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
"aud": audience
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -434,6 +451,7 @@ async def test_user_token_output(prisma_client):
|
|||
"sub": user_id,
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_proxy_admin",
|
||||
"aud": audience
|
||||
}
|
||||
|
||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue