diff --git a/llama_stack/distribution/access_control.py b/llama_stack/distribution/access_control.py new file mode 100644 index 000000000..7c7f12937 --- /dev/null +++ b/llama_stack/distribution/access_control.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from llama_stack.distribution.datatypes import RoutableObjectWithProvider +from llama_stack.log import get_logger + +logger = get_logger(__name__, category="core") + + +def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool: + """Check if the current user has access to the given object, based on access attributes. + + Access control algorithm: + 1. If the resource has no access_attributes, access is GRANTED to all authenticated users + 2. If the user has no attributes, access is DENIED to any object with access_attributes defined + 3. For each attribute category in the resource's access_attributes: + a. If the user lacks that category, access is DENIED + b. If the user has the category but none of the required values, access is DENIED + c. If the user has at least one matching value in each required category, access is GRANTED + + Example: + # Resource requires: + access_attributes = AccessAttributes( + roles=["admin", "data-scientist"], + teams=["ml-team"] + ) + + # User has: + user_attributes = { + "roles": ["data-scientist", "engineer"], + "teams": ["ml-team", "infra-team"], + "projects": ["llama-3"] + } + + # Result: Access GRANTED + # - User has the "data-scientist" role (matches one of the required roles) + # - AND user is part of the "ml-team" (matches the required team) + # - The extra "projects" attribute is ignored + + Args: + obj: The resource object to check access for + + Returns: + bool: True if access is granted, False if denied + """ + # If object has no access attributes, allow access by default + if not hasattr(obj, "access_attributes") or not obj.access_attributes: + return True + + # If no user attributes, deny access to objects with access control + if not user_attributes: + return False + + obj_attributes = obj.access_attributes.model_dump(exclude_none=True) + if not obj_attributes: + return True + + # Check each attribute category (requires ALL categories to match) + for attr_key, required_values in obj_attributes.items(): + user_values = user_attributes.get(attr_key, []) + + if not user_values: + logger.debug( + f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'" + ) + return False + + if not any(val in user_values for val in required_values): + logger.debug( + f"Access denied to {obj.type} '{obj.identifier}': " + f"no match for attribute '{attr_key}', required one of {required_values}" + ) + return False + + logger.debug(f"Access granted to {obj.type} '{obj.identifier}'") + return True diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index e16e047e5..48f1925dd 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput from llama_stack.apis.eval import Eval from llama_stack.apis.inference import Inference from llama_stack.apis.models import Model, ModelInput +from llama_stack.apis.resource import Resource from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput @@ -31,6 +32,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2" RoutingKey = Union[str, List[str]] +class AccessAttributes(BaseModel): + """Structured representation of user attributes for access control. + + This model defines a structured approach to representing user attributes + with common standard categories for access control. + + Standard attribute categories include: + - roles: Role-based attributes (e.g., admin, data-scientist) + - teams: Team-based attributes (e.g., ml-team, infra-team) + - projects: Project access attributes (e.g., llama-3, customer-insights) + - namespaces: Namespace-based access control for resource isolation + """ + + # Standard attribute categories - the minimal set we need now + roles: Optional[List[str]] = Field( + default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')" + ) + + teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')") + + projects: Optional[List[str]] = Field( + default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')" + ) + + namespaces: Optional[List[str]] = Field( + default=None, description="Namespace-based access control for resource isolation" + ) + + +class ResourceWithACL(Resource): + """Extension of Resource that adds attribute-based access control capabilities. + + This class adds an optional access_attributes field that allows fine-grained control + over which users can access each resource. When attributes are defined, a user must have + matching attributes to access the resource. + + Attribute Matching Algorithm: + 1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users + 2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects") + 3. The matching algorithm requires ALL categories to match (AND relationship between categories) + 4. Within each category, ANY value match is sufficient (OR relationship within a category) + + Examples: + # Resource visible to everyone (no access control) + model = Model(identifier="llama-2", ...) + + # Resource visible only to admins + model = Model( + identifier="gpt-4", + access_attributes=AccessAttributes(roles=["admin"]) + ) + + # Resource visible to data scientists on the ML team + model = Model( + identifier="private-model", + access_attributes=AccessAttributes( + roles=["data-scientist", "researcher"], + teams=["ml-team"] + ) + ) + # ^ User must have at least one of the roles AND be on the ml-team + + # Resource visible to users with specific project access + vector_db = VectorDB( + identifier="customer-embeddings", + access_attributes=AccessAttributes( + projects=["customer-insights"], + namespaces=["confidential"] + ) + ) + # ^ User must have access to the customer-insights project AND have confidential namespace + """ + + access_attributes: Optional[AccessAttributes] = None + + +# Use the extended Resource for all routable objects +class ModelWithACL(Model, ResourceWithACL): + pass + + +class ShieldWithACL(Shield, ResourceWithACL): + pass + + +class VectorDBWithACL(VectorDB, ResourceWithACL): + pass + + +class DatasetWithACL(Dataset, ResourceWithACL): + pass + + +class ScoringFnWithACL(ScoringFn, ResourceWithACL): + pass + + +class BenchmarkWithACL(Benchmark, ResourceWithACL): + pass + + +class ToolWithACL(Tool, ResourceWithACL): + pass + + +class ToolGroupWithACL(ToolGroup, ResourceWithACL): + pass + + RoutableObject = Union[ Model, Shield, @@ -45,14 +155,14 @@ RoutableObject = Union[ RoutableObjectWithProvider = Annotated[ Union[ - Model, - Shield, - VectorDB, - Dataset, - ScoringFn, - Benchmark, - Tool, - ToolGroup, + ModelWithACL, + ShieldWithACL, + VectorDBWithACL, + DatasetWithACL, + ScoringFnWithACL, + BenchmarkWithACL, + ToolWithACL, + ToolGroupWithACL, ], Field(discriminator="type"), ] diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 8709fc040..f9cde2cdf 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -7,21 +7,26 @@ import contextvars import json import logging -from typing import Any, ContextManager, Dict, Optional +from typing import Any, ContextManager, Dict, List, Optional from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) -# Context variable for request provider data +# Context variable for request provider data and auth attributes PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) class RequestProviderDataContext(ContextManager): """Context manager for request provider data""" - def __init__(self, provider_data: Optional[Dict[str, Any]] = None): - self.provider_data = provider_data + def __init__( + self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None + ): + self.provider_data = provider_data or {} + if auth_attributes: + self.provider_data["__auth_attributes"] = auth_attributes + self.token = None def __enter__(self): @@ -80,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A return None -def request_provider_data_context(headers: Dict[str, str]) -> ContextManager: - """Context manager that sets request provider data from headers for the duration of the context""" +def request_provider_data_context( + headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None +) -> ContextManager: + """Context manager that sets request provider data from headers and auth attributes for the duration of the context""" provider_data = parse_request_provider_data(headers) - return RequestProviderDataContext(provider_data) + return RequestProviderDataContext(provider_data, auth_attributes) + + +def get_auth_attributes() -> Optional[Dict[str, List[str]]]: + """Helper to retrieve auth attributes from the provider data context""" + provider_data = PROVIDER_DATA_VAR.get() + if not provider_data: + return None + return provider_data.get("__auth_attributes") diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 6277096d8..a2bc10fc1 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -41,11 +41,22 @@ from llama_stack.apis.tools import ( ToolHost, ) from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.distribution.access_control import check_access from llama_stack.distribution.datatypes import ( + AccessAttributes, + BenchmarkWithACL, + DatasetWithACL, + ModelWithACL, RoutableObject, RoutableObjectWithProvider, RoutedProtocol, + ScoringFnWithACL, + ShieldWithACL, + ToolGroupWithACL, + ToolWithACL, + VectorDBWithACL, ) +from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable @@ -186,6 +197,11 @@ class CommonRoutingTableImpl(RoutingTable): if not obj: return None + # Check if user has permission to access this object + if not check_access(obj, get_auth_attributes()): + logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") + return None + return obj async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: @@ -202,6 +218,13 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] + # If object supports access control but no attributes set, use creator's attributes + if not obj.access_attributes: + creator_attributes = get_auth_attributes() + if creator_attributes: + obj.access_attributes = AccessAttributes(**creator_attributes) + logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object if obj.type == ResourceType.model.value: @@ -214,7 +237,13 @@ class CommonRoutingTableImpl(RoutingTable): async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() - return [obj for obj in objs if obj.type == type] + filtered_objs = [obj for obj in objs if obj.type == type] + + # Apply attribute-based access control filtering + if filtered_objs: + filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())] + + return filtered_objs class ModelsRoutingTable(CommonRoutingTableImpl, Models): @@ -251,7 +280,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") - model = Model( + model = ModelWithACL( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, @@ -297,7 +326,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) if params is None: params = {} - shield = Shield( + shield = ShieldWithACL( identifier=shield_id, provider_resource_id=provider_shield_id, provider_id=provider_id, @@ -351,7 +380,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): "embedding_model": embedding_model, "embedding_dimension": model.metadata["embedding_dimension"], } - vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data) + vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) await self.register_object(vector_db) return vector_db @@ -405,7 +434,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): if metadata is None: metadata = {} - dataset = Dataset( + dataset = DatasetWithACL( identifier=dataset_id, provider_resource_id=provider_dataset_id, provider_id=provider_id, @@ -452,7 +481,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - scoring_fn = ScoringFn( + scoring_fn = ScoringFnWithACL( identifier=scoring_fn_id, description=description, return_type=return_type, @@ -494,7 +523,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): ) if provider_benchmark_id is None: provider_benchmark_id = benchmark_id - benchmark = Benchmark( + benchmark = BenchmarkWithACL( identifier=benchmark_id, dataset_id=dataset_id, scoring_functions=scoring_functions, @@ -537,7 +566,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): for tool_def in tool_defs: tools.append( - Tool( + ToolWithACL( identifier=tool_def.name, toolgroup_id=toolgroup_id, description=tool_def.description or "", @@ -562,7 +591,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): await self.register_object(tool) await self.dist_registry.register( - ToolGroup( + ToolGroupWithACL( identifier=toolgroup_id, provider_id=provider_id, provider_resource_id=toolgroup_id, diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index bb577bae5..52e6a013c 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -5,16 +5,118 @@ # the root directory of this source tree. import json +from typing import Dict, List, Optional from urllib.parse import parse_qs import httpx +from pydantic import BaseModel, Field +from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") +class AuthRequestContext(BaseModel): + path: str = Field(description="The path of the request being authenticated") + + headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)") + + params: Dict[str, List[str]] = Field( + description="Query parameters from the original request, parsed as dictionary of lists" + ) + + +class AuthRequest(BaseModel): + api_key: str = Field(description="The API key extracted from the Authorization header") + + request: AuthRequestContext = Field(description="Context information about the request being authenticated") + + +class AuthResponse(BaseModel): + """The format of the authentication response from the auth endpoint.""" + + access_attributes: Optional[AccessAttributes] = Field( + default=None, + description=""" + Structured user attributes for attribute-based access control. + + These attributes determine which resources the user can access. + The model provides standard categories like "roles", "teams", "projects", and "namespaces". + Each attribute category contains a list of values that the user has for that category. + During access control checks, these values are compared against resource requirements. + + Example with standard categories: + ```json + { + "roles": ["admin", "data-scientist"], + "teams": ["ml-team"], + "projects": ["llama-3"], + "namespaces": ["research"] + } + ``` + """, + ) + + message: Optional[str] = Field( + default=None, description="Optional message providing additional context about the authentication result." + ) + + class AuthenticationMiddleware: + """Middleware that authenticates requests using an external auth endpoint. + + This middleware: + 1. Extracts the Bearer token from the Authorization header + 2. Sends it to the configured auth endpoint along with request details + 3. Validates the response and extracts user attributes + 4. Makes these attributes available to the route handlers for access control + + Authentication Request Format: + ```json + { + "api_key": "the-api-key-extracted-from-auth-header", + "request": { + "path": "/models/list", + "headers": { + "content-type": "application/json", + "user-agent": "..." + // All headers except Authorization + }, + "params": { + "limit": ["100"], + "offset": ["0"] + // Query parameters as key -> list of values + } + } + } + ``` + + Expected Auth Endpoint Response Format: + ```json + { + "access_attributes": { // Structured attribute format + "roles": ["admin", "user"], + "teams": ["ml-team", "nlp-team"], + "projects": ["llama-3", "project-x"], + "namespaces": ["research"] + }, + "message": "Optional message about auth result" + } + ``` + + Attribute-Based Access Control: + The attributes returned by the auth endpoint are used to determine which + resources the user can access. Resources can specify required attributes + using the access_attributes field. For a user to access a resource: + + 1. All attribute categories specified in the resource must be present in the user's attributes + 2. For each category, the user must have at least one matching value + + If the auth endpoint doesn't return any attributes, the user will only be able to + access resources that don't have access_attributes defined. + """ + def __init__(self, app, auth_endpoint): self.app = app self.auth_endpoint = auth_endpoint @@ -32,25 +134,57 @@ class AuthenticationMiddleware: path = scope.get("path", "") request_headers = {k.decode(): v.decode() for k, v in headers.items()} + # Remove sensitive headers + if "authorization" in request_headers: + del request_headers["authorization"] + query_string = scope.get("query_string", b"").decode() params = parse_qs(query_string) - auth_data = { - "api_key": api_key, - "request": { - "path": path, - "headers": request_headers, - "params": params, - }, - } + # Build the auth request model + auth_request = AuthRequest( + api_key=api_key, + request=AuthRequestContext( + path=path, + headers=request_headers, + params=params, + ), + ) # Validate with authentication endpoint try: async with httpx.AsyncClient() as client: - response = await client.post(self.auth_endpoint, json=auth_data) + response = await client.post( + self.auth_endpoint, + json=auth_request.model_dump(), + timeout=10.0, # Add a reasonable timeout + ) if response.status_code != 200: logger.warning(f"Authentication failed: {response.status_code}") return await self._send_auth_error(send, "Authentication failed") + + # Parse and validate the auth response + try: + response_data = response.json() + auth_response = AuthResponse(**response_data) + + # Store attributes in request scope for access control + if auth_response.access_attributes: + user_attributes = auth_response.access_attributes.model_dump(exclude_none=True) + else: + logger.warning("No access attributes, setting namespace to api_key by default") + user_attributes = { + "namespaces": [api_key], + } + + scope["user_attributes"] = user_attributes + logger.debug(f"Authentication successful: {len(user_attributes)} attributes") + except Exception: + logger.exception("Error parsing authentication response") + return await self._send_auth_error(send, "Invalid authentication response format") + except httpx.TimeoutException: + logger.exception("Authentication request timed out") + return await self._send_auth_error(send, "Authentication service timeout") except Exception: logger.exception("Error during authentication") return await self._send_auth_error(send, "Authentication service error") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 212e65804..3bdeeef7c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -179,8 +179,11 @@ async def sse_generator(event_gen): def create_dynamic_typed_route(func: Any, method: str, route: str): async def endpoint(request: Request, **kwargs): - # Use context manager for request provider data - with request_provider_data_context(request.headers): + # Get auth attributes from the request scope + user_attributes = request.scope.get("user_attributes", {}) + + # Use context manager with both provider data and auth attributes + with request_provider_data_context(request.headers, user_attributes): is_streaming = is_streaming_request(func.__name__, request, **kwargs) try: diff --git a/scripts/unit-tests.sh b/scripts/unit-tests.sh index dbc25e06b..5cfaa989b 100755 --- a/scripts/unit-tests.sh +++ b/scripts/unit-tests.sh @@ -16,4 +16,4 @@ if [ $FOUND_PYTHON -ne 0 ]; then uv python install $PYTHON_VERSION fi -uv run --python $PYTHON_VERSION --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest -s -v tests/unit/ $@ +uv run --python $PYTHON_VERSION --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --asyncio-mode=auto -s -v tests/unit/ $@ diff --git a/tests/unit/registry/test_registry_acl.py b/tests/unit/registry/test_registry_acl.py new file mode 100644 index 000000000..ee8f28176 --- /dev/null +++ b/tests/unit/registry/test_registry_acl.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import shutil +import tempfile + +import pytest + +from llama_stack.apis.models import ModelType +from llama_stack.distribution.datatypes import ModelWithACL +from llama_stack.distribution.server.auth import AccessAttributes +from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl + + +@pytest.fixture(scope="function") +async def kvstore(): + temp_dir = tempfile.mkdtemp() + db_path = os.path.join(temp_dir, "test_registry_acl.db") + kvstore_config = SqliteKVStoreConfig(db_path=db_path) + kvstore = SqliteKVStoreImpl(kvstore_config) + await kvstore.initialize() + yield kvstore + shutil.rmtree(temp_dir) + + +@pytest.fixture(scope="function") +async def registry(kvstore): + registry = CachedDiskDistributionRegistry(kvstore) + await registry.initialize() + return registry + + +@pytest.mark.asyncio +async def test_registry_cache_with_acl(registry): + model = ModelWithACL( + identifier="model-acl", + provider_id="test-provider", + provider_resource_id="model-acl-resource", + model_type=ModelType.llm, + access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]), + ) + + success = await registry.register(model) + assert success + + cached_model = registry.get_cached("model", "model-acl") + assert cached_model is not None + assert cached_model.identifier == "model-acl" + assert cached_model.access_attributes.roles == ["admin"] + assert cached_model.access_attributes.teams == ["ai-team"] + + fetched_model = await registry.get("model", "model-acl") + assert fetched_model is not None + assert fetched_model.identifier == "model-acl" + assert fetched_model.access_attributes.roles == ["admin"] + + model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"]) + await registry.update(model) + + updated_cached = registry.get_cached("model", "model-acl") + assert updated_cached is not None + assert updated_cached.access_attributes.roles == ["admin", "user"] + assert updated_cached.access_attributes.projects == ["project-x"] + assert updated_cached.access_attributes.teams is None + + new_registry = CachedDiskDistributionRegistry(registry.kvstore) + await new_registry.initialize() + + new_model = await new_registry.get("model", "model-acl") + assert new_model is not None + assert new_model.identifier == "model-acl" + assert new_model.access_attributes.roles == ["admin", "user"] + assert new_model.access_attributes.projects == ["project-x"] + assert new_model.access_attributes.teams is None + + +@pytest.mark.asyncio +async def test_registry_empty_acl(registry): + model = ModelWithACL( + identifier="model-empty-acl", + provider_id="test-provider", + provider_resource_id="model-resource", + model_type=ModelType.llm, + access_attributes=AccessAttributes(), + ) + + await registry.register(model) + + cached_model = registry.get_cached("model", "model-empty-acl") + assert cached_model is not None + assert cached_model.access_attributes is not None + assert cached_model.access_attributes.roles is None + assert cached_model.access_attributes.teams is None + assert cached_model.access_attributes.projects is None + assert cached_model.access_attributes.namespaces is None + + all_models = await registry.get_all() + assert len(all_models) == 1 + + model = ModelWithACL( + identifier="model-no-acl", + provider_id="test-provider", + provider_resource_id="model-resource-2", + model_type=ModelType.llm, + ) + + await registry.register(model) + + cached_model = registry.get_cached("model", "model-no-acl") + assert cached_model is not None + assert cached_model.access_attributes is None + + all_models = await registry.get_all() + assert len(all_models) == 2 + + +@pytest.mark.asyncio +async def test_registry_serialization(registry): + attributes = AccessAttributes( + roles=["admin", "researcher"], + teams=["ai-team", "ml-team"], + projects=["project-a", "project-b"], + namespaces=["prod", "staging"], + ) + + model = ModelWithACL( + identifier="model-serialize", + provider_id="test-provider", + provider_resource_id="model-resource", + model_type=ModelType.llm, + access_attributes=attributes, + ) + + await registry.register(model) + + new_registry = CachedDiskDistributionRegistry(registry.kvstore) + await new_registry.initialize() + + loaded_model = await new_registry.get("model", "model-serialize") + assert loaded_model is not None + + assert loaded_model.access_attributes.roles == ["admin", "researcher"] + assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"] + assert loaded_model.access_attributes.projects == ["project-a", "project-b"] + assert loaded_model.access_attributes.namespaces == ["prod", "staging"] diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py new file mode 100644 index 000000000..ab0feb1a9 --- /dev/null +++ b/tests/unit/server/test_access_control.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import shutil +import tempfile +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from llama_stack.apis.datatypes import Api +from llama_stack.apis.models import ModelType +from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL +from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable +from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl + + +class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + +def _return_model(model): + return model + + +@pytest.fixture +async def test_setup(): + temp_dir = tempfile.mkdtemp() + db_path = os.path.join(temp_dir, "test_access_control.db") + kvstore_config = SqliteKVStoreConfig(db_path=db_path) + kvstore = SqliteKVStoreImpl(kvstore_config) + await kvstore.initialize() + registry = CachedDiskDistributionRegistry(kvstore) + await registry.initialize() + + mock_inference = Mock() + mock_inference.__provider_spec__ = MagicMock() + mock_inference.__provider_spec__.api = Api.inference + mock_inference.register_model = AsyncMock(side_effect=_return_model) + routing_table = ModelsRoutingTable( + impls_by_provider_id={"test_provider": mock_inference}, + dist_registry=registry, + ) + yield registry, routing_table + shutil.rmtree(temp_dir) + + +@pytest.mark.asyncio +@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): + registry, routing_table = test_setup + model_public = ModelWithACL( + identifier="model-public", + provider_id="test_provider", + provider_resource_id="model-public", + model_type=ModelType.llm, + ) + model_admin_only = ModelWithACL( + identifier="model-admin", + provider_id="test_provider", + provider_resource_id="model-admin", + model_type=ModelType.llm, + access_attributes=AccessAttributes(roles=["admin"]), + ) + model_data_scientist = ModelWithACL( + identifier="model-data-scientist", + provider_id="test_provider", + provider_resource_id="model-data-scientist", + model_type=ModelType.llm, + access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]), + ) + await registry.register(model_public) + await registry.register(model_admin_only) + await registry.register(model_data_scientist) + + mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]} + all_models = await routing_table.list_models() + assert len(all_models.data) == 2 + + model = await routing_table.get_model("model-public") + assert model.identifier == "model-public" + model = await routing_table.get_model("model-admin") + assert model.identifier == "model-admin" + with pytest.raises(ValueError): + await routing_table.get_model("model-data-scientist") + + mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]} + all_models = await routing_table.list_models() + assert len(all_models.data) == 1 + assert all_models.data[0].identifier == "model-public" + model = await routing_table.get_model("model-public") + assert model.identifier == "model-public" + with pytest.raises(ValueError): + await routing_table.get_model("model-admin") + with pytest.raises(ValueError): + await routing_table.get_model("model-data-scientist") + + mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]} + all_models = await routing_table.list_models() + assert len(all_models.data) == 2 + model_ids = [m.identifier for m in all_models.data] + assert "model-public" in model_ids + assert "model-data-scientist" in model_ids + assert "model-admin" not in model_ids + model = await routing_table.get_model("model-public") + assert model.identifier == "model-public" + model = await routing_table.get_model("model-data-scientist") + assert model.identifier == "model-data-scientist" + with pytest.raises(ValueError): + await routing_table.get_model("model-admin") + + +@pytest.mark.asyncio +@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): + registry, routing_table = test_setup + model_public = ModelWithACL( + identifier="model-updates", + provider_id="test_provider", + provider_resource_id="model-updates", + model_type=ModelType.llm, + ) + await registry.register(model_public) + mock_get_auth_attributes.return_value = { + "roles": ["user"], + } + model = await routing_table.get_model("model-updates") + assert model.identifier == "model-updates" + model_public.access_attributes = AccessAttributes(roles=["admin"]) + await registry.update(model_public) + mock_get_auth_attributes.return_value = { + "roles": ["user"], + } + with pytest.raises(ValueError): + await routing_table.get_model("model-updates") + mock_get_auth_attributes.return_value = { + "roles": ["admin"], + } + model = await routing_table.get_model("model-updates") + assert model.identifier == "model-updates" + + +@pytest.mark.asyncio +@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup): + registry, routing_table = test_setup + model = ModelWithACL( + identifier="model-empty-attrs", + provider_id="test_provider", + provider_resource_id="model-empty-attrs", + model_type=ModelType.llm, + access_attributes=AccessAttributes(), + ) + await registry.register(model) + mock_get_auth_attributes.return_value = { + "roles": [], + } + result = await routing_table.get_model("model-empty-attrs") + assert result.identifier == "model-empty-attrs" + all_models = await routing_table.list_models() + model_ids = [m.identifier for m in all_models.data] + assert "model-empty-attrs" in model_ids + + +@pytest.mark.asyncio +@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +async def test_no_user_attributes(mock_get_auth_attributes, test_setup): + registry, routing_table = test_setup + model_public = ModelWithACL( + identifier="model-public-2", + provider_id="test_provider", + provider_resource_id="model-public-2", + model_type=ModelType.llm, + ) + model_restricted = ModelWithACL( + identifier="model-restricted", + provider_id="test_provider", + provider_resource_id="model-restricted", + model_type=ModelType.llm, + access_attributes=AccessAttributes(roles=["admin"]), + ) + await registry.register(model_public) + await registry.register(model_restricted) + mock_get_auth_attributes.return_value = None + model = await routing_table.get_model("model-public-2") + assert model.identifier == "model-public-2" + + with pytest.raises(ValueError): + await routing_table.get_model("model-restricted") + + all_models = await routing_table.list_models() + assert len(all_models.data) == 1 + assert all_models.data[0].identifier == "model-public-2" + + +@pytest.mark.asyncio +@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup): + """Test that newly created resources inherit access attributes from their creator.""" + registry, routing_table = test_setup + + # Set creator's attributes + creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]} + mock_get_auth_attributes.return_value = creator_attributes + + # Create model without explicit access attributes + model = ModelWithACL( + identifier="auto-access-model", + provider_id="test_provider", + provider_resource_id="auto-access-model", + model_type=ModelType.llm, + ) + await routing_table.register_object(model) + + # Verify the model got creator's attributes + registered_model = await routing_table.get_model("auto-access-model") + assert registered_model.access_attributes is not None + assert registered_model.access_attributes.roles == ["data-scientist"] + assert registered_model.access_attributes.teams == ["ml-team"] + assert registered_model.access_attributes.projects == ["llama-3"] + + # Verify another user without matching attributes can't access it + mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]} + with pytest.raises(ValueError): + await routing_table.get_model("auto-access-model") + + # But a user with matching attributes can + mock_get_auth_attributes.return_value = { + "roles": ["data-scientist", "engineer"], + "teams": ["ml-team", "platform-team"], + "projects": ["llama-3"], + } + model = await routing_table.get_model("auto-access-model") + assert model.identifier == "auto-access-model" diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 70f08dbd6..5e93719d2 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -13,6 +13,15 @@ from fastapi.testclient import TestClient from llama_stack.distribution.server.auth import AuthenticationMiddleware +class MockResponse: + def __init__(self, status_code, json_data): + self.status_code = status_code + self._json_data = json_data + + def json(self): + return self._json_data + + @pytest.fixture def mock_auth_endpoint(): return "http://mock-auth-service/validate" @@ -45,16 +54,32 @@ def client(app): return TestClient(app) +@pytest.fixture +def mock_scope(): + return { + "type": "http", + "path": "/models/list", + "headers": [ + (b"content-type", b"application/json"), + (b"authorization", b"Bearer test-api-key"), + (b"user-agent", b"test-user-agent"), + ], + "query_string": b"limit=100&offset=0", + } + + +@pytest.fixture +def mock_middleware(mock_auth_endpoint): + mock_app = AsyncMock() + return AuthenticationMiddleware(mock_app, mock_auth_endpoint), mock_app + + async def mock_post_success(*args, **kwargs): - mock_response = AsyncMock() - mock_response.status_code = 200 - return mock_response + return MockResponse(200, {"message": "Authentication successful"}) async def mock_post_failure(*args, **kwargs): - mock_response = AsyncMock() - mock_response.status_code = 401 - return mock_response + return MockResponse(401, {"message": "Authentication failed"}) async def mock_post_exception(*args, **kwargs): @@ -96,8 +121,7 @@ def test_auth_service_error(client, valid_api_key): def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint): with patch("httpx.AsyncClient.post") as mock_post: - mock_response = AsyncMock() - mock_response.status_code = 200 + mock_response = MockResponse(200, {"message": "Authentication successful"}) mock_post.return_value = mock_response client.get( @@ -119,6 +143,64 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint): payload = kwargs["json"] assert payload["api_key"] == valid_api_key assert payload["request"]["path"] == "/test" - assert "authorization" in payload["request"]["headers"] + assert "authorization" not in payload["request"]["headers"] assert "param1" in payload["request"]["params"] assert "param2" in payload["request"]["params"] + + +@pytest.mark.asyncio +async def test_auth_middleware_with_access_attributes(mock_middleware, mock_scope): + middleware, mock_app = mock_middleware + mock_receive = AsyncMock() + mock_send = AsyncMock() + + with patch("httpx.AsyncClient") as mock_client: + mock_client_instance = AsyncMock() + mock_client.return_value.__aenter__.return_value = mock_client_instance + + mock_client_instance.post.return_value = MockResponse( + 200, + { + "access_attributes": { + "roles": ["admin", "user"], + "teams": ["ml-team"], + "projects": ["project-x", "project-y"], + } + }, + ) + + await middleware(mock_scope, mock_receive, mock_send) + + assert "user_attributes" in mock_scope + assert mock_scope["user_attributes"]["roles"] == ["admin", "user"] + assert mock_scope["user_attributes"]["teams"] == ["ml-team"] + assert mock_scope["user_attributes"]["projects"] == ["project-x", "project-y"] + + mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) + + +@pytest.mark.asyncio +async def test_auth_middleware_no_attributes(mock_middleware, mock_scope): + """Test middleware behavior with no access attributes""" + middleware, mock_app = mock_middleware + mock_receive = AsyncMock() + mock_send = AsyncMock() + + with patch("httpx.AsyncClient") as mock_client: + mock_client_instance = AsyncMock() + mock_client.return_value.__aenter__.return_value = mock_client_instance + + mock_client_instance.post.return_value = MockResponse( + 200, + { + "message": "Authentication successful" + # No access_attributes + }, + ) + + await middleware(mock_scope, mock_receive, mock_send) + + assert "user_attributes" in mock_scope + attributes = mock_scope["user_attributes"] + assert "namespaces" in attributes + assert attributes["namespaces"] == ["test-api-key"]