fix: replace python-jose with PyJWT for JWT handling (#3756)

# What does this PR do?

This commit migrates the authentication system from python-jose to PyJWT
to eliminate the dependency on the archived rsa package. The migration
includes:

- Refactored OAuth2TokenAuthProvider to use PyJWT's PyJWKClient for
clean JWKS handling
- Removed manual JWKS fetching, caching and key extraction logic in
favor of PyJWT's built-in functionality

The new implementation is cleaner, more maintainable, and follows PyJWT
best practices while maintaining full backward compatibility.

## Test Plan

Unit tests. Auth CI.

---------

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-10-14 09:35:48 +02:00 committed by GitHub
parent 968c364a3e
commit 1136daf310
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 93 additions and 86 deletions

View file

@ -5,13 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import ssl import ssl
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Lock
from urllib.parse import parse_qs, urljoin, urlparse from urllib.parse import parse_qs, urljoin, urlparse
import httpx import httpx
from jose import jwt import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError from llama_stack.apis.common.errors import TokenValidationError
@ -98,9 +96,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
def __init__(self, config: OAuth2TokenAuthConfig): def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config self.config = config
self._jwks_at: float = 0.0 self._jwks_client: jwt.PyJWKClient | None = None
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> User: async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks: if self.config.jwks:
@ -109,23 +105,60 @@ class OAuth2TokenAuthProvider(AuthProvider):
return await self.introspect_token(token, scope) return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured") raise ValueError("One of jwks or introspection must be configured")
def _get_jwks_client(self) -> jwt.PyJWKClient:
if self._jwks_client is None:
ssl_context = None
if not self.config.verify_tls:
# Disable SSL verification if verify_tls is False
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif self.config.tls_cafile:
# Use custom CA file if provided
ssl_context = ssl.create_default_context(
cafile=self.config.tls_cafile.as_posix(),
)
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
# to the JWK endpoint, we must use the token in the config to authenticate
headers = {}
if self.config.jwks and self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
self._jwks_client = jwt.PyJWKClient(
self.config.jwks.uri if self.config.jwks else None,
cache_keys=True,
max_cached_keys=10,
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
headers=headers,
ssl_context=ssl_context,
)
return self._jwks_client
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token.""" """Validate a token using the JWT token."""
await self._refresh_jwks()
try: try:
header = jwt.get_unverified_header(token) jwks_client: jwt.PyJWKClient = self._get_jwks_client()
kid = header["kid"] signing_key = jwks_client.get_signing_key_from_jwt(token)
if kid not in self._jwks: algorithm = jwt.get_unverified_header(token)["alg"]
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode( claims = jwt.decode(
token, token,
key_data, signing_key.key,
algorithms=[algorithm], algorithms=[algorithm],
audience=self.config.audience, audience=self.config.audience,
issuer=self.config.issuer, issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
)
# Decode and verify the JWT
claims = jwt.decode(
token,
signing_key.key,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
) )
except Exception as exc: except Exception as exc:
raise ValueError("Invalid JWT token") from exc raise ValueError("Invalid JWT token") from exc
@ -201,37 +234,6 @@ class OAuth2TokenAuthProvider(AuthProvider):
else: else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header" return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProvider(AuthProvider): class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint.""" """Custom authentication provider that uses an external endpoint."""

View file

@ -34,7 +34,7 @@ dependencies = [
"openai>=1.107", # for expires_after support "openai>=1.107", # for expires_after support
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
"python-jose[cryptography]", "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
"pydantic>=2.11.9", "pydantic>=2.11.9",
"rich", "rich",
"starlette", "starlette",

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
from unittest.mock import AsyncMock, patch import json
from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs):
@pytest.fixture @pytest.fixture
def jwt_token_valid(): def jwt_token_valid():
from jose import jwt import jwt
return jwt.encode( return jwt.encode(
{ {
@ -389,8 +390,30 @@ def jwt_token_valid():
) )
@patch("httpx.AsyncClient.get", new=mock_jwks_response) @pytest.fixture
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid): def mock_jwks_urlopen():
"""Mock urllib.request.urlopen for PyJWKClient JWKS requests."""
with patch("urllib.request.urlopen") as mock_urlopen:
# Mock the JWKS response for PyJWKClient
mock_response = Mock()
mock_response.read.return_value = json.dumps(
{
"keys": [
{
"kid": "1234567890",
"kty": "oct",
"alg": "HS256",
"use": "sig",
"k": base64.b64encode(b"foobarbaz").decode(),
}
]
}
).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
yield mock_urlopen
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"} assert response.json() == {"message": "Authentication successful"}
@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
assert response.status_code == 401 assert response.status_code == 401
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen):
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"} assert response.json() == {"message": "Authentication successful"}

49
uv.lock generated
View file

@ -874,18 +874,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" }, { url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" },
] ]
[[package]]
name = "ecdsa"
version = "0.19.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "six" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
]
[[package]] [[package]]
name = "eval-type-backport" name = "eval-type-backport"
version = "0.2.2" version = "0.2.2"
@ -1787,8 +1775,8 @@ dependencies = [
{ name = "pillow" }, { name = "pillow" },
{ name = "prompt-toolkit" }, { name = "prompt-toolkit" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "pyjwt", extra = ["crypto"] },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" }, { name = "python-multipart" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlalchemy", extra = ["asyncio"] },
@ -1910,8 +1898,8 @@ requires-dist = [
{ name = "pillow" }, { name = "pillow" },
{ name = "prompt-toolkit" }, { name = "prompt-toolkit" },
{ name = "pydantic", specifier = ">=2.11.9" }, { name = "pydantic", specifier = ">=2.11.9" },
{ name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.0" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extras = ["cryptography"] },
{ name = "python-multipart", specifier = ">=0.0.20" }, { name = "python-multipart", specifier = ">=0.0.20" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
@ -3558,6 +3546,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
] ]
[[package]]
name = "pyjwt"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
]
[package.optional-dependencies]
crypto = [
{ name = "cryptography" },
]
[[package]] [[package]]
name = "pymilvus" name = "pymilvus"
version = "2.6.1" version = "2.6.1"
@ -3747,25 +3749,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" },
] ]
[[package]]
name = "python-jose"
version = "3.5.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ecdsa" },
{ name = "pyasn1" },
{ name = "rsa" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c6/77/3a1c9039db7124eb039772b935f2244fbb73fc8ee65b9acf2375da1c07bf/python_jose-3.5.0.tar.gz", hash = "sha256:fb4eaa44dbeb1c26dcc69e4bd7ec54a1cb8dd64d3b4d81ef08d90ff453f2b01b", size = 92726, upload-time = "2025-05-28T17:31:54.288Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d9/c3/0bd11992072e6a1c513b16500a5d07f91a24017c5909b02c72c62d7ad024/python_jose-3.5.0-py2.py3-none-any.whl", hash = "sha256:abd1202f23d34dfad2c3d28cb8617b90acf34132c7afd60abd0b0b7d3cb55771", size = 34624, upload-time = "2025-05-28T17:31:52.802Z" },
]
[package.optional-dependencies]
cryptography = [
{ name = "cryptography" },
]
[[package]] [[package]]
name = "python-multipart" name = "python-multipart"
version = "0.0.20" version = "0.0.20"