mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
feat(server): add attribute based access control for resources
This commit is contained in:
parent
7c0448456e
commit
b937a49436
8 changed files with 862 additions and 35 deletions
|
@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
|
|||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
|
@ -31,28 +32,137 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
|||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: Optional[List[str]] = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: Optional[List[str]] = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: Optional[List[str]] = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
|
||||
This class adds an optional access_attributes field that allows fine-grained control
|
||||
over which users can access each resource. When attributes are defined, a user must have
|
||||
matching attributes to access the resource.
|
||||
|
||||
Attribute Matching Algorithm:
|
||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||
|
||||
Examples:
|
||||
# Resource visible to everyone (no access control)
|
||||
model = Model(identifier="llama-2", ...)
|
||||
|
||||
# Resource visible only to admins
|
||||
model = Model(
|
||||
identifier="gpt-4",
|
||||
access_attributes=AccessAttributes(roles=["admin"])
|
||||
)
|
||||
|
||||
# Resource visible to data scientists on the ML team
|
||||
model = Model(
|
||||
identifier="private-model",
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["data-scientist", "researcher"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
)
|
||||
# ^ User must have at least one of the roles AND be on the ml-team
|
||||
|
||||
# Resource visible to users with specific project access
|
||||
vector_db = VectorDB(
|
||||
identifier="customer-embeddings",
|
||||
access_attributes=AccessAttributes(
|
||||
projects=["customer-insights"],
|
||||
namespaces=["confidential"]
|
||||
)
|
||||
)
|
||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||
"""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithACL(Model, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithACL(Shield, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolWithACL(Tool, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Union[
|
||||
Model,
|
||||
Shield,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
VectorDBWithACL,
|
||||
DatasetWithACL,
|
||||
ScoringFnWithACL,
|
||||
BenchmarkWithACL,
|
||||
ToolWithACL,
|
||||
ToolGroupWithACL,
|
||||
]
|
||||
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
Union[
|
||||
Model,
|
||||
Shield,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
VectorDBWithACL,
|
||||
DatasetWithACL,
|
||||
ScoringFnWithACL,
|
||||
BenchmarkWithACL,
|
||||
ToolWithACL,
|
||||
ToolGroupWithACL,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
@ -7,21 +7,29 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, ContextManager, Dict, Optional
|
||||
from typing import Any, ContextManager, Dict, List, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for request provider data
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||
self.provider_data = provider_data
|
||||
def __init__(
|
||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
):
|
||||
# Initialize with either provider_data or create a new dict
|
||||
self.provider_data = provider_data or {}
|
||||
|
||||
# Add auth attributes under a special key if provided
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -80,7 +88,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
|||
return None
|
||||
|
||||
|
||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||
def request_provider_data_context(
|
||||
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__auth_attributes")
|
||||
|
|
|
@ -40,10 +40,12 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
@ -184,6 +186,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not self._check_access(obj):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
|
@ -200,6 +207,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
|
@ -212,7 +226,89 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [obj for obj in filtered_objs if self._check_access(obj)]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
def _check_access(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
Access control algorithm:
|
||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||
3. For each attribute category in the resource's access_attributes:
|
||||
a. If the user lacks that category, access is DENIED
|
||||
b. If the user has the category but none of the required values, access is DENIED
|
||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||
|
||||
Example:
|
||||
# Resource requires:
|
||||
access_attributes = AccessAttributes(
|
||||
roles=["admin", "data-scientist"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
|
||||
# User has:
|
||||
user_attributes = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "infra-team"],
|
||||
"projects": ["llama-3"]
|
||||
}
|
||||
|
||||
# Result: Access GRANTED
|
||||
# - User has the "data-scientist" role (matches one of the required roles)
|
||||
# - AND user is part of the "ml-team" (matches the required team)
|
||||
# - The extra "projects" attribute is ignored
|
||||
|
||||
Args:
|
||||
obj: The resource object to check access for
|
||||
|
||||
Returns:
|
||||
bool: True if access is granted, False if denied
|
||||
"""
|
||||
# If object has no access attributes, allow access by default
|
||||
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
|
||||
return True
|
||||
|
||||
# Get user attributes from context
|
||||
user_attributes = get_auth_attributes()
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
# Convert AccessAttributes to dictionary for checking
|
||||
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
|
||||
|
||||
# If the model_dump is empty (all fields are None), allow access
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
# Check each attribute category (requires ALL categories to match)
|
||||
for attr_key, required_values in obj_attributes.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
# No values for this category in user attributes
|
||||
if not user_values:
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
|
||||
)
|
||||
return False
|
||||
|
||||
# None of the values in this category match (need at least one match per category)
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': "
|
||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
|
||||
return True
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
|
|
@ -5,16 +5,118 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using an external auth endpoint.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Sends it to the configured auth endpoint along with request details
|
||||
3. Validates the response and extracts user attributes
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Authentication Request Format:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
"request": {
|
||||
"path": "/models/list",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "..."
|
||||
// All headers except Authorization
|
||||
},
|
||||
"params": {
|
||||
"limit": ["100"],
|
||||
"offset": ["0"]
|
||||
// Query parameters as key -> list of values
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Expected Auth Endpoint Response Format:
|
||||
```json
|
||||
{
|
||||
"access_attributes": { // Structured attribute format
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research"]
|
||||
},
|
||||
"message": "Optional message about auth result"
|
||||
}
|
||||
```
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth endpoint are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_endpoint):
|
||||
self.app = app
|
||||
self.auth_endpoint = auth_endpoint
|
||||
|
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
|
|||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
auth_data = {
|
||||
"api_key": api_key,
|
||||
"request": {
|
||||
"path": path,
|
||||
"headers": request_headers,
|
||||
"params": params,
|
||||
},
|
||||
}
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=api_key,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(self.auth_endpoint, json=auth_data)
|
||||
response = await client.post(
|
||||
self.auth_endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed: {response.status_code}")
|
||||
return await self._send_auth_error(send, "Authentication failed")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
scope["user_attributes"] = user_attributes
|
||||
else:
|
||||
logger.warning("Authentication response did not contain any attributes")
|
||||
scope["user_attributes"] = {}
|
||||
|
||||
# Log authentication success with attribute details
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
|
|
@ -28,7 +28,9 @@ from typing_extensions import Annotated
|
|||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
AUTH_ATTRIBUTES_VAR,
|
||||
PROVIDER_DATA_VAR,
|
||||
auth_attributes_context,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
@ -179,8 +181,11 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
|
|
143
tests/unit/registry/test_registry_acl.py
Normal file
143
tests/unit/registry/test_registry_acl.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.server.auth import AccessAttributes
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def kvstore():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_registry_acl.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
yield kvstore
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def registry(kvstore):
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
async def test_registry_cache_with_acl(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-acl-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
||||
)
|
||||
|
||||
success = await registry.register(model)
|
||||
assert success
|
||||
|
||||
cached_model = registry.get_cached("model", "model-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.identifier == "model-acl"
|
||||
assert cached_model.access_attributes.roles == ["admin"]
|
||||
assert cached_model.access_attributes.teams == ["ai-team"]
|
||||
|
||||
fetched_model = await registry.get("model", "model-acl")
|
||||
assert fetched_model is not None
|
||||
assert fetched_model.identifier == "model-acl"
|
||||
assert fetched_model.access_attributes.roles == ["admin"]
|
||||
|
||||
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
||||
await registry.update(model)
|
||||
|
||||
updated_cached = registry.get_cached("model", "model-acl")
|
||||
assert updated_cached is not None
|
||||
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
||||
assert updated_cached.access_attributes.projects == ["project-x"]
|
||||
assert updated_cached.access_attributes.teams is None
|
||||
|
||||
new_registry = CachedDiskDistributionRegistry(registry._kvstore)
|
||||
await new_registry.initialize()
|
||||
|
||||
new_model = await new_registry.get("model", "model-acl")
|
||||
assert new_model is not None
|
||||
assert new_model.identifier == "model-acl"
|
||||
assert new_model.access_attributes.roles == ["admin", "user"]
|
||||
assert new_model.access_attributes.projects == ["project-x"]
|
||||
assert new_model.access_attributes.teams is None
|
||||
|
||||
|
||||
async def test_registry_empty_acl(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
|
||||
cached_model = registry.get_cached("model", "model-empty-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is not None
|
||||
assert cached_model.access_attributes.roles is None
|
||||
assert cached_model.access_attributes.teams is None
|
||||
assert cached_model.access_attributes.projects is None
|
||||
assert cached_model.access_attributes.namespaces is None
|
||||
|
||||
all_models = await registry.get_all()
|
||||
assert len(all_models) == 1
|
||||
|
||||
model = ModelWithACL(
|
||||
identifier="model-no-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource-2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
|
||||
cached_model = registry.get_cached("model", "model-no-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is None
|
||||
|
||||
all_models = await registry.get_all()
|
||||
assert len(all_models) == 2
|
||||
|
||||
|
||||
async def test_registry_serialization(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-serialize",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["admin", "researcher"],
|
||||
teams=["ai-team", "ml-team"],
|
||||
projects=["project-a", "project-b"],
|
||||
namespaces=["prod", "staging"],
|
||||
),
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
registry.cache.clear()
|
||||
|
||||
loaded_model = await registry.get("model", "model-serialize")
|
||||
|
||||
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
||||
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
||||
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
|
||||
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
|
236
tests/unit/server/test_access_control.py
Normal file
236
tests/unit/server/test_access_control.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super(AsyncMock, self).__call__(*args, **kwargs)
|
||||
|
||||
|
||||
def _return_model(model):
|
||||
return model
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_setup():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_access_control.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
|
||||
mock_inference = Mock()
|
||||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = "inference"
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=registry,
|
||||
)
|
||||
yield registry, routing_table
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_admin_only = ModelWithACL(
|
||||
identifier="model-admin",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-admin",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
model_data_scientist = ModelWithACL(
|
||||
identifier="model-data-scientist",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-data-scientist",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_admin_only)
|
||||
await registry.register(model_data_scientist)
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 3
|
||||
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
model = await routing_table.get_model("model-admin")
|
||||
assert model.identifier == "model-admin"
|
||||
model = await routing_table.get_model("model-data-scientist")
|
||||
assert model.identifier == "model-data-scientist"
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public"
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-admin")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
assert "model-public" in model_ids
|
||||
assert "model-data-scientist" in model_ids
|
||||
assert "model-admin" not in model_ids
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
model = await routing_table.get_model("model-data-scientist")
|
||||
assert model.identifier == "model-data-scientist"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-admin")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-updates",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-updates",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await registry.register(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
||||
await registry.update(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-updates")
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["admin"],
|
||||
}
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-attrs",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-empty-attrs",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
await registry.register(model)
|
||||
mock_get_auth_attributes.return_value = {}
|
||||
result = await routing_table.get_model("model-empty-attrs")
|
||||
assert result.identifier == "model-empty-attrs"
|
||||
all_models = await routing_table.list_models()
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
assert "model-empty-attrs" in model_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public-2",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public-2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_restricted = ModelWithACL(
|
||||
identifier="model-restricted",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-restricted",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_restricted)
|
||||
mock_get_auth_attributes.return_value = None
|
||||
model = await routing_table.get_model("model-public-2")
|
||||
assert model.identifier == "model-public-2"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-restricted")
|
||||
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public-2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
"""Test that newly created resources inherit access attributes from their creator."""
|
||||
registry, routing_table = test_setup
|
||||
|
||||
# Set creator's attributes
|
||||
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
|
||||
# Create model without explicit access attributes
|
||||
model = ModelWithACL(
|
||||
identifier="auto-access-model",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="auto-access-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await routing_table.register_object(model)
|
||||
|
||||
# Verify the model got creator's attributes
|
||||
registered_model = await routing_table.get_model("auto-access-model")
|
||||
assert registered_model.access_attributes is not None
|
||||
assert registered_model.access_attributes.roles == ["data-scientist"]
|
||||
assert registered_model.access_attributes.teams == ["ml-team"]
|
||||
assert registered_model.access_attributes.projects == ["llama-3"]
|
||||
|
||||
# Verify another user without matching attributes can't access it
|
||||
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("auto-access-model")
|
||||
|
||||
# But a user with matching attributes can
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
}
|
||||
model = await routing_table.get_model("auto-access-model")
|
||||
assert model.identifier == "auto-access-model"
|
|
@ -13,6 +13,15 @@ from fastapi.testclient import TestClient
|
|||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_endpoint():
|
||||
return "http://mock-auth-service/validate"
|
||||
|
@ -45,6 +54,26 @@ def client(app):
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scope():
|
||||
return {
|
||||
"type": "http",
|
||||
"path": "/models/list",
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"authorization", b"Bearer test-api-key"),
|
||||
(b"user-agent", b"test-user-agent"),
|
||||
],
|
||||
"query_string": b"limit=100&offset=0",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
return AuthenticationMiddleware(mock_app, mock_auth_endpoint), mock_app
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
@ -122,3 +151,59 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
|||
assert "authorization" in payload["request"]["headers"]
|
||||
assert "param1" in payload["request"]["params"]
|
||||
assert "param2" in payload["request"]["params"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_with_access_attributes(mock_middleware, mock_scope):
|
||||
middleware, mock_app = mock_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
200,
|
||||
{
|
||||
"access_attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["project-x", "project-y"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"]["roles"] == ["admin", "user"]
|
||||
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
|
||||
assert mock_scope["user_attributes"]["projects"] == ["project-x", "project-y"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful"
|
||||
# No access_attributes
|
||||
},
|
||||
)
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"] == {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue