diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 57a552514..2db60c91c 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -364,23 +364,6 @@ def test_invalid_auth_header_format_oauth2(oauth2_client): assert "Invalid Authorization header format" in response.json()["error"]["message"] -async def mock_jwks_response(*args, **kwargs): - return MockResponse( - 200, - { - "keys": [ - { - "kid": "1234567890", - "kty": "oct", - "alg": "HS256", - "use": "sig", - "k": base64.b64encode(b"foobarbaz").decode(), - } - ] - }, - ) - - @pytest.fixture def jwt_token_valid(): import jwt @@ -421,28 +404,60 @@ def mock_jwks_urlopen(): yield mock_urlopen +@pytest.fixture +def mock_jwks_urlopen_with_auth_required(): + """Mock urllib.request.urlopen that requires Bearer token for JWKS requests.""" + with patch("urllib.request.urlopen") as mock_urlopen: + + def side_effect(request, **kwargs): + # Check if Authorization header is present + auth_header = request.headers.get("Authorization") if hasattr(request, "headers") else None + + if not auth_header or not auth_header.startswith("Bearer "): + # Simulate 401 Unauthorized + import urllib.error + + raise urllib.error.HTTPError( + url=request.full_url if hasattr(request, "full_url") else "", + code=401, + msg="Unauthorized", + hdrs={}, + fp=None, + ) + + # Mock the JWKS response for PyJWKClient + mock_response = Mock() + mock_response.read.return_value = json.dumps( + { + "keys": [ + { + "kid": "1234567890", + "kty": "oct", + "alg": "HS256", + "use": "sig", + "k": base64.b64encode(b"foobarbaz").decode(), + } + ] + } + ).encode() + return mock_response + + mock_urlopen.side_effect = side_effect + yield mock_urlopen + + def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen): response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} -@patch("httpx.AsyncClient.get", new=mock_jwks_response) -def test_invalid_oauth2_authentication(oauth2_client, invalid_token, suppress_auth_errors): +def test_invalid_oauth2_authentication(oauth2_client, invalid_token, mock_jwks_urlopen, suppress_auth_errors): response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"}) assert response.status_code == 401 assert "Invalid JWT token" in response.json()["error"]["message"] -async def mock_auth_jwks_response(*args, **kwargs): - if "headers" not in kwargs or "Authorization" not in kwargs["headers"]: - return MockResponse(401, {}) - authz = kwargs["headers"]["Authorization"] - if authz != "Bearer my-jwks-token": - return MockResponse(401, {}) - return await mock_jwks_response(args, kwargs) - - @pytest.fixture def oauth2_app_with_jwks_token(): app = FastAPI() @@ -472,8 +487,9 @@ def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token): return TestClient(oauth2_app_with_jwks_token) -@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) -def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid, suppress_auth_errors): +def test_oauth2_with_jwks_token_expected( + oauth2_client, jwt_token_valid, mock_jwks_urlopen_with_auth_required, suppress_auth_errors +): response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) assert response.status_code == 401