forked from phoenix/litellm-mirror
Merge pull request #3500 from ghaemisr/main
Added support for JWT auth with PEM cert public keys
This commit is contained in:
commit
93e5fb49d3
3 changed files with 62 additions and 9 deletions
|
@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback.
|
||||||
### Step 1. Setup Proxy
|
### Step 1. Setup Proxy
|
||||||
|
|
||||||
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
||||||
|
- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"
|
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"
|
||||||
|
|
|
@ -15,6 +15,9 @@ from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from cryptography import x509
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
|
||||||
|
|
||||||
class JWTHandler:
|
class JWTHandler:
|
||||||
|
@ -142,8 +145,8 @@ class JWTHandler:
|
||||||
public_key = keys[0]
|
public_key = keys[0]
|
||||||
elif len(keys) > 1:
|
elif len(keys) > 1:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if kid is not None and key["kid"] == kid:
|
if kid is not None and key == kid:
|
||||||
public_key = key
|
public_key = keys[key]
|
||||||
|
|
||||||
if public_key is None:
|
if public_key is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -153,6 +156,11 @@ class JWTHandler:
|
||||||
return public_key
|
return public_key
|
||||||
|
|
||||||
async def auth_jwt(self, token: str) -> dict:
|
async def auth_jwt(self, token: str) -> dict:
|
||||||
|
audience = os.getenv("JWT_AUDIENCE")
|
||||||
|
decode_options = None
|
||||||
|
if audience is None:
|
||||||
|
decode_options = {"verify_aud": False}
|
||||||
|
|
||||||
from jwt.algorithms import RSAAlgorithm
|
from jwt.algorithms import RSAAlgorithm
|
||||||
|
|
||||||
header = jwt.get_unverified_header(token)
|
header = jwt.get_unverified_header(token)
|
||||||
|
@ -182,7 +190,33 @@ class JWTHandler:
|
||||||
token,
|
token,
|
||||||
public_key_rsa, # type: ignore
|
public_key_rsa, # type: ignore
|
||||||
algorithms=["RS256"],
|
algorithms=["RS256"],
|
||||||
options={"verify_aud": False},
|
options=decode_options,
|
||||||
|
audience=audience,
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
# the token is expired, do something to refresh it
|
||||||
|
raise Exception("Token Expired")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Validation fails: {str(e)}")
|
||||||
|
elif public_key is not None and isinstance(public_key, str):
|
||||||
|
try:
|
||||||
|
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend())
|
||||||
|
|
||||||
|
# Extract public key
|
||||||
|
key = cert.public_key().public_bytes(
|
||||||
|
serialization.Encoding.PEM,
|
||||||
|
serialization.PublicFormat.SubjectPublicKeyInfo
|
||||||
|
)
|
||||||
|
|
||||||
|
# decode the token using the public key
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
key,
|
||||||
|
algorithms=["RS256"],
|
||||||
|
audience=audience,
|
||||||
|
options=decode_options
|
||||||
)
|
)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,6 @@ public_key = {
|
||||||
"alg": "RS256",
|
"alg": "RS256",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_load_config_with_custom_role_names():
|
def test_load_config_with_custom_role_names():
|
||||||
config = {
|
config = {
|
||||||
"general_settings": {
|
"general_settings": {
|
||||||
|
@ -78,9 +77,9 @@ async def test_token_single_public_key():
|
||||||
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
|
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_invalid_token():
|
async def test_valid_invalid_token(audience):
|
||||||
"""
|
"""
|
||||||
Tests
|
Tests
|
||||||
- valid token
|
- valid token
|
||||||
|
@ -91,6 +90,10 @@ async def test_valid_invalid_token():
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
os.environ.pop('JWT_AUDIENCE', None)
|
||||||
|
if audience:
|
||||||
|
os.environ["JWT_AUDIENCE"] = audience
|
||||||
|
|
||||||
# Generate a private / public key pair using RSA algorithm
|
# Generate a private / public key pair using RSA algorithm
|
||||||
key = rsa.generate_private_key(
|
key = rsa.generate_private_key(
|
||||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
|
@ -135,6 +138,7 @@ async def test_valid_invalid_token():
|
||||||
"sub": "user123",
|
"sub": "user123",
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm-proxy-admin",
|
"scope": "litellm-proxy-admin",
|
||||||
|
"aud": audience
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
|
@ -162,6 +166,7 @@ async def test_valid_invalid_token():
|
||||||
"sub": "user123",
|
"sub": "user123",
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm-NO-SCOPE",
|
"scope": "litellm-NO-SCOPE",
|
||||||
|
"aud": audience
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
|
@ -178,7 +183,6 @@ async def test_valid_invalid_token():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def prisma_client():
|
def prisma_client():
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -201,8 +205,9 @@ def prisma_client():
|
||||||
return prisma_client
|
return prisma_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_team_token_output(prisma_client):
|
async def test_team_token_output(prisma_client, audience):
|
||||||
import jwt, json
|
import jwt, json
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
@ -217,6 +222,10 @@ async def test_team_token_output(prisma_client):
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
os.environ.pop('JWT_AUDIENCE', None)
|
||||||
|
if audience:
|
||||||
|
os.environ["JWT_AUDIENCE"] = audience
|
||||||
|
|
||||||
# Generate a private / public key pair using RSA algorithm
|
# Generate a private / public key pair using RSA algorithm
|
||||||
key = rsa.generate_private_key(
|
key = rsa.generate_private_key(
|
||||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
|
@ -265,6 +274,7 @@ async def test_team_token_output(prisma_client):
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm_team",
|
"scope": "litellm_team",
|
||||||
"client_id": team_id,
|
"client_id": team_id,
|
||||||
|
"aud": audience
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
|
@ -279,6 +289,7 @@ async def test_team_token_output(prisma_client):
|
||||||
"sub": "user123",
|
"sub": "user123",
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm_proxy_admin",
|
"scope": "litellm_proxy_admin",
|
||||||
|
"aud": audience
|
||||||
}
|
}
|
||||||
|
|
||||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
@ -347,8 +358,9 @@ async def test_team_token_output(prisma_client):
|
||||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_token_output(prisma_client):
|
async def test_user_token_output(prisma_client, audience):
|
||||||
"""
|
"""
|
||||||
- If user required, check if it exists
|
- If user required, check if it exists
|
||||||
- fail initial request (when user doesn't exist)
|
- fail initial request (when user doesn't exist)
|
||||||
|
@ -369,6 +381,10 @@ async def test_user_token_output(prisma_client):
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
os.environ.pop('JWT_AUDIENCE', None)
|
||||||
|
if audience:
|
||||||
|
os.environ["JWT_AUDIENCE"] = audience
|
||||||
|
|
||||||
# Generate a private / public key pair using RSA algorithm
|
# Generate a private / public key pair using RSA algorithm
|
||||||
key = rsa.generate_private_key(
|
key = rsa.generate_private_key(
|
||||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
|
@ -420,6 +436,7 @@ async def test_user_token_output(prisma_client):
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm_team",
|
"scope": "litellm_team",
|
||||||
"client_id": team_id,
|
"client_id": team_id,
|
||||||
|
"aud": audience
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
|
@ -434,6 +451,7 @@ async def test_user_token_output(prisma_client):
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm_proxy_admin",
|
"scope": "litellm_proxy_admin",
|
||||||
|
"aud": audience
|
||||||
}
|
}
|
||||||
|
|
||||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue