diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 97bdd179d..ea3ff2b64 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -92,7 +92,8 @@ jobs: run: | echo "Waiting for Llama Stack server..." for i in {1..30}; do - if curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://localhost:8321/v1/health | grep -q "OK"; then + # Note: /v1/health does not require authentication + if curl -s -L http://localhost:8321/v1/health | grep -q "OK"; then echo "Llama Stack server is up!" if grep -q "Enabling authentication with provider: ${{ matrix.auth-provider }}" server.log; then echo "Llama Stack server is configured to use ${{ matrix.auth-provider }} auth" @@ -111,4 +112,27 @@ jobs: - name: Test auth run: | - curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers|jq + echo "Testing /v1/version without token (should succeed)..." + if curl -s -L -o /dev/null -w "%{http_code}" http://127.0.0.1:8321/v1/version | grep -q "200"; then + echo "/v1/version accessible without token (200)" + else + echo "/v1/version returned non-200 status without token" + exit 1 + fi + + echo "Testing /v1/providers without token (should fail with 401)..." + if curl -s -L -o /dev/null -w "%{http_code}" http://127.0.0.1:8321/v1/providers | grep -q "401"; then + echo "/v1/providers blocked without token (401)" + else + echo "/v1/providers did not return 401 without token" + exit 1 + fi + + echo "Testing /v1/providers with valid token (should succeed)..." + curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers | jq + if [ $? -eq 0 ]; then + echo "/v1/providers accessible with valid token" + else + echo "/v1/providers failed with valid token" + exit 1 + fi diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index 72f203621..8b0996e69 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -73,7 +73,7 @@ class Inspect(Protocol): """ ... - @webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1) + @webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False) async def health(self) -> HealthInfo: """Get health status. @@ -83,7 +83,7 @@ class Inspect(Protocol): """ ... - @webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1) + @webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False) async def version(self) -> VersionInfo: """Get version. diff --git a/llama_stack/core/server/auth.py b/llama_stack/core/server/auth.py index c98d3bec0..8a4c8956f 100644 --- a/llama_stack/core/server/auth.py +++ b/llama_stack/core/server/auth.py @@ -27,6 +27,11 @@ class AuthenticationMiddleware: 3. Extracts user attributes from the provider's response 4. Makes these attributes available to the route handlers for access control + Unauthenticated Access: + Endpoints can opt out of authentication by setting require_authentication=False + in their @webmethod decorator. This is typically used for operational endpoints + like /health and /version to support monitoring, load balancers, and observability tools. + The middleware supports multiple authentication providers through the AuthProvider interface: - Kubernetes: Validates tokens against the Kubernetes API server - Custom: Validates tokens against a custom endpoint @@ -88,7 +93,26 @@ class AuthenticationMiddleware: async def __call__(self, scope, receive, send): if scope["type"] == "http": - # First, handle authentication + # Find the route and check if authentication is required + path = scope.get("path", "") + method = scope.get("method", hdrs.METH_GET) + + if not hasattr(self, "route_impls"): + self.route_impls = initialize_route_impls(self.impls) + + webmethod = None + try: + _, _, _, webmethod = find_matching_route(method, path, self.route_impls) + except ValueError: + # If no matching endpoint is found, pass here to run auth anyways + pass + + # If webmethod explicitly sets require_authentication=False, allow without auth + if webmethod and webmethod.require_authentication is False: + logger.debug(f"Allowing unauthenticated access to endpoint: {path}") + return await self.app(scope, receive, send) + + # Handle authentication headers = dict(scope.get("headers", [])) auth_header = headers.get(b"authorization", b"").decode() @@ -127,19 +151,7 @@ class AuthenticationMiddleware: ) # Scope-based API access control - path = scope.get("path", "") - method = scope.get("method", hdrs.METH_GET) - - if not hasattr(self, "route_impls"): - self.route_impls = initialize_route_impls(self.impls) - - try: - _, _, _, webmethod = find_matching_route(method, path, self.route_impls) - except ValueError: - # If no matching endpoint is found, pass through to FastAPI - return await self.app(scope, receive, send) - - if webmethod.required_scope: + if webmethod and webmethod.required_scope: user = user_from_scope(scope) if not _has_required_scope(webmethod.required_scope, user): return await self._send_auth_error( diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index c17d6e353..8444d2a34 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -61,6 +61,7 @@ class WebMethod: descriptive_name: str | None = None required_scope: str | None = None deprecated: bool | None = False + require_authentication: bool | None = True CallableT = TypeVar("CallableT", bound=Callable[..., Any]) @@ -77,6 +78,7 @@ def webmethod( descriptive_name: str | None = None, required_scope: str | None = None, deprecated: bool | None = False, + require_authentication: bool | None = True, ) -> Callable[[CallableT], CallableT]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -86,6 +88,7 @@ def webmethod( :param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON. :param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON. :param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer'). + :param require_authentication: Whether this endpoint requires authentication (default True). """ def wrap(func: CallableT) -> CallableT: @@ -100,6 +103,7 @@ def webmethod( descriptive_name=descriptive_name, required_scope=required_scope, deprecated=deprecated, + require_authentication=require_authentication if require_authentication is not None else True, ) # Store all webmethods in a list to support multiple decorators diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 205e0ce65..9dbabe195 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -122,7 +122,7 @@ def mock_impls(): @pytest.fixture -def scope_middleware_with_mocks(mock_auth_endpoint): +def middleware_with_mocks(mock_auth_endpoint): """Create AuthenticationMiddleware with mocked route implementations""" mock_app = AsyncMock() auth_config = AuthenticationConfig( @@ -137,18 +137,20 @@ def scope_middleware_with_mocks(mock_auth_endpoint): # Mock the route_impls to simulate finding routes with required scopes from llama_stack.schema_utils import WebMethod - scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read") - - public_webmethod = WebMethod(route="/test/public", method="GET") + routes = { + ("POST", "/test/scoped"): WebMethod(route="/test/scoped", method="POST", required_scope="test.read"), + ("GET", "/test/public"): WebMethod(route="/test/public", method="GET"), + ("GET", "/health"): WebMethod(route="/health", method="GET", require_authentication=False), + ("GET", "/version"): WebMethod(route="/version", method="GET", require_authentication=False), + ("GET", "/models/list"): WebMethod(route="/models/list", method="GET", require_authentication=True), + } # Mock the route finding logic def mock_find_matching_route(method, path, route_impls): - if method == "POST" and path == "/test/scoped": - return None, {}, "/test/scoped", scoped_webmethod - elif method == "GET" and path == "/test/public": - return None, {}, "/test/public", public_webmethod - else: - raise ValueError("No matching route") + webmethod = routes.get((method, path)) + if webmethod: + return None, {}, path, webmethod + raise ValueError("No matching route") import llama_stack.core.server.auth @@ -659,9 +661,9 @@ def test_valid_introspection_with_custom_mapping_authentication( # Scope-based authorization tests @patch("httpx.AsyncClient.post", new=mock_post_success_with_scope) -async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key): +async def test_scope_authorization_success(middleware_with_mocks, valid_api_key): """Test that user with required scope can access protected endpoint""" - middleware, mock_app = scope_middleware_with_mocks + middleware, mock_app = middleware_with_mocks mock_receive = AsyncMock() mock_send = AsyncMock() @@ -680,9 +682,9 @@ async def test_scope_authorization_success(scope_middleware_with_mocks, valid_ap @patch("httpx.AsyncClient.post", new=mock_post_success_no_scope) -async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key): +async def test_scope_authorization_denied(middleware_with_mocks, valid_api_key): """Test that user without required scope gets 403 access denied""" - middleware, mock_app = scope_middleware_with_mocks + middleware, mock_app = middleware_with_mocks mock_receive = AsyncMock() mock_send = AsyncMock() @@ -710,9 +712,9 @@ async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api @patch("httpx.AsyncClient.post", new=mock_post_success_no_scope) -async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key): +async def test_public_endpoint_no_scope_required(middleware_with_mocks, valid_api_key): """Test that public endpoints work without specific scopes""" - middleware, mock_app = scope_middleware_with_mocks + middleware, mock_app = middleware_with_mocks mock_receive = AsyncMock() mock_send = AsyncMock() @@ -730,9 +732,9 @@ async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, va mock_send.assert_not_called() -async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks): +async def test_scope_authorization_no_auth_disabled(middleware_with_mocks): """Test that when auth is disabled (no user), scope checks are bypassed""" - middleware, mock_app = scope_middleware_with_mocks + middleware, mock_app = middleware_with_mocks mock_receive = AsyncMock() mock_send = AsyncMock() @@ -907,3 +909,41 @@ def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mo request_body = call_args[1]["json"] assert request_body["apiVersion"] == "authentication.k8s.io/v1" assert request_body["kind"] == "SelfSubjectReview" + + +async def test_unauthenticated_endpoint_access_health(middleware_with_mocks): + """Test that /health endpoints can be accessed without authentication""" + middleware, mock_app = middleware_with_mocks + + # Test request to /health without auth header (level prefix v1 is added by router) + scope = {"type": "http", "path": "/health", "headers": [], "method": "GET"} + receive = AsyncMock() + send = AsyncMock() + + # Should allow the request to proceed without authentication + await middleware(scope, receive, send) + + # Verify that the request was passed to the app + mock_app.assert_called_once_with(scope, receive, send) + + # Verify that no error response was sent + assert not any(call[0][0].get("status") == 401 for call in send.call_args_list) + + +async def test_unauthenticated_endpoint_denied_for_other_paths(middleware_with_mocks): + """Test that endpoints other than /health and /version require authentication""" + middleware, mock_app = middleware_with_mocks + + # Test request to /models/list without auth header + scope = {"type": "http", "path": "/models/list", "headers": [], "method": "GET"} + receive = AsyncMock() + send = AsyncMock() + + # Should return 401 error + await middleware(scope, receive, send) + + # Verify that the app was NOT called + mock_app.assert_not_called() + + # Verify that a 401 error response was sent + assert any(call[0][0].get("status") == 401 for call in send.call_args_list)