diff --git a/llama_stack/distribution/access_control.py b/llama_stack/distribution/access_control.py deleted file mode 100644 index d560ec80f..000000000 --- a/llama_stack/distribution/access_control.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.distribution.datatypes import AccessAttributes -from llama_stack.log import get_logger - -logger = get_logger(__name__, category="core") - - -def check_access( - obj_identifier: str, - obj_attributes: AccessAttributes | None, - user_attributes: dict[str, Any] | None = None, -) -> bool: - """Check if the current user has access to the given object, based on access attributes. - - Access control algorithm: - 1. If the resource has no access_attributes, access is GRANTED to all authenticated users - 2. If the user has no attributes, access is DENIED to any object with access_attributes defined - 3. For each attribute category in the resource's access_attributes: - a. If the user lacks that category, access is DENIED - b. If the user has the category but none of the required values, access is DENIED - c. If the user has at least one matching value in each required category, access is GRANTED - - Example: - # Resource requires: - access_attributes = AccessAttributes( - roles=["admin", "data-scientist"], - teams=["ml-team"] - ) - - # User has: - user_attributes = { - "roles": ["data-scientist", "engineer"], - "teams": ["ml-team", "infra-team"], - "projects": ["llama-3"] - } - - # Result: Access GRANTED - # - User has the "data-scientist" role (matches one of the required roles) - # - AND user is part of the "ml-team" (matches the required team) - # - The extra "projects" attribute is ignored - - Args: - obj_identifier: The identifier of the resource object to check access for - obj_attributes: The access attributes of the resource object - user_attributes: The attributes of the current user - - Returns: - bool: True if access is granted, False if denied - """ - # If object has no access attributes, allow access by default - if not obj_attributes: - return True - - # If no user attributes, deny access to objects with access control - if not user_attributes: - return False - - dict_attribs = obj_attributes.model_dump(exclude_none=True) - if not dict_attribs: - return True - - # Check each attribute category (requires ALL categories to match) - # TODO: formalize this into a proper ABAC policy - for attr_key, required_values in dict_attribs.items(): - user_values = user_attributes.get(attr_key, []) - - if not user_values: - logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'") - return False - - if not any(val in user_values for val in required_values): - logger.debug( - f"Access denied to {obj_identifier}: " - f"no match for attribute '{attr_key}', required one of {required_values}" - ) - return False - - logger.debug(f"Access granted to {obj_identifier}") - return True diff --git a/llama_stack/distribution/access_control/__init__.py b/llama_stack/distribution/access_control/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/access_control/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/access_control/access_control.py b/llama_stack/distribution/access_control/access_control.py new file mode 100644 index 000000000..84d506d8f --- /dev/null +++ b/llama_stack/distribution/access_control/access_control.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.distribution.datatypes import User + +from .conditions import ( + Condition, + ProtectedResource, + parse_conditions, +) +from .datatypes import ( + AccessRule, + Action, + Scope, +) + + +def matches_resource(resource_scope: str, actual_resource: str) -> bool: + if resource_scope == actual_resource: + return True + return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1]) + + +def matches_scope( + scope: Scope, + action: Action, + resource: str, + user: str | None, +) -> bool: + if scope.resource and not matches_resource(scope.resource, resource): + return False + if scope.principal and scope.principal != user: + return False + return action in scope.actions + + +def as_list(obj: Any) -> list[Any]: + if isinstance(obj, list): + return obj + return [obj] + + +def matches_conditions( + conditions: list[Condition], + resource: ProtectedResource, + user: User, +) -> bool: + for condition in conditions: + # must match all conditions + if not condition.matches(resource, user): + return False + return True + + +def default_policy() -> list[AccessRule]: + # for backwards compatibility, if no rules are provided, assume + # full access subject to previous attribute matching rules + return [ + AccessRule( + permit=Scope(actions=list(Action)), + when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]], + ), + ] + + +def is_action_allowed( + policy: list[AccessRule], + action: Action, + resource: ProtectedResource, + user: User | None, +) -> bool: + # If user is not set, assume authentication is not enabled + if not user: + return True + + if not len(policy): + policy = default_policy() + + qualified_resource_id = resource.type + "::" + resource.identifier + for rule in policy: + if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal): + if rule.when: + if matches_conditions(parse_conditions(as_list(rule.when)), resource, user): + return False + elif rule.unless: + if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user): + return False + else: + return False + elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal): + if rule.when: + if matches_conditions(parse_conditions(as_list(rule.when)), resource, user): + return True + elif rule.unless: + if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user): + return True + else: + return True + # assume access is denied unless we find a rule that permits access + return False + + +class AccessDeniedError(RuntimeError): + pass diff --git a/llama_stack/distribution/access_control/conditions.py b/llama_stack/distribution/access_control/conditions.py new file mode 100644 index 000000000..25a267124 --- /dev/null +++ b/llama_stack/distribution/access_control/conditions.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Protocol + + +class User(Protocol): + principal: str + attributes: dict[str, list[str]] | None + + +class ProtectedResource(Protocol): + type: str + identifier: str + owner: User + + +class Condition(Protocol): + def matches(self, resource: ProtectedResource, user: User) -> bool: ... + + +class UserInOwnersList: + def __init__(self, name: str): + self.name = name + + def owners_values(self, resource: ProtectedResource) -> list[str] | None: + if ( + hasattr(resource, "owner") + and resource.owner + and resource.owner.attributes + and self.name in resource.owner.attributes + ): + return resource.owner.attributes[self.name] + else: + return None + + def matches(self, resource: ProtectedResource, user: User) -> bool: + required = self.owners_values(resource) + if not required: + return True + if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]: + return False + user_values = user.attributes[self.name] + for value in required: + if value in user_values: + return True + return False + + def __repr__(self): + return f"user in owners {self.name}" + + +class UserNotInOwnersList(UserInOwnersList): + def __init__(self, name: str): + super().__init__(name) + + def matches(self, resource: ProtectedResource, user: User) -> bool: + return not super().matches(resource, user) + + def __repr__(self): + return f"user not in owners {self.name}" + + +class UserWithValueInList: + def __init__(self, name: str, value: str): + self.name = name + self.value = value + + def matches(self, resource: ProtectedResource, user: User) -> bool: + if user.attributes and self.name in user.attributes: + return self.value in user.attributes[self.name] + print(f"User does not have {self.value} in {self.name}") + return False + + def __repr__(self): + return f"user with {self.value} in {self.name}" + + +class UserWithValueNotInList(UserWithValueInList): + def __init__(self, name: str, value: str): + super().__init__(name, value) + + def matches(self, resource: ProtectedResource, user: User) -> bool: + return not super().matches(resource, user) + + def __repr__(self): + return f"user with {self.value} not in {self.name}" + + +class UserIsOwner: + def matches(self, resource: ProtectedResource, user: User) -> bool: + return resource.owner.principal == user.principal if resource.owner else False + + def __repr__(self): + return "user is owner" + + +class UserIsNotOwner: + def matches(self, resource: ProtectedResource, user: User) -> bool: + return not resource.owner or resource.owner.principal != user.principal + + def __repr__(self): + return "user is not owner" + + +def parse_condition(condition: str) -> Condition: + words = condition.split() + match words: + case ["user", "is", "owner"]: + return UserIsOwner() + case ["user", "is", "not", "owner"]: + return UserIsNotOwner() + case ["user", "with", value, "in", name]: + return UserWithValueInList(name, value) + case ["user", "with", value, "not", "in", name]: + return UserWithValueNotInList(name, value) + case ["user", "in", "owners", name]: + return UserInOwnersList(name) + case ["user", "not", "in", "owners", name]: + return UserNotInOwnersList(name) + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def parse_conditions(conditions: list[str]) -> list[Condition]: + return [parse_condition(c) for c in conditions] diff --git a/llama_stack/distribution/access_control/datatypes.py b/llama_stack/distribution/access_control/datatypes.py new file mode 100644 index 000000000..3e6c624dc --- /dev/null +++ b/llama_stack/distribution/access_control/datatypes.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum + +from pydantic import BaseModel, model_validator +from typing_extensions import Self + +from .conditions import parse_conditions + + +class Action(str, Enum): + CREATE = "create" + READ = "read" + UPDATE = "update" + DELETE = "delete" + + +class Scope(BaseModel): + principal: str | None = None + actions: Action | list[Action] + resource: str | None = None + + +def _mutually_exclusive(obj, a: str, b: str): + if getattr(obj, a) and getattr(obj, b): + raise ValueError(f"{a} and {b} are mutually exclusive") + + +def _require_one_of(obj, a: str, b: str): + if not getattr(obj, a) and not getattr(obj, b): + raise ValueError(f"on of {a} or {b} is required") + + +class AccessRule(BaseModel): + """Access rule based loosely on cedar policy language + + A rule defines a list of action either to permit or to forbid. It may specify a + principal or a resource that must match for the rule to take effect. The resource + to match should be specified in the form of a type qualified identifier, e.g. + model::my-model or vector_db::some-db, or a wildcard for all resources of a type, + e.g. model::*. If the principal or resource are not specified, they will match all + requests. + + A rule may also specify a condition, either a 'when' or an 'unless', with additional + constraints as to where the rule applies. The constraints supported at present are: + + - 'user with in ' + - 'user with not in ' + - 'user is owner' + - 'user is not owner' + - 'user in owners ' + - 'user not in owners ' + + Rules are tested in order to find a match. If a match is found, the request is + permitted or forbidden depending on the type of rule. If no match is found, the + request is denied. If no rules are specified, a rule that allows any action as + long as the resource attributes match the user attributes is added + (i.e. the previous behaviour is the default). + + Some examples in yaml: + + - permit: + principal: user-1 + actions: [create, read, delete] + resource: model::* + description: user-1 has full access to all models + - permit: + principal: user-2 + actions: [read] + resource: model::model-1 + description: user-2 has read access to model-1 only + - permit: + actions: [read] + when: user in owner teams + description: any user has read access to any resource created by a member of their team + - forbid: + actions: [create, read, delete] + resource: vector_db::* + unless: user with admin in roles + description: only user with admin role can use vector_db resources + + """ + + permit: Scope | None = None + forbid: Scope | None = None + when: str | list[str] | None = None + unless: str | list[str] | None = None + description: str | None = None + + @model_validator(mode="after") + def validate_rule_format(self) -> Self: + _require_one_of(self, "permit", "forbid") + _mutually_exclusive(self, "permit", "forbid") + _mutually_exclusive(self, "when", "unless") + if isinstance(self.when, list): + parse_conditions(self.when) + elif self.when: + parse_conditions([self.when]) + if isinstance(self.unless, list): + parse_conditions(self.unless) + elif self.unless: + parse_conditions([self.unless]) + return self diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index def7048c0..abc3f0065 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO +from llama_stack.distribution.access_control.datatypes import AccessRule from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig @@ -35,126 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2" RoutingKey = str | list[str] -class AccessAttributes(BaseModel): - """Structured representation of user attributes for access control. +class User(BaseModel): + principal: str + # further attributes that may be used for access control decisions + attributes: dict[str, list[str]] | None = None - This model defines a structured approach to representing user attributes - with common standard categories for access control. - - Standard attribute categories include: - - roles: Role-based attributes (e.g., admin, data-scientist) - - teams: Team-based attributes (e.g., ml-team, infra-team) - - projects: Project access attributes (e.g., llama-3, customer-insights) - - namespaces: Namespace-based access control for resource isolation - """ - - # Standard attribute categories - the minimal set we need now - roles: list[str] | None = Field( - default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')" - ) - - teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')") - - projects: list[str] | None = Field( - default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')" - ) - - namespaces: list[str] | None = Field( - default=None, description="Namespace-based access control for resource isolation" - ) + def __init__(self, principal: str, attributes: dict[str, list[str]] | None): + super().__init__(principal=principal, attributes=attributes) -class ResourceWithACL(Resource): - """Extension of Resource that adds attribute-based access control capabilities. +class ResourceWithOwner(Resource): + """Extension of Resource that adds an optional owner, i.e. the user that created the + resource. This can be used to constrain access to the resource.""" - This class adds an optional access_attributes field that allows fine-grained control - over which users can access each resource. When attributes are defined, a user must have - matching attributes to access the resource. - - Attribute Matching Algorithm: - 1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users - 2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects") - 3. The matching algorithm requires ALL categories to match (AND relationship between categories) - 4. Within each category, ANY value match is sufficient (OR relationship within a category) - - Examples: - # Resource visible to everyone (no access control) - model = Model(identifier="llama-2", ...) - - # Resource visible only to admins - model = Model( - identifier="gpt-4", - access_attributes=AccessAttributes(roles=["admin"]) - ) - - # Resource visible to data scientists on the ML team - model = Model( - identifier="private-model", - access_attributes=AccessAttributes( - roles=["data-scientist", "researcher"], - teams=["ml-team"] - ) - ) - # ^ User must have at least one of the roles AND be on the ml-team - - # Resource visible to users with specific project access - vector_db = VectorDB( - identifier="customer-embeddings", - access_attributes=AccessAttributes( - projects=["customer-insights"], - namespaces=["confidential"] - ) - ) - # ^ User must have access to the customer-insights project AND have confidential namespace - """ - - access_attributes: AccessAttributes | None = None + owner: User | None = None # Use the extended Resource for all routable objects -class ModelWithACL(Model, ResourceWithACL): +class ModelWithOwner(Model, ResourceWithOwner): pass -class ShieldWithACL(Shield, ResourceWithACL): +class ShieldWithOwner(Shield, ResourceWithOwner): pass -class VectorDBWithACL(VectorDB, ResourceWithACL): +class VectorDBWithOwner(VectorDB, ResourceWithOwner): pass -class DatasetWithACL(Dataset, ResourceWithACL): +class DatasetWithOwner(Dataset, ResourceWithOwner): pass -class ScoringFnWithACL(ScoringFn, ResourceWithACL): +class ScoringFnWithOwner(ScoringFn, ResourceWithOwner): pass -class BenchmarkWithACL(Benchmark, ResourceWithACL): +class BenchmarkWithOwner(Benchmark, ResourceWithOwner): pass -class ToolWithACL(Tool, ResourceWithACL): +class ToolWithOwner(Tool, ResourceWithOwner): pass -class ToolGroupWithACL(ToolGroup, ResourceWithACL): +class ToolGroupWithOwner(ToolGroup, ResourceWithOwner): pass RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup RoutableObjectWithProvider = Annotated[ - ModelWithACL - | ShieldWithACL - | VectorDBWithACL - | DatasetWithACL - | ScoringFnWithACL - | BenchmarkWithACL - | ToolWithACL - | ToolGroupWithACL, + ModelWithOwner + | ShieldWithOwner + | VectorDBWithOwner + | DatasetWithOwner + | ScoringFnWithOwner + | BenchmarkWithOwner + | ToolWithOwner + | ToolGroupWithOwner, Field(discriminator="type"), ] @@ -234,6 +175,7 @@ class AuthenticationConfig(BaseModel): ..., description="Provider-specific configuration", ) + access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources") class AuthenticationRequiredError(Exception): diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index b03d2dee8..81d494e04 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -10,6 +10,8 @@ import logging from contextlib import AbstractContextManager from typing import Any +from llama_stack.distribution.datatypes import User + from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) @@ -21,12 +23,10 @@ PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) class RequestProviderDataContext(AbstractContextManager): """Context manager for request provider data""" - def __init__( - self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None - ): + def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None): self.provider_data = provider_data or {} - if auth_attributes: - self.provider_data["__auth_attributes"] = auth_attributes + if user: + self.provider_data["__authenticated_user"] = user self.token = None @@ -95,9 +95,9 @@ def request_provider_data_context( return RequestProviderDataContext(provider_data, auth_attributes) -def get_auth_attributes() -> dict[str, list[str]] | None: +def get_authenticated_user() -> User | None: """Helper to retrieve auth attributes from the provider data context""" provider_data = PROVIDER_DATA_VAR.get() if not provider_data: return None - return provider_data.get("__auth_attributes") + return provider_data.get("__authenticated_user") diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index b7c7cb87f..6e7bb5edd 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.datatypes import ( + AccessRule, AutoRoutedProviderSpec, Provider, RoutingTableProviderSpec, @@ -118,6 +119,7 @@ async def resolve_impls( run_config: StackRunConfig, provider_registry: ProviderRegistry, dist_registry: DistributionRegistry, + policy: list[AccessRule], ) -> dict[Api, Any]: """ Resolves provider implementations by: @@ -140,7 +142,7 @@ async def resolve_impls( sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) - return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config) + return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy) def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: @@ -247,6 +249,7 @@ async def instantiate_providers( router_apis: set[Api], dist_registry: DistributionRegistry, run_config: StackRunConfig, + policy: list[AccessRule], ) -> dict: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} @@ -261,7 +264,7 @@ async def instantiate_providers( if isinstance(provider.spec, RoutingTableProviderSpec): inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] - impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config) + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy) if api_str.startswith("inner-"): inner_impls_by_provider_id[api_str][provider.provider_id] = impl @@ -312,6 +315,7 @@ async def instantiate_provider( inner_impls: dict[str, Any], dist_registry: DistributionRegistry, run_config: StackRunConfig, + policy: list[AccessRule], ): provider_spec = provider.spec if not hasattr(provider_spec, "module"): @@ -336,13 +340,15 @@ async def instantiate_provider( method = "get_routing_table_impl" config = None - args = [provider_spec.api, inner_impls, deps, dist_registry] + args = [provider_spec.api, inner_impls, deps, dist_registry, policy] else: method = "get_provider_impl" config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider.config) args = [config, deps] + if "policy" in inspect.signature(getattr(module, method)).parameters: + args.append(policy) fn = getattr(module, method) impl = await fn(*args) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 1358d5812..0a0c13880 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import RoutedProtocol +from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol from llama_stack.distribution.stack import StackRunConfig from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable @@ -18,6 +18,7 @@ async def get_routing_table_impl( impls_by_provider_id: dict[str, RoutedProtocol], _deps, dist_registry: DistributionRegistry, + policy: list[AccessRule], ) -> Any: from ..routing_tables.benchmarks import BenchmarksRoutingTable from ..routing_tables.datasets import DatasetsRoutingTable @@ -40,7 +41,7 @@ async def get_routing_table_impl( if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) + impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy) await impl.initialize() return impl diff --git a/llama_stack/distribution/routing_tables/benchmarks.py b/llama_stack/distribution/routing_tables/benchmarks.py index 589a00c02..815483494 100644 --- a/llama_stack/distribution/routing_tables/benchmarks.py +++ b/llama_stack/distribution/routing_tables/benchmarks.py @@ -8,7 +8,7 @@ from typing import Any from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.distribution.datatypes import ( - BenchmarkWithACL, + BenchmarkWithOwner, ) from llama_stack.log import get_logger @@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): ) if provider_benchmark_id is None: provider_benchmark_id = benchmark_id - benchmark = BenchmarkWithACL( + benchmark = BenchmarkWithOwner( identifier=benchmark_id, dataset_id=dataset_id, scoring_functions=scoring_functions, diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 8ec87ca50..b79c8a2a8 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -8,14 +8,14 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn -from llama_stack.distribution.access_control import check_access +from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.datatypes import ( - AccessAttributes, + AccessRule, RoutableObject, RoutableObjectWithProvider, RoutedProtocol, ) -from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.distribution.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable @@ -73,9 +73,11 @@ class CommonRoutingTableImpl(RoutingTable): self, impls_by_provider_id: dict[str, RoutedProtocol], dist_registry: DistributionRegistry, + policy: list[AccessRule], ) -> None: self.impls_by_provider_id = impls_by_provider_id self.dist_registry = dist_registry + self.policy = policy async def initialize(self) -> None: async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: @@ -166,13 +168,15 @@ class CommonRoutingTableImpl(RoutingTable): return None # Check if user has permission to access this object - if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): - logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") + if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()): + logger.debug(f"Access denied to {type} '{identifier}'") return None return obj async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: + if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()): + raise AccessDeniedError() await self.dist_registry.delete(obj.type, obj.identifier) await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) @@ -187,11 +191,12 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] # If object supports access control but no attributes set, use creator's attributes - if not obj.access_attributes: - creator_attributes = get_auth_attributes() - if creator_attributes: - obj.access_attributes = AccessAttributes(**creator_attributes) - logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + creator = get_authenticated_user() + if not is_action_allowed(self.policy, "create", obj, creator): + raise AccessDeniedError() + if creator: + obj.owner = creator + logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}") registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object @@ -210,9 +215,7 @@ class CommonRoutingTableImpl(RoutingTable): # Apply attribute-based access control filtering if filtered_objs: filtered_objs = [ - obj - for obj in filtered_objs - if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) + obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user()) ] return filtered_objs diff --git a/llama_stack/distribution/routing_tables/datasets.py b/llama_stack/distribution/routing_tables/datasets.py index 4401ad47e..fb34f40b6 100644 --- a/llama_stack/distribution/routing_tables/datasets.py +++ b/llama_stack/distribution/routing_tables/datasets.py @@ -19,7 +19,7 @@ from llama_stack.apis.datasets import ( ) from llama_stack.apis.resource import ResourceType from llama_stack.distribution.datatypes import ( - DatasetWithACL, + DatasetWithOwner, ) from llama_stack.log import get_logger @@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): if metadata is None: metadata = {} - dataset = DatasetWithACL( + dataset = DatasetWithOwner( identifier=dataset_id, provider_resource_id=provider_dataset_id, provider_id=provider_id, diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 7216d9935..c6a10ea9b 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -9,7 +9,7 @@ from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ( - ModelWithACL, + ModelWithOwner, ) from llama_stack.log import get_logger @@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") - model = ModelWithACL( + model = ModelWithOwner( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, diff --git a/llama_stack/distribution/routing_tables/scoring_functions.py b/llama_stack/distribution/routing_tables/scoring_functions.py index d85f64b57..742cc3ca6 100644 --- a/llama_stack/distribution/routing_tables/scoring_functions.py +++ b/llama_stack/distribution/routing_tables/scoring_functions.py @@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import ( ScoringFunctions, ) from llama_stack.distribution.datatypes import ( - ScoringFnWithACL, + ScoringFnWithOwner, ) from llama_stack.log import get_logger @@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - scoring_fn = ScoringFnWithACL( + scoring_fn = ScoringFnWithOwner( identifier=scoring_fn_id, description=description, return_type=return_type, diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py index 7f62596c9..5215981b9 100644 --- a/llama_stack/distribution/routing_tables/shields.py +++ b/llama_stack/distribution/routing_tables/shields.py @@ -9,7 +9,7 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields from llama_stack.distribution.datatypes import ( - ShieldWithACL, + ShieldWithOwner, ) from llama_stack.log import get_logger @@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) if params is None: params = {} - shield = ShieldWithACL( + shield = ShieldWithOwner( identifier=shield_id, provider_resource_id=provider_shield_id, provider_id=provider_id, diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index 2f7dc3e06..b86f057bd 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -8,7 +8,7 @@ from typing import Any from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups -from llama_stack.distribution.datatypes import ToolGroupWithACL +from llama_stack.distribution.datatypes import ToolGroupWithOwner from llama_stack.log import get_logger from .common import CommonRoutingTableImpl @@ -106,7 +106,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): mcp_endpoint: URL | None = None, args: dict[str, Any] | None = None, ) -> None: - toolgroup = ToolGroupWithACL( + toolgroup = ToolGroupWithOwner( identifier=toolgroup_id, provider_id=provider_id, provider_resource_id=toolgroup_id, diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index dc6c0d0ef..542e965f8 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.distribution.datatypes import ( - VectorDBWithACL, + VectorDBWithOwner, ) from llama_stack.log import get_logger @@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): "embedding_model": embedding_model, "embedding_dimension": model.metadata["embedding_dimension"], } - vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) + vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) await self.register_object(vector_db) return vector_db diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index fb26b49a7..81b1ffd37 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -105,24 +105,16 @@ class AuthenticationMiddleware: logger.exception("Error during authentication") return await self._send_auth_error(send, "Authentication service error") - # Store attributes in request scope for access control - if validation_result.access_attributes: - user_attributes = validation_result.access_attributes.model_dump(exclude_none=True) - else: - logger.warning("No access attributes, setting namespace to token by default") - user_attributes = { - "roles": [token], - } - # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) # can identify the requester and enforce per-client rate limits. scope["authenticated_client_id"] = token # Store attributes in request scope - scope["user_attributes"] = user_attributes scope["principal"] = validation_result.principal + if validation_result.attributes: + scope["user_attributes"] = validation_result.attributes logger.debug( - f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes" + f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes" ) return await self.app(scope, receive, send) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 723a65b77..7cb038494 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -16,43 +16,18 @@ from jose import jwt from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self -from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType +from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") -class TokenValidationResult(BaseModel): - principal: str | None = Field( - default=None, - description="The principal (username or persistent identifier) of the authenticated user", - ) - access_attributes: AccessAttributes | None = Field( - default=None, - description=""" - Structured user attributes for attribute-based access control. - - These attributes determine which resources the user can access. - The model provides standard categories like "roles", "teams", "projects", and "namespaces". - Each attribute category contains a list of values that the user has for that category. - During access control checks, these values are compared against resource requirements. - - Example with standard categories: - ```json - { - "roles": ["admin", "data-scientist"], - "teams": ["ml-team"], - "projects": ["llama-3"], - "namespaces": ["research"] - } - ``` - """, - ) - - -class AuthResponse(TokenValidationResult): +class AuthResponse(BaseModel): """The format of the authentication response from the auth endpoint.""" + principal: str + # further attributes that may be used for access control decisions + attributes: dict[str, list[str]] | None = None message: str | None = Field( default=None, description="Optional message providing additional context about the authentication result." ) @@ -78,7 +53,7 @@ class AuthProvider(ABC): """Abstract base class for authentication providers.""" @abstractmethod - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_token(self, token: str, scope: dict | None = None) -> User: """Validate a token and return access attributes.""" pass @@ -88,10 +63,10 @@ class AuthProvider(ABC): pass -def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: - attributes = AccessAttributes() +def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]: + attributes: dict[str, list[str]] = {} for claim_key, attribute_key in mapping.items(): - if claim_key not in claims or not hasattr(attributes, attribute_key): + if claim_key not in claims: continue claim = claims[claim_key] if isinstance(claim, list): @@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) else: values = claim.split() - current = getattr(attributes, attribute_key) - if current: - current.extend(values) + if attribute_key in attributes: + attributes[attribute_key].extend(values) else: - setattr(attributes, attribute_key, values) + attributes[attribute_key] = values return attributes @@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel): for key, value in v.items(): if not value: raise ValueError(f"claims_mapping value cannot be empty: {key}") - if value not in AccessAttributes.model_fields: - raise ValueError(f"claims_mapping value is not a valid attribute: {value}") return v @model_validator(mode="after") @@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider): self._jwks: dict[str, str] = {} self._jwks_lock = Lock() - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_token(self, token: str, scope: dict | None = None) -> User: if self.config.jwks: return await self.validate_jwt_token(token, scope) if self.config.introspection: return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") - async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the JWT token.""" await self._refresh_jwks() @@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider): # We should incorporate these into the access attributes. principal = claims["sub"] access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) - return TokenValidationResult( + return User( principal=principal, - access_attributes=access_attributes, + attributes=access_attributes, ) - async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def introspect_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using token introspection as defined by RFC 7662.""" form = { "token": token, @@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider): raise ValueError("Token not active") principal = fields["sub"] or fields["username"] access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) - return TokenValidationResult( + return User( principal=principal, - access_attributes=access_attributes, + attributes=access_attributes, ) except httpx.TimeoutException: logger.exception("Token introspection request timed out") @@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider): self.config = config self._client = None - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the custom authentication endpoint.""" if scope is None: scope = {} @@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider): json=auth_request.model_dump(), timeout=10.0, # Add a reasonable timeout ) + print("MADE CALL") if response.status_code != 200: logger.warning(f"Authentication failed with status code: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}") @@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider): try: response_data = response.json() auth_response = AuthResponse(**response_data) - return auth_response + return User(auth_response.principal, auth_response.attributes) except Exception as e: logger.exception("Error parsing authentication response") raise ValueError("Invalid authentication response format") from e diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b4089bffc..5fdfba574 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -33,10 +33,7 @@ from pydantic import BaseModel, ValidationError from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.request_headers import ( - PROVIDER_DATA_VAR, - request_provider_data_context, -) +from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.server.routes import ( find_matching_route, @@ -217,11 +214,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: async def route_handler(request: Request, **kwargs): # Get auth attributes from the request scope user_attributes = request.scope.get("user_attributes", {}) + principal = request.scope.get("principal", "") + user = User(principal, user_attributes) await log_request_pre_validation(request) # Use context manager with both provider data and auth attributes - with request_provider_data_context(request.headers, user_attributes): + with request_provider_data_context(request.headers, user): is_streaming = is_streaming_request(func.__name__, request, **kwargs) try: diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index fc68dc016..5a9708497 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -223,7 +223,10 @@ async def construct_stack( run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None ) -> dict[Api, Any]: dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) - impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) + policy = run_config.server.auth.access_policy if run_config.server.auth else [] + impls = await resolve_impls( + run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy + ) # Add internal implementations after all other providers are resolved add_internal_implementations(impls, run_config) diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 7503b8c90..4a77e65b9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -6,12 +6,12 @@ from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.datatypes import AccessRule, Api from .config import MetaReferenceAgentsImplConfig -async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]): +async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): from .agents import MetaReferenceAgentsImpl impl = MetaReferenceAgentsImpl( @@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap deps[Api.safety], deps[Api.tool_runtime], deps[Api.tool_groups], + policy, ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2e387e7e8..937bd0341 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -60,6 +60,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.distribution.datatypes import AccessRule from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin): vector_io_api: VectorIO, persistence_store: KVStore, created_at: str, + policy: list[AccessRule], ): self.agent_id = agent_id self.agent_config = agent_config self.inference_api = inference_api self.safety_api = safety_api self.vector_io_api = vector_io_api - self.storage = AgentPersistence(agent_id, persistence_store) + self.storage = AgentPersistence(agent_id, persistence_store, policy) self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api self.created_at = created_at diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 4c3dcab15..ea3c5da97 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -41,6 +41,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.distribution.datatypes import AccessRule from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.responses.responses_store import ResponsesStore @@ -62,6 +63,7 @@ class MetaReferenceAgentsImpl(Agents): safety_api: Safety, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, + policy: list[AccessRule], ): self.config = config self.inference_api = inference_api @@ -72,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents): self.in_memory_store = InmemoryKVStoreImpl() self.openai_responses_impl: OpenAIResponsesImpl | None = None + self.policy = policy async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) @@ -130,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents): self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store ), created_at=agent_info.created_at, + policy=self.policy, ) async def create_agent_session( diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 5031a4a90..25dbb5df7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -10,9 +10,10 @@ import uuid from datetime import datetime, timezone from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn -from llama_stack.distribution.access_control import check_access -from llama_stack.distribution.datatypes import AccessAttributes -from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.distribution.access_control.datatypes import AccessRule +from llama_stack.distribution.datatypes import User +from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -22,7 +23,9 @@ class AgentSessionInfo(Session): # TODO: is this used anywhere? vector_db_id: str | None = None started_at: datetime - access_attributes: AccessAttributes | None = None + owner: User | None = None + identifier: str | None = None + type: str = "session" class AgentInfo(AgentConfig): @@ -30,24 +33,27 @@ class AgentInfo(AgentConfig): class AgentPersistence: - def __init__(self, agent_id: str, kvstore: KVStore): + def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]): self.agent_id = agent_id self.kvstore = kvstore + self.policy = policy async def create_session(self, name: str) -> str: session_id = str(uuid.uuid4()) # Get current user's auth attributes for new sessions - auth_attributes = get_auth_attributes() - access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None + user = get_authenticated_user() session_info = AgentSessionInfo( session_id=session_id, session_name=name, started_at=datetime.now(timezone.utc), - access_attributes=access_attributes, + owner=user, turns=[], + identifier=name, # should this be qualified in any way? ) + if not is_action_allowed(self.policy, "create", session_info, user): + raise AccessDeniedError() await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", @@ -73,10 +79,10 @@ class AgentPersistence: def _check_session_access(self, session_info: AgentSessionInfo) -> bool: """Check if current user has access to the session.""" # Handle backward compatibility for old sessions without access control - if not hasattr(session_info, "access_attributes"): + if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): return True - return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes()) + return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: """Get session info if the user has access to it. For internal use by sub-session methods.""" diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 2a30fd0b8..9cbdc8e51 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl): @pytest.mark.asyncio async def test_models_routing_table(cached_disk_dist_registry): - table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry) + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple models and verify listing @@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_shields_routing_table(cached_disk_dist_registry): - table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry) + table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple shields and verify listing @@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_vectordbs_routing_table(cached_disk_dist_registry): - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry) + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) await table.initialize() - m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await m_table.initialize() await m_table.register_model( model_id="test-model", - provider_id="test_providere", + provider_id="test_provider", metadata={"embedding_dimension": 128}, model_type=ModelType.embedding, ) @@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_datasets_routing_table(cached_disk_dist_registry): - table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry) + table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple datasets and verify listing @@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_scoring_functions_routing_table(cached_disk_dist_registry): - table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry) + table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple scoring functions and verify listing @@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_benchmarks_routing_table(cached_disk_dist_registry): - table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry) + table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple benchmarks and verify listing @@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_tool_groups_routing_table(cached_disk_dist_registry): - table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry) + table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple tool groups and verify listing diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index 9549f6df6..7a7d52892 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis): mock_apis["safety_api"], mock_apis["tool_runtime_api"], mock_apis["tool_groups_api"], + {}, ) await impl.initialize() yield impl diff --git a/tests/unit/providers/agents/test_persistence_access_control.py b/tests/unit/providers/agents/test_persistence_access_control.py index 48fa647a8..d5b876a09 100644 --- a/tests/unit/providers/agents/test_persistence_access_control.py +++ b/tests/unit/providers/agents/test_persistence_access_control.py @@ -12,24 +12,24 @@ import pytest from llama_stack.apis.agents import Turn from llama_stack.apis.inference import CompletionMessage, StopReason -from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.distribution.datatypes import User from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo @pytest.fixture async def test_setup(sqlite_kvstore): - agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore) + agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={}) yield agent_persistence @pytest.mark.asyncio -@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") -async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup): +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") +async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup): agent_persistence = test_setup # Set creator's attributes for the session creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]} - mock_get_auth_attributes.return_value = creator_attributes + mock_get_authenticated_user.return_value = User("test_user", creator_attributes) # Create a session session_id = await agent_persistence.create_session("Test Session") @@ -37,14 +37,15 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes, # Get the session and verify access attributes were set session_info = await agent_persistence.get_session_info(session_id) assert session_info is not None - assert session_info.access_attributes is not None - assert session_info.access_attributes.roles == ["researcher"] - assert session_info.access_attributes.teams == ["ai-team"] + assert session_info.owner is not None + assert session_info.owner.attributes is not None + assert session_info.owner.attributes["roles"] == ["researcher"] + assert session_info.owner.attributes["teams"] == ["ai-team"] @pytest.mark.asyncio -@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") -async def test_session_access_control(mock_get_auth_attributes, test_setup): +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") +async def test_session_access_control(mock_get_authenticated_user, test_setup): agent_persistence = test_setup # Create a session with specific access attributes @@ -53,8 +54,9 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup): session_id=session_id, session_name="Restricted Session", started_at=datetime.now(), - access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), + owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}), turns=[], + identifier="Restricted Session", ) await agent_persistence.kvstore.set( @@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup): ) # User with matching attributes can access - mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]} + mock_get_authenticated_user.return_value = User( + "testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]} + ) retrieved_session = await agent_persistence.get_session_info(session_id) assert retrieved_session is not None assert retrieved_session.session_id == session_id # User without matching attributes cannot access - mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]}) retrieved_session = await agent_persistence.get_session_info(session_id) assert retrieved_session is None @pytest.mark.asyncio -@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") -async def test_turn_access_control(mock_get_auth_attributes, test_setup): +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") +async def test_turn_access_control(mock_get_authenticated_user, test_setup): agent_persistence = test_setup # Create a session with restricted access @@ -85,8 +89,9 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): session_id=session_id, session_name="Restricted Session", started_at=datetime.now(), - access_attributes=AccessAttributes(roles=["admin"]), + owner=User("someone", {"roles": ["admin"]}), turns=[], + identifier="Restricted Session", ) await agent_persistence.kvstore.set( @@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): ) # Admin can add turn - mock_get_auth_attributes.return_value = {"roles": ["admin"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]}) await agent_persistence.add_turn_to_session(session_id, turn) # Admin can get turn @@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): assert retrieved_turn.turn_id == turn_id # Regular user cannot get turn - mock_get_auth_attributes.return_value = {"roles": ["user"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]}) with pytest.raises(ValueError): await agent_persistence.get_session_turn(session_id, turn_id) @@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") -async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup): +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") +async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup): agent_persistence = test_setup # Create a session with restricted access @@ -138,8 +143,9 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes session_id=session_id, session_name="Restricted Session", started_at=datetime.now(), - access_attributes=AccessAttributes(roles=["admin"]), + owner=User("someone", {"roles": ["admin"]}), turns=[], + identifier="Restricted Session", ) await agent_persistence.kvstore.set( @@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes turn_id = str(uuid.uuid4()) # Admin user can set inference iterations - mock_get_auth_attributes.return_value = {"roles": ["admin"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]}) await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5) # Admin user can get inference iterations @@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes assert infer_iters == 5 # Regular user cannot get inference iterations - mock_get_auth_attributes.return_value = {"roles": ["user"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]}) infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id) assert infer_iters is None diff --git a/tests/unit/registry/test_registry_acl.py b/tests/unit/registry/test_registry_acl.py index 25ea37bfa..48b3ac51b 100644 --- a/tests/unit/registry/test_registry_acl.py +++ b/tests/unit/registry/test_registry_acl.py @@ -8,19 +8,18 @@ import pytest from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ModelWithACL -from llama_stack.distribution.server.auth_providers import AccessAttributes +from llama_stack.distribution.datatypes import ModelWithOwner, User from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry @pytest.mark.asyncio async def test_registry_cache_with_acl(cached_disk_dist_registry): - model = ModelWithACL( + model = ModelWithOwner( identifier="model-acl", provider_id="test-provider", provider_resource_id="model-acl-resource", model_type=ModelType.llm, - access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]), + owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}), ) success = await cached_disk_dist_registry.register(model) @@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry): cached_model = cached_disk_dist_registry.get_cached("model", "model-acl") assert cached_model is not None assert cached_model.identifier == "model-acl" - assert cached_model.access_attributes.roles == ["admin"] - assert cached_model.access_attributes.teams == ["ai-team"] + assert cached_model.owner.principal == "testuser" + assert cached_model.owner.attributes["roles"] == ["admin"] + assert cached_model.owner.attributes["teams"] == ["ai-team"] fetched_model = await cached_disk_dist_registry.get("model", "model-acl") assert fetched_model is not None assert fetched_model.identifier == "model-acl" - assert fetched_model.access_attributes.roles == ["admin"] - - model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"]) - await cached_disk_dist_registry.update(model) - - updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl") - assert updated_cached is not None - assert updated_cached.access_attributes.roles == ["admin", "user"] - assert updated_cached.access_attributes.projects == ["project-x"] - assert updated_cached.access_attributes.teams is None + assert fetched_model.owner.attributes["roles"] == ["admin"] new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore) await new_registry.initialize() @@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry): new_model = await new_registry.get("model", "model-acl") assert new_model is not None assert new_model.identifier == "model-acl" - assert new_model.access_attributes.roles == ["admin", "user"] - assert new_model.access_attributes.projects == ["project-x"] - assert new_model.access_attributes.teams is None + assert new_model.owner.principal == "testuser" + assert new_model.owner.attributes["roles"] == ["admin"] + assert new_model.owner.attributes["teams"] == ["ai-team"] @pytest.mark.asyncio async def test_registry_empty_acl(cached_disk_dist_registry): - model = ModelWithACL( + model = ModelWithOwner( identifier="model-empty-acl", provider_id="test-provider", provider_resource_id="model-resource", model_type=ModelType.llm, - access_attributes=AccessAttributes(), + owner=User("testuser", None), ) await cached_disk_dist_registry.register(model) cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl") assert cached_model is not None - assert cached_model.access_attributes is not None - assert cached_model.access_attributes.roles is None - assert cached_model.access_attributes.teams is None - assert cached_model.access_attributes.projects is None - assert cached_model.access_attributes.namespaces is None + assert cached_model.owner is not None + assert cached_model.owner.attributes is None all_models = await cached_disk_dist_registry.get_all() assert len(all_models) == 1 - model = ModelWithACL( + model = ModelWithOwner( identifier="model-no-acl", provider_id="test-provider", provider_resource_id="model-resource-2", @@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry): cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl") assert cached_model is not None - assert cached_model.access_attributes is None + assert cached_model.owner is None all_models = await cached_disk_dist_registry.get_all() assert len(all_models) == 2 @@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry): @pytest.mark.asyncio async def test_registry_serialization(cached_disk_dist_registry): - attributes = AccessAttributes( - roles=["admin", "researcher"], - teams=["ai-team", "ml-team"], - projects=["project-a", "project-b"], - namespaces=["prod", "staging"], - ) + attributes = { + "roles": ["admin", "researcher"], + "teams": ["ai-team", "ml-team"], + "projects": ["project-a", "project-b"], + "namespaces": ["prod", "staging"], + } - model = ModelWithACL( + model = ModelWithOwner( identifier="model-serialize", provider_id="test-provider", provider_resource_id="model-resource", model_type=ModelType.llm, - access_attributes=attributes, + owner=User("bob", attributes), ) await cached_disk_dist_registry.register(model) @@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry): loaded_model = await new_registry.get("model", "model-serialize") assert loaded_model is not None - assert loaded_model.access_attributes.roles == ["admin", "researcher"] - assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"] - assert loaded_model.access_attributes.projects == ["project-a", "project-b"] - assert loaded_model.access_attributes.namespaces == ["prod", "staging"] + assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"] + assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"] + assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"] + assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"] diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index e352ba54d..f9ad47b0c 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -7,10 +7,13 @@ from unittest.mock import MagicMock, Mock, patch import pytest +import yaml +from pydantic import TypeAdapter, ValidationError from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL +from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User from llama_stack.distribution.routing_tables.models import ModelsRoutingTable @@ -32,39 +35,40 @@ async def test_setup(cached_disk_dist_registry): routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry, + policy={}, ) yield cached_disk_dist_registry, routing_table @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_access_control_with_cache(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup - model_public = ModelWithACL( + model_public = ModelWithOwner( identifier="model-public", provider_id="test_provider", provider_resource_id="model-public", model_type=ModelType.llm, ) - model_admin_only = ModelWithACL( + model_admin_only = ModelWithOwner( identifier="model-admin", provider_id="test_provider", provider_resource_id="model-admin", model_type=ModelType.llm, - access_attributes=AccessAttributes(roles=["admin"]), + owner=User("testuser", {"roles": ["admin"]}), ) - model_data_scientist = ModelWithACL( + model_data_scientist = ModelWithOwner( identifier="model-data-scientist", provider_id="test_provider", provider_resource_id="model-data-scientist", model_type=ModelType.llm, - access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]), + owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}), ) await registry.register(model_public) await registry.register(model_admin_only) await registry.register(model_data_scientist) - mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]} + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]}) all_models = await routing_table.list_models() assert len(all_models.data) == 2 @@ -75,7 +79,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): with pytest.raises(ValueError): await routing_table.get_model("model-data-scientist") - mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]} + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]}) all_models = await routing_table.list_models() assert len(all_models.data) == 1 assert all_models.data[0].identifier == "model-public" @@ -86,7 +90,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): with pytest.raises(ValueError): await routing_table.get_model("model-data-scientist") - mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]} + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]}) all_models = await routing_table.list_models() assert len(all_models.data) == 2 model_ids = [m.identifier for m in all_models.data] @@ -102,50 +106,62 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_access_control_and_updates(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup - model_public = ModelWithACL( + model_public = ModelWithOwner( identifier="model-updates", provider_id="test_provider", provider_resource_id="model-updates", model_type=ModelType.llm, ) await registry.register(model_public) - mock_get_auth_attributes.return_value = { - "roles": ["user"], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": ["user"], + }, + ) model = await routing_table.get_model("model-updates") assert model.identifier == "model-updates" - model_public.access_attributes = AccessAttributes(roles=["admin"]) + model_public.owner = User("testuser", {"roles": ["admin"]}) await registry.update(model_public) - mock_get_auth_attributes.return_value = { - "roles": ["user"], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": ["user"], + }, + ) with pytest.raises(ValueError): await routing_table.get_model("model-updates") - mock_get_auth_attributes.return_value = { - "roles": ["admin"], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": ["admin"], + }, + ) model = await routing_table.get_model("model-updates") assert model.identifier == "model-updates" @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup - model = ModelWithACL( + model = ModelWithOwner( identifier="model-empty-attrs", provider_id="test_provider", provider_resource_id="model-empty-attrs", model_type=ModelType.llm, - access_attributes=AccessAttributes(), + owner=User("testuser", {}), ) await registry.register(model) - mock_get_auth_attributes.return_value = { - "roles": [], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": [], + }, + ) result = await routing_table.get_model("model-empty-attrs") assert result.identifier == "model-empty-attrs" all_models = await routing_table.list_models() @@ -154,25 +170,25 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_no_user_attributes(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_no_user_attributes(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup - model_public = ModelWithACL( + model_public = ModelWithOwner( identifier="model-public-2", provider_id="test_provider", provider_resource_id="model-public-2", model_type=ModelType.llm, ) - model_restricted = ModelWithACL( + model_restricted = ModelWithOwner( identifier="model-restricted", provider_id="test_provider", provider_resource_id="model-restricted", model_type=ModelType.llm, - access_attributes=AccessAttributes(roles=["admin"]), + owner=User("testuser", {"roles": ["admin"]}), ) await registry.register(model_public) await registry.register(model_restricted) - mock_get_auth_attributes.return_value = None + mock_get_authenticated_user.return_value = User("test-user", None) model = await routing_table.get_model("model-public-2") assert model.identifier == "model-public-2" @@ -185,17 +201,17 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup): """Test that newly created resources inherit access attributes from their creator.""" registry, routing_table = test_setup # Set creator's attributes creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]} - mock_get_auth_attributes.return_value = creator_attributes + mock_get_authenticated_user.return_value = User("test-user", creator_attributes) # Create model without explicit access attributes - model = ModelWithACL( + model = ModelWithOwner( identifier="auto-access-model", provider_id="test_provider", provider_resource_id="auto-access-model", @@ -205,21 +221,346 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup) # Verify the model got creator's attributes registered_model = await routing_table.get_model("auto-access-model") - assert registered_model.access_attributes is not None - assert registered_model.access_attributes.roles == ["data-scientist"] - assert registered_model.access_attributes.teams == ["ml-team"] - assert registered_model.access_attributes.projects == ["llama-3"] + assert registered_model.owner is not None + assert registered_model.owner.attributes is not None + assert registered_model.owner.attributes["roles"] == ["data-scientist"] + assert registered_model.owner.attributes["teams"] == ["ml-team"] + assert registered_model.owner.attributes["projects"] == ["llama-3"] # Verify another user without matching attributes can't access it - mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]} + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]}) with pytest.raises(ValueError): await routing_table.get_model("auto-access-model") # But a user with matching attributes can - mock_get_auth_attributes.return_value = { - "roles": ["data-scientist", "engineer"], - "teams": ["ml-team", "platform-team"], - "projects": ["llama-3"], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": ["data-scientist", "engineer"], + "teams": ["ml-team", "platform-team"], + "projects": ["llama-3"], + }, + ) model = await routing_table.get_model("auto-access-model") assert model.identifier == "auto-access-model" + + +@pytest.fixture +async def test_setup_with_access_policy(cached_disk_dist_registry): + mock_inference = Mock() + mock_inference.__provider_spec__ = MagicMock() + mock_inference.__provider_spec__.api = Api.inference + mock_inference.register_model = AsyncMock(side_effect=_return_model) + mock_inference.unregister_model = AsyncMock(side_effect=_return_model) + + config = """ + - permit: + principal: user-1 + actions: [create, read, delete] + description: user-1 has full access to all models + - permit: + principal: user-2 + actions: [read] + resource: model::model-1 + description: user-2 has read access to model-1 only + - permit: + principal: user-3 + actions: [read] + resource: model::model-2 + description: user-3 has read access to model-2 only + - forbid: + actions: [create, read, delete] + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + routing_table = ModelsRoutingTable( + impls_by_provider_id={"test_provider": mock_inference}, + dist_registry=cached_disk_dist_registry, + policy=policy, + ) + yield routing_table + + +@pytest.mark.asyncio +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy): + routing_table = test_setup_with_access_policy + mock_get_authenticated_user.return_value = User( + "user-1", + { + "roles": ["admin"], + "projects": ["foo", "bar"], + }, + ) + await routing_table.register_model("model-1", provider_id="test_provider") + await routing_table.register_model("model-2", provider_id="test_provider") + await routing_table.register_model("model-3", provider_id="test_provider") + model = await routing_table.get_model("model-1") + assert model.identifier == "model-1" + model = await routing_table.get_model("model-2") + assert model.identifier == "model-2" + model = await routing_table.get_model("model-3") + assert model.identifier == "model-3" + + mock_get_authenticated_user.return_value = User( + "user-2", + { + "roles": ["user"], + "projects": ["foo"], + }, + ) + model = await routing_table.get_model("model-1") + assert model.identifier == "model-1" + with pytest.raises(ValueError): + await routing_table.get_model("model-2") + with pytest.raises(ValueError): + await routing_table.get_model("model-3") + with pytest.raises(AccessDeniedError): + await routing_table.register_model("model-4", provider_id="test_provider") + with pytest.raises(AccessDeniedError): + await routing_table.unregister_model("model-1") + + mock_get_authenticated_user.return_value = User( + "user-3", + { + "roles": ["user"], + "projects": ["bar"], + }, + ) + model = await routing_table.get_model("model-2") + assert model.identifier == "model-2" + with pytest.raises(ValueError): + await routing_table.get_model("model-1") + with pytest.raises(ValueError): + await routing_table.get_model("model-3") + with pytest.raises(AccessDeniedError): + await routing_table.register_model("model-5", provider_id="test_provider") + with pytest.raises(AccessDeniedError): + await routing_table.unregister_model("model-2") + + mock_get_authenticated_user.return_value = User( + "user-1", + { + "roles": ["admin"], + "projects": ["foo", "bar"], + }, + ) + await routing_table.unregister_model("model-3") + with pytest.raises(ValueError): + await routing_table.get_model("model-3") + + +def test_permit_when(): + config = """ + - permit: + principal: user-1 + actions: [read] + when: user in owners namespaces + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + owner=User("testuser", {"namespaces": ["foo"]}), + ) + assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) + + +def test_permit_unless(): + config = """ + - permit: + principal: user-1 + actions: [read] + resource: model::* + unless: + - user not in owners namespaces + - user in owners teams + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + owner=User("testuser", {"namespaces": ["foo"]}), + ) + assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) + + +def test_forbid_when(): + config = """ + - forbid: + principal: user-1 + actions: [read] + when: + user in owners namespaces + - permit: + actions: [read] + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + owner=User("testuser", {"namespaces": ["foo"]}), + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) + assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) + + +def test_forbid_unless(): + config = """ + - forbid: + principal: user-1 + actions: [read] + unless: + user in owners namespaces + - permit: + actions: [read] + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + owner=User("testuser", {"namespaces": ["foo"]}), + ) + assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) + + +def test_user_has_attribute(): + config = """ + - permit: + actions: [read] + when: user with admin in roles + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) + assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-4", None)) + + +def test_user_does_not_have_attribute(): + config = """ + - permit: + actions: [read] + unless: user with admin not in roles + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) + assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-4", None)) + + +def test_is_owner(): + config = """ + - permit: + actions: [read] + when: user is owner + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + owner=User("user-2", {"namespaces": ["foo"]}), + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) + assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-4", None)) + + +def test_is_not_owner(): + config = """ + - permit: + actions: [read] + unless: user is not owner + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + owner=User("user-2", {"namespaces": ["foo"]}), + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) + assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-4", None)) + + +def test_invalid_rule_permit_and_forbid_both_specified(): + config = """ + - permit: + actions: [read] + forbid: + actions: [create] + """ + with pytest.raises(ValidationError): + TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + + +def test_invalid_rule_neither_permit_or_forbid_specified(): + config = """ + - when: user is owner + unless: user with admin in roles + """ + with pytest.raises(ValidationError): + TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + + +def test_invalid_rule_when_and_unless_both_specified(): + config = """ + - permit: + actions: [read] + when: user is owner + unless: user with admin in roles + """ + with pytest.raises(ValidationError): + TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + + +def test_invalid_condition(): + config = """ + - permit: + actions: [read] + when: random words that are not valid + """ + with pytest.raises(ValidationError): + TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + + +@pytest.mark.parametrize( + "condition", + [ + "user is owner", + "user is not owner", + "user with dev in teams", + "user with default not in namespaces", + "user in owners roles", + "user not in owners projects", + ], +) +def test_condition_reprs(condition): + from llama_stack.distribution.access_control.conditions import parse_condition + + assert condition == str(parse_condition(condition)) diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 408acb88a..e159aefd1 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs): { "message": "Authentication successful", "principal": "test-principal", - "access_attributes": { + "attributes": { "roles": ["admin", "user"], "teams": ["ml-team", "nlp-team"], "projects": ["llama-3", "project-x"], @@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock { "message": "Authentication successful", "principal": "test-principal", - "access_attributes": { + "attributes": { "roles": ["admin", "user"], "teams": ["ml-team", "nlp-team"], "projects": ["llama-3", "project-x"], @@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) -@pytest.mark.asyncio -async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope): - """Test middleware behavior with no access attributes""" - middleware, mock_app = mock_http_middleware - mock_receive = AsyncMock() - mock_send = AsyncMock() - - with patch("httpx.AsyncClient") as mock_client: - mock_client_instance = AsyncMock() - mock_client.return_value.__aenter__.return_value = mock_client_instance - - mock_client_instance.post.return_value = MockResponse( - 200, - { - "message": "Authentication successful" - # No access_attributes - }, - ) - - await middleware(mock_scope, mock_receive, mock_send) - - assert "user_attributes" in mock_scope - attributes = mock_scope["user_attributes"] - assert "roles" in attributes - assert attributes["roles"] == ["test.jwt.token"] - - # oauth2 token provider tests @@ -380,16 +353,16 @@ def test_get_attributes_from_claims(): "aud": "llama-stack", } attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"}) - assert attributes.roles == ["my-user"] - assert attributes.teams == ["group1", "group2"] + assert attributes["roles"] == ["my-user"] + assert attributes["teams"] == ["group1", "group2"] claims = { "sub": "my-user", "tenant": "my-tenant", } attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"}) - assert attributes.roles == ["my-user"] - assert attributes.namespaces == ["my-tenant"] + assert attributes["roles"] == ["my-user"] + assert attributes["namespaces"] == ["my-tenant"] claims = { "sub": "my-user", @@ -408,9 +381,9 @@ def test_get_attributes_from_claims(): "groups": "teams", }, ) - assert set(attributes.roles) == {"my-user", "my-username"} - assert set(attributes.teams) == {"my-team", "group1", "group2"} - assert attributes.namespaces == ["my-tenant"] + assert set(attributes["roles"]) == {"my-user", "my-username"} + assert set(attributes["teams"]) == {"my-team", "group1", "group2"} + assert attributes["namespaces"] == ["my-tenant"] # TODO: add more tests for oauth2 token provider diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index bb4c15dbc..acf4da0a3 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -100,9 +100,10 @@ async def test_resolve_impls_basic(): add_protocol_methods(SampleImpl, Inference) mock_module.get_provider_impl = AsyncMock(return_value=impl) + mock_module.get_provider_impl.__text_signature__ = "()" sys.modules["test_module"] = mock_module - impls = await resolve_impls(run_config, provider_registry, dist_registry) + impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={}) assert Api.inference in impls assert isinstance(impls[Api.inference], InferenceRouter)