mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
test: Update JWKS tests to properly mock authentication
PyJWKClient uses urllib.request.urlopen to fetch JWKS keys, not httpx.AsyncClient.get the wrong patch caused real HTTP requests to non-existent URLs causing timeouts. Closes: #4256 Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
a7c7c72467
commit
2c87f46ca9
1 changed files with 46 additions and 30 deletions
|
|
@ -364,23 +364,6 @@ def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
async def mock_jwks_response(*args, **kwargs):
|
|
||||||
return MockResponse(
|
|
||||||
200,
|
|
||||||
{
|
|
||||||
"keys": [
|
|
||||||
{
|
|
||||||
"kid": "1234567890",
|
|
||||||
"kty": "oct",
|
|
||||||
"alg": "HS256",
|
|
||||||
"use": "sig",
|
|
||||||
"k": base64.b64encode(b"foobarbaz").decode(),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def jwt_token_valid():
|
def jwt_token_valid():
|
||||||
import jwt
|
import jwt
|
||||||
|
|
@ -421,28 +404,60 @@ def mock_jwks_urlopen():
|
||||||
yield mock_urlopen
|
yield mock_urlopen
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_jwks_urlopen_with_auth_required():
|
||||||
|
"""Mock urllib.request.urlopen that requires Bearer token for JWKS requests."""
|
||||||
|
with patch("urllib.request.urlopen") as mock_urlopen:
|
||||||
|
|
||||||
|
def side_effect(request, **kwargs):
|
||||||
|
# Check if Authorization header is present
|
||||||
|
auth_header = request.headers.get("Authorization") if hasattr(request, "headers") else None
|
||||||
|
|
||||||
|
if not auth_header or not auth_header.startswith("Bearer "):
|
||||||
|
# Simulate 401 Unauthorized
|
||||||
|
import urllib.error
|
||||||
|
|
||||||
|
raise urllib.error.HTTPError(
|
||||||
|
url=request.full_url if hasattr(request, "full_url") else "",
|
||||||
|
code=401,
|
||||||
|
msg="Unauthorized",
|
||||||
|
hdrs={},
|
||||||
|
fp=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the JWKS response for PyJWKClient
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.read.return_value = json.dumps(
|
||||||
|
{
|
||||||
|
"keys": [
|
||||||
|
{
|
||||||
|
"kid": "1234567890",
|
||||||
|
"kty": "oct",
|
||||||
|
"alg": "HS256",
|
||||||
|
"use": "sig",
|
||||||
|
"k": base64.b64encode(b"foobarbaz").decode(),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
).encode()
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
mock_urlopen.side_effect = side_effect
|
||||||
|
yield mock_urlopen
|
||||||
|
|
||||||
|
|
||||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
|
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
|
||||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"message": "Authentication successful"}
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, mock_jwks_urlopen, suppress_auth_errors):
|
||||||
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, suppress_auth_errors):
|
|
||||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Invalid JWT token" in response.json()["error"]["message"]
|
assert "Invalid JWT token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
async def mock_auth_jwks_response(*args, **kwargs):
|
|
||||||
if "headers" not in kwargs or "Authorization" not in kwargs["headers"]:
|
|
||||||
return MockResponse(401, {})
|
|
||||||
authz = kwargs["headers"]["Authorization"]
|
|
||||||
if authz != "Bearer my-jwks-token":
|
|
||||||
return MockResponse(401, {})
|
|
||||||
return await mock_jwks_response(args, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def oauth2_app_with_jwks_token():
|
def oauth2_app_with_jwks_token():
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
@ -472,8 +487,9 @@ def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
|
||||||
return TestClient(oauth2_app_with_jwks_token)
|
return TestClient(oauth2_app_with_jwks_token)
|
||||||
|
|
||||||
|
|
||||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
def test_oauth2_with_jwks_token_expected(
|
||||||
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid, suppress_auth_errors):
|
oauth2_client, jwt_token_valid, mock_jwks_urlopen_with_auth_required, suppress_auth_errors
|
||||||
|
):
|
||||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue