fix: replace python-jose with PyJWT for JWT handling

This commit migrates the authentication system from python-jose to PyJWT
to eliminate the dependency on the archived rsa package. The migration
includes:

- Refactored OAuth2TokenAuthProvider to use PyJWT's PyJWKClient for
  clean JWKS handling
- Removed manual JWKS fetching, caching and key extraction logic in
  favor of PyJWT's built-in functionality

The new implementation is cleaner, more maintainable, and follows PyJWT
best practices while maintaining full backward compatibility.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-10-09 17:53:42 +02:00
parent ecc8a554d2
commit 7fa2bae3e2
No known key found for this signature in database
4 changed files with 60 additions and 86 deletions

View file

@ -5,13 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import ssl import ssl
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Lock
from urllib.parse import parse_qs, urljoin, urlparse from urllib.parse import parse_qs, urljoin, urlparse
import httpx import httpx
from jose import jwt import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError from llama_stack.apis.common.errors import TokenValidationError
@ -98,9 +96,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
def __init__(self, config: OAuth2TokenAuthConfig): def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config self.config = config
self._jwks_at: float = 0.0 self._jwks_client: jwt.PyJWKClient | None = None
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> User: async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks: if self.config.jwks:
@ -111,21 +107,30 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token.""" """Validate a token using the JWT token."""
await self._refresh_jwks() if self.config.jwks is None:
raise ValueError("JWKS is not configured")
try: try:
header = jwt.get_unverified_header(token) # Initialize PyJWKClient if not already done
kid = header["kid"] if self._jwks_client is None:
if kid not in self._jwks: self._jwks_client = jwt.PyJWKClient(
raise ValueError(f"Unknown key ID: {kid}") self.config.jwks.uri,
key_data = self._jwks[kid] cache_keys=True,
algorithm = header.get("alg", "RS256") max_cached_keys=10,
lifespan=3600, # 1 hour cache
)
# Get the signing key from the JWT token
signing_key = self._jwks_client.get_signing_key_from_jwt(token)
# Decode and verify the JWT
claims = jwt.decode( claims = jwt.decode(
token, token,
key_data, signing_key.key,
algorithms=[algorithm], algorithms=["RS256", "HS256", "ES256"], # Common algorithms
audience=self.config.audience, audience=self.config.audience,
issuer=self.config.issuer, issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
) )
except Exception as exc: except Exception as exc:
raise ValueError("Invalid JWT token") from exc raise ValueError("Invalid JWT token") from exc
@ -201,37 +206,6 @@ class OAuth2TokenAuthProvider(AuthProvider):
else: else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header" return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProvider(AuthProvider): class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint.""" """Custom authentication provider that uses an external endpoint."""

View file

@ -34,7 +34,7 @@ dependencies = [
"openai>=1.107", # for expires_after support "openai>=1.107", # for expires_after support
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
"python-jose[cryptography]", "PyJWT>=2.8.0",
"pydantic>=2.11.9", "pydantic>=2.11.9",
"rich", "rich",
"starlette", "starlette",

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
from unittest.mock import AsyncMock, patch import json
from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs):
@pytest.fixture @pytest.fixture
def jwt_token_valid(): def jwt_token_valid():
from jose import jwt import jwt
return jwt.encode( return jwt.encode(
{ {
@ -389,8 +390,30 @@ def jwt_token_valid():
) )
@patch("httpx.AsyncClient.get", new=mock_jwks_response) @pytest.fixture
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid): def mock_jwks_urlopen():
"""Mock urllib.request.urlopen for PyJWKClient JWKS requests."""
with patch("urllib.request.urlopen") as mock_urlopen:
# Mock the JWKS response for PyJWKClient
mock_response = Mock()
mock_response.read.return_value = json.dumps(
{
"keys": [
{
"kid": "1234567890",
"kty": "oct",
"alg": "HS256",
"use": "sig",
"k": base64.b64encode(b"foobarbaz").decode(),
}
]
}
).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
yield mock_urlopen
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"} assert response.json() == {"message": "Authentication successful"}
@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
assert response.status_code == 401 assert response.status_code == 401
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen):
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"} assert response.json() == {"message": "Authentication successful"}

44
uv.lock generated
View file

@ -874,18 +874,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" }, { url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" },
] ]
[[package]]
name = "ecdsa"
version = "0.19.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "six" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
]
[[package]] [[package]]
name = "eval-type-backport" name = "eval-type-backport"
version = "0.2.2" version = "0.2.2"
@ -1787,8 +1775,8 @@ dependencies = [
{ name = "pillow" }, { name = "pillow" },
{ name = "prompt-toolkit" }, { name = "prompt-toolkit" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "pyjwt" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" }, { name = "python-multipart" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlalchemy", extra = ["asyncio"] },
@ -1910,8 +1898,8 @@ requires-dist = [
{ name = "pillow" }, { name = "pillow" },
{ name = "prompt-toolkit" }, { name = "prompt-toolkit" },
{ name = "pydantic", specifier = ">=2.11.9" }, { name = "pydantic", specifier = ">=2.11.9" },
{ name = "pyjwt", specifier = ">=2.8.0" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extras = ["cryptography"] },
{ name = "python-multipart", specifier = ">=0.0.20" }, { name = "python-multipart", specifier = ">=0.0.20" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
@ -3558,6 +3546,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
] ]
[[package]]
name = "pyjwt"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
]
[[package]] [[package]]
name = "pymilvus" name = "pymilvus"
version = "2.6.1" version = "2.6.1"
@ -3747,25 +3744,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" },
] ]
[[package]]
name = "python-jose"
version = "3.5.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ecdsa" },
{ name = "pyasn1" },
{ name = "rsa" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c6/77/3a1c9039db7124eb039772b935f2244fbb73fc8ee65b9acf2375da1c07bf/python_jose-3.5.0.tar.gz", hash = "sha256:fb4eaa44dbeb1c26dcc69e4bd7ec54a1cb8dd64d3b4d81ef08d90ff453f2b01b", size = 92726, upload-time = "2025-05-28T17:31:54.288Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d9/c3/0bd11992072e6a1c513b16500a5d07f91a24017c5909b02c72c62d7ad024/python_jose-3.5.0-py2.py3-none-any.whl", hash = "sha256:abd1202f23d34dfad2c3d28cb8617b90acf34132c7afd60abd0b0b7d3cb55771", size = 34624, upload-time = "2025-05-28T17:31:52.802Z" },
]
[package.optional-dependencies]
cryptography = [
{ name = "cryptography" },
]
[[package]] [[package]]
name = "python-multipart" name = "python-multipart"
version = "0.0.20" version = "0.0.20"