From ebea3c8277dc3368aa5db4fd3a682ef6d7327273 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 24 Jul 2025 14:56:17 -0700 Subject: [PATCH] api access - Create BaseServerMiddleware base class for server middleware - Refactor TracingMiddleware to extend BaseServerMiddleware - Consolidate route matching logic in base class - Update server.py to use user_from_scope utility - Add required_scope parameter to WebMethod in schema_utils.py - Create AccessControlMiddleware with simplified scope checking - Update telemetry API to use required_scope protection - Add comprehensive test coverage for access control logic - Integrate access control middleware into server setup - Rename AccessControlMiddleware to AuthorizationMiddleware for better clarity - Update imports and references in server.py and tests - Keep the same functionality and API - Merge authorization logic directly into AuthenticationMiddleware - Remove separate access_control.py file - Update middleware setup in server.py to use single middleware - Rename and update tests to test the merged functionality - AuthenticationMiddleware now handles both authentication and authorization --- docs/source/distributions/configuration.md | 41 ++++ llama_stack/apis/telemetry/telemetry.py | 16 +- llama_stack/distribution/request_headers.py | 12 ++ llama_stack/distribution/server/auth.py | 50 ++++- llama_stack/distribution/server/server.py | 37 ++-- llama_stack/schema_utils.py | 4 + tests/unit/server/test_auth.py | 207 +++++++++++++++++++- 7 files changed, 331 insertions(+), 36 deletions(-) diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 6362effe8..775749dd6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -504,6 +504,47 @@ created by users sharing a team with them: description: any user has read access to any resource created by a user with the same team ``` +#### API Endpoint Authorization with Scopes + +In addition to resource-based access control, Llama Stack supports endpoint-level authorization using OAuth 2.0 style scopes. When authentication is enabled, specific API endpoints require users to have particular scopes in their authentication token. + +**Scope-Gated APIs:** +The following APIs are currently gated by scopes: + +- **Telemetry API** (scope: `telemetry.read`): + - `POST /telemetry/traces` - Query traces + - `GET /telemetry/traces/{trace_id}` - Get trace by ID + - `GET /telemetry/traces/{trace_id}/spans/{span_id}` - Get span by ID + - `POST /telemetry/spans/{span_id}/tree` - Get span tree + - `POST /telemetry/spans` - Query spans + - `POST /telemetry/metrics/{metric_name}` - Query metrics + +**Authentication Configuration:** + +For **JWT/OAuth2 providers**, scopes should be included in the JWT's claims: +```json +{ + "sub": "user123", + "scope": "telemetry.read", + "aud": "llama-stack" +} +``` + +For **custom authentication providers**, the endpoint must return user attributes including the `scopes` array: +```json +{ + "principal": "user123", + "attributes": { + "scopes": ["telemetry.read"] + } +} +``` + +**Behavior:** +- Users without the required scope receive a 403 Forbidden response +- When authentication is disabled, scope checks are bypassed +- Endpoints without `required_scope` work normally for all authenticated users + ### Quota Configuration The `quota` section allows you to enable server-side request throttling for both diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index d621e601e..96b317c29 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho # Add this constant near the top of the file, after the imports DEFAULT_TTL_DAYS = 7 +REQUIRED_SCOPE = "telemetry.read" + @json_schema_type class SpanStatus(Enum): @@ -259,7 +261,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/traces", method="POST") + @webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE) async def query_traces( self, attribute_filters: list[QueryCondition] | None = None, @@ -277,7 +279,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") + @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE) async def get_trace(self, trace_id: str) -> Trace: """Get a trace by its ID. @@ -286,7 +288,9 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET") + @webmethod( + route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE + ) async def get_span(self, trace_id: str, span_id: str) -> Span: """Get a span by its ID. @@ -296,7 +300,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST") + @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE) async def get_span_tree( self, span_id: str, @@ -312,7 +316,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/spans", method="POST") + @webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE) async def query_spans( self, attribute_filters: list[QueryCondition], @@ -345,7 +349,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/metrics/{metric_name}", method="POST") + @webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE) async def query_metrics( self, metric_name: str, diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 81d494e04..509c2be44 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None: if not provider_data: return None return provider_data.get("__authenticated_user") + + +def user_from_scope(scope: dict) -> User | None: + """Create a User object from ASGI scope data (set by authentication middleware)""" + user_attributes = scope.get("user_attributes", {}) + principal = scope.get("principal", "") + + # auth not enabled + if not principal and not user_attributes: + return None + + return User(principal=principal, attributes=user_attributes) diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index fadbf7b49..87c1a2ab6 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -7,9 +7,12 @@ import json import httpx +from aiohttp import hdrs -from llama_stack.distribution.datatypes import AuthenticationConfig +from llama_stack.distribution.datatypes import AuthenticationConfig, User +from llama_stack.distribution.request_headers import user_from_scope from llama_stack.distribution.server.auth_providers import create_auth_provider +from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -78,12 +81,14 @@ class AuthenticationMiddleware: access resources that don't have access_attributes defined. """ - def __init__(self, app, auth_config: AuthenticationConfig): + def __init__(self, app, auth_config: AuthenticationConfig, impls): self.app = app + self.impls = impls self.auth_provider = create_auth_provider(auth_config) async def __call__(self, scope, receive, send): if scope["type"] == "http": + # First, handle authentication headers = dict(scope.get("headers", [])) auth_header = headers.get(b"authorization", b"").decode() @@ -121,15 +126,50 @@ class AuthenticationMiddleware: f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes" ) + # Scope-based API access control + path = scope.get("path", "") + method = scope.get("method", hdrs.METH_GET) + + if not hasattr(self, "route_impls"): + self.route_impls = initialize_route_impls(self.impls) + + try: + _, _, _, webmethod = find_matching_route(method, path, self.route_impls) + except ValueError: + # If no matching endpoint is found, pass through to FastAPI + return await self.app(scope, receive, send) + + if webmethod.required_scope: + user = user_from_scope(scope) + if not _has_required_scope(webmethod.required_scope, user): + return await self._send_auth_error( + send, + f"Access denied: user does not have required scope: {webmethod.required_scope}", + status=403, + ) + return await self.app(scope, receive, send) - async def _send_auth_error(self, send, message): + async def _send_auth_error(self, send, message, status=401): await send( { "type": "http.response.start", - "status": 401, + "status": status, "headers": [[b"content-type", b"application/json"]], } ) - error_msg = json.dumps({"error": {"message": message}}).encode() + error_key = "message" if status == 401 else "detail" + error_msg = json.dumps({"error": {error_key: message}}).encode() await send({"type": "http.response.body", "body": error_msg}) + + +def _has_required_scope(required_scope: str, user: User | None) -> bool: + # if no user, assume auth is not enabled + if not user: + return True + + if not user.attributes: + return False + + user_scopes = user.attributes.get("scopes", []) + return required_scope in user_scopes diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 199875204..26ea5f90c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -41,7 +41,11 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.external import ExternalApiSpec, load_external_apis -from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context +from llama_stack.distribution.request_headers import ( + PROVIDER_DATA_VAR, + request_provider_data_context, + user_from_scope, +) from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.server.routes import ( find_matching_route, @@ -223,9 +227,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: @functools.wraps(func) async def route_handler(request: Request, **kwargs): # Get auth attributes from the request scope - user_attributes = request.scope.get("user_attributes", {}) - principal = request.scope.get("principal", "") - user = User(principal=principal, attributes=user_attributes) + user = user_from_scope(request.scope) await log_request_pre_validation(request) @@ -437,10 +439,21 @@ def main(args: argparse.Namespace | None = None): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) - # Add authentication middleware if configured + try: + # Create and set the event loop that will be used for both construction and server runtime + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Construct the stack in the persistent event loop + impls = loop.run_until_complete(construct_stack(config)) + + except InvalidProviderError as e: + logger.error(f"Error: {str(e)}") + sys.exit(1) + if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") - app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) + app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls) else: if config.server.quota: quota = config.server.quota @@ -471,18 +484,6 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) - try: - # Create and set the event loop that will be used for both construction and server runtime - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Construct the stack in the persistent event loop - impls = loop.run_until_complete(construct_stack(config)) - - except InvalidProviderError as e: - logger.error(f"Error: {str(e)}") - sys.exit(1) - if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index 694de333e..93382a881 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -22,6 +22,7 @@ class WebMethod: # A descriptive name of the corresponding span created by tracing descriptive_name: str | None = None experimental: bool | None = False + required_scope: str | None = None T = TypeVar("T", bound=Callable[..., Any]) @@ -36,6 +37,7 @@ def webmethod( raw_bytes_request_body: bool | None = False, descriptive_name: str | None = None, experimental: bool | None = False, + required_scope: str | None = None, ) -> Callable[[T], T]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -45,6 +47,7 @@ def webmethod( :param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON. :param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON. :param experimental: True if the operation is experimental and subject to change. + :param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer'). """ def wrap(func: T) -> T: @@ -57,6 +60,7 @@ def webmethod( raw_bytes_request_body=raw_bytes_request_body, descriptive_name=descriptive_name, experimental=experimental, + required_scope=required_scope, ) return func diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 7012a7f17..adf0140e2 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -19,7 +19,8 @@ from llama_stack.distribution.datatypes import ( OAuth2JWKSConfig, OAuth2TokenAuthConfig, ) -from llama_stack.distribution.server.auth import AuthenticationMiddleware +from llama_stack.distribution.request_headers import User +from llama_stack.distribution.server.auth import AuthenticationMiddleware, _has_required_scope from llama_stack.distribution.server.auth_providers import ( get_attributes_from_claims, ) @@ -73,7 +74,7 @@ def http_app(mock_auth_endpoint): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -111,7 +112,50 @@ def mock_http_middleware(mock_auth_endpoint): ), access_policy=[], ) - return AuthenticationMiddleware(mock_app, auth_config), mock_app + return AuthenticationMiddleware(mock_app, auth_config, {}), mock_app + + +@pytest.fixture +def mock_impls(): + """Mock implementations for scope testing""" + return {} + + +@pytest.fixture +def scope_middleware_with_mocks(mock_auth_endpoint): + """Create AuthenticationMiddleware with mocked route implementations""" + mock_app = AsyncMock() + auth_config = AuthenticationConfig( + provider_config=CustomAuthConfig( + type=AuthProviderType.CUSTOM, + endpoint=mock_auth_endpoint, + ), + access_policy=[], + ) + middleware = AuthenticationMiddleware(mock_app, auth_config, {}) + + # Mock the route_impls to simulate finding routes with required scopes + from llama_stack.schema_utils import WebMethod + + scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read") + + public_webmethod = WebMethod(route="/test/public", method="GET") + + # Mock the route finding logic + def mock_find_matching_route(method, path, route_impls): + if method == "POST" and path == "/test/scoped": + return None, {}, "/test/scoped", scoped_webmethod + elif method == "GET" and path == "/test/public": + return None, {}, "/test/public", public_webmethod + else: + raise ValueError("No matching route") + + import llama_stack.distribution.server.auth + + llama_stack.distribution.server.auth.find_matching_route = mock_find_matching_route + llama_stack.distribution.server.auth.initialize_route_impls = lambda impls: {} + + return middleware, mock_app async def mock_post_success(*args, **kwargs): @@ -138,6 +182,36 @@ async def mock_post_exception(*args, **kwargs): raise Exception("Connection error") +async def mock_post_success_with_scope(*args, **kwargs): + """Mock auth response for user with test.read scope""" + return MockResponse( + 200, + { + "message": "Authentication successful", + "principal": "test-user", + "attributes": { + "scopes": ["test.read", "other.scope"], + "roles": ["user"], + }, + }, + ) + + +async def mock_post_success_no_scope(*args, **kwargs): + """Mock auth response for user without required scope""" + return MockResponse( + 200, + { + "message": "Authentication successful", + "principal": "test-user", + "attributes": { + "scopes": ["other.scope"], + "roles": ["user"], + }, + }, + ) + + # HTTP Endpoint Tests def test_missing_auth_header(http_client): response = http_client.get("/test") @@ -252,7 +326,7 @@ def oauth2_app(): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -351,7 +425,7 @@ def oauth2_app_with_jwks_token(): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -442,7 +516,7 @@ def introspection_app(mock_introspection_endpoint): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -472,7 +546,7 @@ def introspection_app_with_custom_mapping(mock_introspection_endpoint): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -581,3 +655,122 @@ def test_valid_introspection_with_custom_mapping_authentication( ) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} + + +# Scope-based authorization tests +@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope) +async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key): + """Test that user with required scope can access protected endpoint""" + middleware, mock_app = scope_middleware_with_mocks + mock_receive = AsyncMock() + mock_send = AsyncMock() + + scope = { + "type": "http", + "path": "/test/scoped", + "method": "POST", + "headers": [(b"authorization", f"Bearer {valid_api_key}".encode())], + } + + await middleware(scope, mock_receive, mock_send) + + # Should call the downstream app (no 403 error sent) + mock_app.assert_called_once_with(scope, mock_receive, mock_send) + mock_send.assert_not_called() + + +@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope) +async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key): + """Test that user without required scope gets 403 access denied""" + middleware, mock_app = scope_middleware_with_mocks + mock_receive = AsyncMock() + mock_send = AsyncMock() + + scope = { + "type": "http", + "path": "/test/scoped", + "method": "POST", + "headers": [(b"authorization", f"Bearer {valid_api_key}".encode())], + } + + await middleware(scope, mock_receive, mock_send) + + # Should send 403 error, not call downstream app + mock_app.assert_not_called() + assert mock_send.call_count == 2 # start + body + + # Check the response + start_call = mock_send.call_args_list[0][0][0] + assert start_call["status"] == 403 + + body_call = mock_send.call_args_list[1][0][0] + body_text = body_call["body"].decode() + assert "Access denied" in body_text + assert "test.read" in body_text + + +@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope) +async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key): + """Test that public endpoints work without specific scopes""" + middleware, mock_app = scope_middleware_with_mocks + mock_receive = AsyncMock() + mock_send = AsyncMock() + + scope = { + "type": "http", + "path": "/test/public", + "method": "GET", + "headers": [(b"authorization", f"Bearer {valid_api_key}".encode())], + } + + await middleware(scope, mock_receive, mock_send) + + # Should call the downstream app (no error) + mock_app.assert_called_once_with(scope, mock_receive, mock_send) + mock_send.assert_not_called() + + +async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks): + """Test that when auth is disabled (no user), scope checks are bypassed""" + middleware, mock_app = scope_middleware_with_mocks + mock_receive = AsyncMock() + mock_send = AsyncMock() + + scope = { + "type": "http", + "path": "/test/scoped", + "method": "POST", + "headers": [], # No authorization header + } + + await middleware(scope, mock_receive, mock_send) + + # Should send 401 auth error, not call downstream app + mock_app.assert_not_called() + assert mock_send.call_count == 2 # start + body + + # Check the response + start_call = mock_send.call_args_list[0][0][0] + assert start_call["status"] == 401 + + body_call = mock_send.call_args_list[1][0][0] + body_text = body_call["body"].decode() + assert "Authentication required" in body_text + + +def test_has_required_scope_function(): + """Test the _has_required_scope function directly""" + # Test user with required scope + user_with_scope = User(principal="test-user", attributes={"scopes": ["test.read", "other.scope"]}) + assert _has_required_scope("test.read", user_with_scope) + + # Test user without required scope + user_without_scope = User(principal="test-user", attributes={"scopes": ["other.scope"]}) + assert not _has_required_scope("test.read", user_without_scope) + + # Test user with no scopes attribute + user_no_scopes = User(principal="test-user", attributes={}) + assert not _has_required_scope("test.read", user_no_scopes) + + # Test no user (auth disabled) + assert _has_required_scope("test.read", None)