mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
test: suppress expected error logs in SSE test
Use pytest's caplog fixture to suppress ERROR logs when deliberately triggering errors in test_sse_generator_error_before_response_starts. This keeps test output clean while still validating error handling behavior.
This commit is contained in:
parent
cb2185b936
commit
62626d4d94
5 changed files with 60 additions and 17 deletions
|
|
@ -4,9 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_sessionstart(session) -> None:
|
||||
if "LLAMA_STACK_LOGGING" not in os.environ:
|
||||
|
|
@ -17,4 +20,10 @@ def pytest_sessionstart(session) -> None:
|
|||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def suppress_httpx_logs(caplog):
|
||||
"""Suppress httpx INFO logs for all unit tests"""
|
||||
caplog.set_level(logging.WARNING, logger="httpx")
|
||||
|
||||
|
||||
pytest_plugins = ["tests.unit.fixtures"]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import base64
|
||||
import json
|
||||
import logging # allow-direct-logging
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -27,6 +28,13 @@ from llama_stack.core.server.auth_providers import (
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suppress_auth_errors(caplog):
|
||||
"""Suppress expected ERROR/WARNING logs for tests that deliberately trigger authentication errors"""
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth")
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth_providers")
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
|
|
@ -237,20 +245,20 @@ def test_valid_http_authentication(http_client, valid_api_key):
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_failure)
|
||||
def test_invalid_http_authentication(http_client, invalid_api_key):
|
||||
def test_invalid_http_authentication(http_client, invalid_api_key, suppress_auth_errors):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_exception)
|
||||
def test_http_auth_service_error(http_client, valid_api_key):
|
||||
def test_http_auth_service_error(http_client, valid_api_key, suppress_auth_errors):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication service error" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint):
|
||||
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint, suppress_auth_errors):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(200, {"message": "Authentication successful"})
|
||||
mock_post.return_value = mock_response
|
||||
|
|
@ -420,7 +428,7 @@ def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_u
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
|
||||
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, suppress_auth_errors):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid JWT token" in response.json()["error"]["message"]
|
||||
|
|
@ -465,7 +473,7 @@ def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
|
||||
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid, suppress_auth_errors):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
|
@ -726,21 +734,21 @@ def test_valid_introspection_authentication(introspection_client, valid_api_key)
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_inactive)
|
||||
def test_inactive_introspection_authentication(introspection_client, invalid_api_key):
|
||||
def test_inactive_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token not active" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_invalid)
|
||||
def test_invalid_introspection_authentication(introspection_client, invalid_api_key):
|
||||
def test_invalid_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Not JSON" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_failed)
|
||||
def test_failed_introspection_authentication(introspection_client, invalid_api_key):
|
||||
def test_failed_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token introspection failed: 500" in response.json()["error"]["message"]
|
||||
|
|
@ -957,20 +965,22 @@ def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_toke
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
|
||||
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
|
||||
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token, suppress_auth_errors):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
|
||||
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
|
||||
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token, suppress_auth_errors):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token validation failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
|
||||
def test_kubernetes_auth_request_payload(
|
||||
kubernetes_auth_client, valid_token, mock_kubernetes_api_server, suppress_auth_errors
|
||||
):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging # allow-direct-logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
|
|
@ -15,6 +16,13 @@ from llama_stack.core.datatypes import AuthenticationConfig, AuthProviderType, G
|
|||
from llama_stack.core.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suppress_auth_errors(caplog):
|
||||
"""Suppress expected ERROR logs for tests that deliberately trigger authentication errors"""
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth")
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth_providers")
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
|
|
@ -119,7 +127,7 @@ def test_authenticated_endpoint_with_valid_github_token(mock_client_class, githu
|
|||
|
||||
|
||||
@patch("llama_stack.core.server.auth_providers.httpx.AsyncClient")
|
||||
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
|
||||
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client, suppress_auth_errors):
|
||||
"""Test accessing protected endpoint with invalid GitHub token"""
|
||||
# Mock the GitHub API to return 401 Unauthorized
|
||||
mock_client = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging # allow-direct-logging
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
|
@ -17,6 +18,12 @@ from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreCo
|
|||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suppress_quota_warnings(caplog):
|
||||
"""Suppress expected WARNING logs for SQLite backend and quota exceeded"""
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.quota")
|
||||
|
||||
|
||||
class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware.
|
||||
|
|
@ -70,13 +77,13 @@ def auth_app(tmp_path, request):
|
|||
return app
|
||||
|
||||
|
||||
def test_authenticated_quota_allows_up_to_limit(auth_app):
|
||||
def test_authenticated_quota_allows_up_to_limit(auth_app, suppress_quota_warnings):
|
||||
client = TestClient(auth_app)
|
||||
assert client.get("/test").status_code == 200
|
||||
assert client.get("/test").status_code == 200
|
||||
|
||||
|
||||
def test_authenticated_quota_blocks_after_limit(auth_app):
|
||||
def test_authenticated_quota_blocks_after_limit(auth_app, suppress_quota_warnings):
|
||||
client = TestClient(auth_app)
|
||||
client.get("/test")
|
||||
client.get("/test")
|
||||
|
|
@ -85,7 +92,7 @@ def test_authenticated_quota_blocks_after_limit(auth_app):
|
|||
assert resp.json()["error"]["message"] == "Quota exceeded"
|
||||
|
||||
|
||||
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
||||
def test_anonymous_quota_allows_up_to_limit(tmp_path, request, suppress_quota_warnings):
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
|
|
@ -107,7 +114,7 @@ def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
|||
assert client.get("/test").status_code == 200
|
||||
|
||||
|
||||
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
||||
def test_anonymous_quota_blocks_after_limit(tmp_path, request, suppress_quota_warnings):
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
|
|
|
|||
|
|
@ -5,12 +5,21 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging # allow-direct-logging
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suppress_sse_errors(caplog):
|
||||
"""Suppress expected ERROR logs for tests that deliberately trigger SSE errors"""
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.server")
|
||||
|
||||
|
||||
async def test_sse_generator_basic():
|
||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||
async def async_event_gen():
|
||||
|
|
@ -70,7 +79,7 @@ async def test_sse_generator_client_disconnected_before_response_starts():
|
|||
assert len(seen_events) == 0
|
||||
|
||||
|
||||
async def test_sse_generator_error_before_response_starts():
|
||||
async def test_sse_generator_error_before_response_starts(suppress_sse_errors):
|
||||
# Raise an error before the response starts
|
||||
async def async_event_gen():
|
||||
raise Exception("Test error")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue