mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
fix: replace python-jose with PyJWT for JWT handling (#3756)
# What does this PR do? This commit migrates the authentication system from python-jose to PyJWT to eliminate the dependency on the archived rsa package. The migration includes: - Refactored OAuth2TokenAuthProvider to use PyJWT's PyJWKClient for clean JWKS handling - Removed manual JWKS fetching, caching and key extraction logic in favor of PyJWT's built-in functionality The new implementation is cleaner, more maintainable, and follows PyJWT best practices while maintaining full backward compatibility. ## Test Plan Unit tests. Auth CI. --------- Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
968c364a3e
commit
1136daf310
4 changed files with 93 additions and 86 deletions
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs):
|
|||
|
||||
@pytest.fixture
|
||||
def jwt_token_valid():
|
||||
from jose import jwt
|
||||
import jwt
|
||||
|
||||
return jwt.encode(
|
||||
{
|
||||
|
@ -389,8 +390,30 @@ def jwt_token_valid():
|
|||
)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
|
||||
@pytest.fixture
|
||||
def mock_jwks_urlopen():
|
||||
"""Mock urllib.request.urlopen for PyJWKClient JWKS requests."""
|
||||
with patch("urllib.request.urlopen") as mock_urlopen:
|
||||
# Mock the JWKS response for PyJWKClient
|
||||
mock_response = Mock()
|
||||
mock_response.read.return_value = json.dumps(
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "1234567890",
|
||||
"kty": "oct",
|
||||
"alg": "HS256",
|
||||
"use": "sig",
|
||||
"k": base64.b64encode(b"foobarbaz").decode(),
|
||||
}
|
||||
]
|
||||
}
|
||||
).encode()
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
yield mock_urlopen
|
||||
|
||||
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
|
|||
assert response.status_code == 401
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue