fix: replace python-jose with PyJWT for JWT handling (#3756)

# What does this PR do?

This commit migrates the authentication system from python-jose to PyJWT
to eliminate the dependency on the archived rsa package. The migration
includes:

- Refactored OAuth2TokenAuthProvider to use PyJWT's PyJWKClient for
clean JWKS handling
- Removed manual JWKS fetching, caching and key extraction logic in
favor of PyJWT's built-in functionality

The new implementation is cleaner, more maintainable, and follows PyJWT
best practices while maintaining full backward compatibility.

## Test Plan

Unit tests. Auth CI.

---------

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-10-14 09:35:48 +02:00 committed by GitHub
parent 968c364a3e
commit 1136daf310
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 93 additions and 86 deletions

View file

@ -5,13 +5,11 @@
# the root directory of this source tree.
import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from urllib.parse import parse_qs, urljoin, urlparse
import httpx
from jose import jwt
import jwt
from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError
@ -98,9 +96,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
self._jwks_client: jwt.PyJWKClient | None = None
async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks:
@ -109,23 +105,60 @@ class OAuth2TokenAuthProvider(AuthProvider):
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
def _get_jwks_client(self) -> jwt.PyJWKClient:
if self._jwks_client is None:
ssl_context = None
if not self.config.verify_tls:
# Disable SSL verification if verify_tls is False
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif self.config.tls_cafile:
# Use custom CA file if provided
ssl_context = ssl.create_default_context(
cafile=self.config.tls_cafile.as_posix(),
)
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
# to the JWK endpoint, we must use the token in the config to authenticate
headers = {}
if self.config.jwks and self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
self._jwks_client = jwt.PyJWKClient(
self.config.jwks.uri if self.config.jwks else None,
cache_keys=True,
max_cached_keys=10,
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
headers=headers,
ssl_context=ssl_context,
)
return self._jwks_client
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
try:
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
jwks_client: jwt.PyJWKClient = self._get_jwks_client()
signing_key = jwks_client.get_signing_key_from_jwt(token)
algorithm = jwt.get_unverified_header(token)["alg"]
claims = jwt.decode(
token,
key_data,
signing_key.key,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
)
# Decode and verify the JWT
claims = jwt.decode(
token,
signing_key.key,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
)
except Exception as exc:
raise ValueError("Invalid JWT token") from exc
@ -201,37 +234,6 @@ class OAuth2TokenAuthProvider(AuthProvider):
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""