From 1136daf310b6f9cf5215fc682e0b37d242b2ebdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 14 Oct 2025 09:35:48 +0200 Subject: [PATCH] fix: replace python-jose with PyJWT for JWT handling (#3756) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This commit migrates the authentication system from python-jose to PyJWT to eliminate the dependency on the archived rsa package. The migration includes: - Refactored OAuth2TokenAuthProvider to use PyJWT's PyJWKClient for clean JWKS handling - Removed manual JWKS fetching, caching and key extraction logic in favor of PyJWT's built-in functionality The new implementation is cleaner, more maintainable, and follows PyJWT best practices while maintaining full backward compatibility. ## Test Plan Unit tests. Auth CI. --------- Signed-off-by: Sébastien Han --- llama_stack/core/server/auth_providers.py | 94 ++++++++++++----------- pyproject.toml | 2 +- tests/unit/server/test_auth.py | 34 ++++++-- uv.lock | 49 ++++-------- 4 files changed, 93 insertions(+), 86 deletions(-) diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 38188c49a..05a21c8d4 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -5,13 +5,11 @@ # the root directory of this source tree. import ssl -import time from abc import ABC, abstractmethod -from asyncio import Lock from urllib.parse import parse_qs, urljoin, urlparse import httpx -from jose import jwt +import jwt from pydantic import BaseModel, Field from llama_stack.apis.common.errors import TokenValidationError @@ -98,9 +96,7 @@ class OAuth2TokenAuthProvider(AuthProvider): def __init__(self, config: OAuth2TokenAuthConfig): self.config = config - self._jwks_at: float = 0.0 - self._jwks: dict[str, str] = {} - self._jwks_lock = Lock() + self._jwks_client: jwt.PyJWKClient | None = None async def validate_token(self, token: str, scope: dict | None = None) -> User: if self.config.jwks: @@ -109,23 +105,60 @@ class OAuth2TokenAuthProvider(AuthProvider): return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") + def _get_jwks_client(self) -> jwt.PyJWKClient: + if self._jwks_client is None: + ssl_context = None + if not self.config.verify_tls: + # Disable SSL verification if verify_tls is False + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif self.config.tls_cafile: + # Use custom CA file if provided + ssl_context = ssl.create_default_context( + cafile=self.config.tls_cafile.as_posix(), + ) + # If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults) + + # Prepare headers for JWKS request - this is needed for Kubernetes to authenticate + # to the JWK endpoint, we must use the token in the config to authenticate + headers = {} + if self.config.jwks and self.config.jwks.token: + headers["Authorization"] = f"Bearer {self.config.jwks.token}" + + self._jwks_client = jwt.PyJWKClient( + self.config.jwks.uri if self.config.jwks else None, + cache_keys=True, + max_cached_keys=10, + lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None, + headers=headers, + ssl_context=ssl_context, + ) + return self._jwks_client + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the JWT token.""" - await self._refresh_jwks() - try: - header = jwt.get_unverified_header(token) - kid = header["kid"] - if kid not in self._jwks: - raise ValueError(f"Unknown key ID: {kid}") - key_data = self._jwks[kid] - algorithm = header.get("alg", "RS256") + jwks_client: jwt.PyJWKClient = self._get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(token) + algorithm = jwt.get_unverified_header(token)["alg"] claims = jwt.decode( token, - key_data, + signing_key.key, algorithms=[algorithm], audience=self.config.audience, issuer=self.config.issuer, + options={"verify_exp": True, "verify_aud": True, "verify_iss": True}, + ) + + # Decode and verify the JWT + claims = jwt.decode( + token, + signing_key.key, + algorithms=[algorithm], + audience=self.config.audience, + issuer=self.config.issuer, + options={"verify_exp": True, "verify_aud": True, "verify_iss": True}, ) except Exception as exc: raise ValueError("Invalid JWT token") from exc @@ -201,37 +234,6 @@ class OAuth2TokenAuthProvider(AuthProvider): else: return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header" - async def _refresh_jwks(self) -> None: - """ - Refresh the JWKS cache. - - This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`). - If the cache is expired, we refresh the JWKS from the JWKS URI. - - Notes: for Kubernetes which doesn't fully implement the OIDC protocol: - * It doesn't have user authentication flows - * It doesn't have refresh tokens - """ - async with self._jwks_lock: - if self.config.jwks is None: - raise ValueError("JWKS is not configured") - if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: - headers = {} - if self.config.jwks.token: - headers["Authorization"] = f"Bearer {self.config.jwks.token}" - verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls - async with httpx.AsyncClient(verify=verify) as client: - res = await client.get(self.config.jwks.uri, timeout=5, headers=headers) - res.raise_for_status() - jwks_data = res.json()["keys"] - updated = {} - for k in jwks_data: - kid = k["kid"] - # Store the entire key object as it may be needed for different algorithms - updated[kid] = k - self._jwks = updated - self._jwks_at = time.time() - class CustomAuthProvider(AuthProvider): """Custom authentication provider that uses an external endpoint.""" diff --git a/pyproject.toml b/pyproject.toml index 81997c249..d55de794d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "openai>=1.107", # for expires_after support "prompt-toolkit", "python-dotenv", - "python-jose[cryptography]", + "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support. "pydantic>=2.11.9", "rich", "starlette", diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 9dbabe195..04ae89db8 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import base64 -from unittest.mock import AsyncMock, patch +import json +from unittest.mock import AsyncMock, Mock, patch import pytest from fastapi import FastAPI @@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs): @pytest.fixture def jwt_token_valid(): - from jose import jwt + import jwt return jwt.encode( { @@ -389,8 +390,30 @@ def jwt_token_valid(): ) -@patch("httpx.AsyncClient.get", new=mock_jwks_response) -def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid): +@pytest.fixture +def mock_jwks_urlopen(): + """Mock urllib.request.urlopen for PyJWKClient JWKS requests.""" + with patch("urllib.request.urlopen") as mock_urlopen: + # Mock the JWKS response for PyJWKClient + mock_response = Mock() + mock_response.read.return_value = json.dumps( + { + "keys": [ + { + "kid": "1234567890", + "kty": "oct", + "alg": "HS256", + "use": "sig", + "k": base64.b64encode(b"foobarbaz").decode(), + } + ] + } + ).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + yield mock_urlopen + + +def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen): response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} @@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid): assert response.status_code == 401 -@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) -def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid): +def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen): response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} diff --git a/uv.lock b/uv.lock index 0fcb02768..747e82aaa 100644 --- a/uv.lock +++ b/uv.lock @@ -874,18 +874,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" }, ] -[[package]] -name = "ecdsa" -version = "0.19.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, -] - [[package]] name = "eval-type-backport" version = "0.2.2" @@ -1787,8 +1775,8 @@ dependencies = [ { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic" }, + { name = "pyjwt", extra = ["crypto"] }, { name = "python-dotenv" }, - { name = "python-jose", extra = ["cryptography"] }, { name = "python-multipart" }, { name = "rich" }, { name = "sqlalchemy", extra = ["asyncio"] }, @@ -1910,8 +1898,8 @@ requires-dist = [ { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic", specifier = ">=2.11.9" }, + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.0" }, { name = "python-dotenv" }, - { name = "python-jose", extras = ["cryptography"] }, { name = "python-multipart", specifier = ">=0.0.20" }, { name = "rich" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, @@ -3558,6 +3546,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pymilvus" version = "2.6.1" @@ -3747,25 +3749,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, ] -[[package]] -name = "python-jose" -version = "3.5.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "ecdsa" }, - { name = "pyasn1" }, - { name = "rsa" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/77/3a1c9039db7124eb039772b935f2244fbb73fc8ee65b9acf2375da1c07bf/python_jose-3.5.0.tar.gz", hash = "sha256:fb4eaa44dbeb1c26dcc69e4bd7ec54a1cb8dd64d3b4d81ef08d90ff453f2b01b", size = 92726, upload-time = "2025-05-28T17:31:54.288Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/c3/0bd11992072e6a1c513b16500a5d07f91a24017c5909b02c72c62d7ad024/python_jose-3.5.0-py2.py3-none-any.whl", hash = "sha256:abd1202f23d34dfad2c3d28cb8617b90acf34132c7afd60abd0b0b7d3cb55771", size = 34624, upload-time = "2025-05-28T17:31:52.802Z" }, -] - -[package.optional-dependencies] -cryptography = [ - { name = "cryptography" }, -] - [[package]] name = "python-multipart" version = "0.0.20"