diff --git a/llama_stack/distribution/access_control.py b/llama_stack/distribution/access_control.py deleted file mode 100644 index d560ec80f..000000000 --- a/llama_stack/distribution/access_control.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.distribution.datatypes import AccessAttributes -from llama_stack.log import get_logger - -logger = get_logger(__name__, category="core") - - -def check_access( - obj_identifier: str, - obj_attributes: AccessAttributes | None, - user_attributes: dict[str, Any] | None = None, -) -> bool: - """Check if the current user has access to the given object, based on access attributes. - - Access control algorithm: - 1. If the resource has no access_attributes, access is GRANTED to all authenticated users - 2. If the user has no attributes, access is DENIED to any object with access_attributes defined - 3. For each attribute category in the resource's access_attributes: - a. If the user lacks that category, access is DENIED - b. If the user has the category but none of the required values, access is DENIED - c. If the user has at least one matching value in each required category, access is GRANTED - - Example: - # Resource requires: - access_attributes = AccessAttributes( - roles=["admin", "data-scientist"], - teams=["ml-team"] - ) - - # User has: - user_attributes = { - "roles": ["data-scientist", "engineer"], - "teams": ["ml-team", "infra-team"], - "projects": ["llama-3"] - } - - # Result: Access GRANTED - # - User has the "data-scientist" role (matches one of the required roles) - # - AND user is part of the "ml-team" (matches the required team) - # - The extra "projects" attribute is ignored - - Args: - obj_identifier: The identifier of the resource object to check access for - obj_attributes: The access attributes of the resource object - user_attributes: The attributes of the current user - - Returns: - bool: True if access is granted, False if denied - """ - # If object has no access attributes, allow access by default - if not obj_attributes: - return True - - # If no user attributes, deny access to objects with access control - if not user_attributes: - return False - - dict_attribs = obj_attributes.model_dump(exclude_none=True) - if not dict_attribs: - return True - - # Check each attribute category (requires ALL categories to match) - # TODO: formalize this into a proper ABAC policy - for attr_key, required_values in dict_attribs.items(): - user_values = user_attributes.get(attr_key, []) - - if not user_values: - logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'") - return False - - if not any(val in user_values for val in required_values): - logger.debug( - f"Access denied to {obj_identifier}: " - f"no match for attribute '{attr_key}', required one of {required_values}" - ) - return False - - logger.debug(f"Access granted to {obj_identifier}") - return True diff --git a/llama_stack/distribution/access_control/__init__.py b/llama_stack/distribution/access_control/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/access_control/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/access_control/access_control.py b/llama_stack/distribution/access_control/access_control.py new file mode 100644 index 000000000..e68325783 --- /dev/null +++ b/llama_stack/distribution/access_control/access_control.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Protocol + +from llama_stack.distribution.request_headers import User + +from .datatypes import ( + AccessAttributes, + AccessRule, + Action, + AttributeReference, + Condition, + Scope, +) + + +def matches_resource(resource_scope: str, actual_resource: str) -> bool: + if resource_scope == actual_resource: + return True + return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1]) + + +def matches_scope( + scope: Scope, + action: Action, + resource: str, + user: str | None, +) -> bool: + if scope.resource and not matches_resource(scope.resource, resource): + return False + if scope.principal and scope.principal != user: + return False + return action in scope.actions + + +def user_in_literal( + literal: str, + user_attributes: dict[str, list[str]] | None, +) -> bool: + for qualifier in ["role::", "team::", "project::", "namespace::"]: + if literal.startswith(qualifier): + if not user_attributes: + return False + ref = qualifier.replace("::", "s") + if ref in user_attributes: + value = literal.removeprefix(qualifier) + return value in user_attributes[ref] + else: + return False + return False + + +def user_in( + ref: AttributeReference | str, + resource_attributes: AccessAttributes | None, + user_attributes: dict[str, list[str]] | None, +) -> bool: + if not ref.startswith("resource."): + return user_in_literal(ref, user_attributes) + name = ref.removeprefix("resource.") + required = resource_attributes and getattr(resource_attributes, name) + if not required: + return True + if not user_attributes or name not in user_attributes: + return False + actual = user_attributes[name] + for value in required: + if value in actual: + return True + return False + + +def as_list(obj: Any) -> list[Any]: + if isinstance(obj, list): + return obj + return [obj] + + +def matches_conditions( + conditions: list[Condition], + resource_attributes: AccessAttributes | None, + user_attributes: dict[str, list[str]] | None, +) -> bool: + for condition in conditions: + # must match all conditions + if not matches_condition(condition, resource_attributes, user_attributes): + return False + return True + + +def matches_condition( + condition: Condition | list[Condition], + resource_attributes: AccessAttributes | None, + user_attributes: dict[str, list[str]] | None, +) -> bool: + if isinstance(condition, list): + return matches_conditions(as_list(condition), resource_attributes, user_attributes) + if condition.user_in: + for ref in as_list(condition.user_in): + # if multiple references are specified, all must match + if not user_in(ref, resource_attributes, user_attributes): + return False + return True + if condition.user_not_in: + for ref in as_list(condition.user_not_in): + # if multiple references are specified, none must match + if user_in(ref, resource_attributes, user_attributes): + return False + return True + return True + + +def default_policy() -> list[AccessRule]: + # for backwards compatibility, if no rules are provided , assume + # full access to all subject to attribute matching rules + return [ + AccessRule( + permit=Scope(actions=list(Action)), + when=Condition(user_in=list(AttributeReference)), + ) + ] + + +class ProtectedResource(Protocol): + type: str + identifier: str + access_attributes: AccessAttributes + + +def is_action_allowed( + policy: list[AccessRule], + action: Action, + resource: ProtectedResource, + user: User | None, +) -> bool: + # If user is not set, assume authentication is not enabled + if not user: + return True + + if not len(policy): + policy = default_policy() + + resource_attributes = AccessAttributes() + if hasattr(resource, "access_attributes"): + resource_attributes = resource.access_attributes + qualified_resource_id = resource.type + "::" + resource.identifier + for rule in policy: + if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal): + if rule.when: + if matches_condition(rule.when, resource_attributes, user.attributes): + return False + elif rule.unless: + if not matches_condition(rule.unless, resource_attributes, user.attributes): + return False + else: + return False + elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal): + if rule.when: + if matches_condition(rule.when, resource_attributes, user.attributes): + return True + elif rule.unless: + if not matches_condition(rule.unless, resource_attributes, user.attributes): + return True + else: + return True + # assume access is denied unless we find a rule that permits access + return False + + +class AccessDeniedError(RuntimeError): + pass diff --git a/llama_stack/distribution/access_control/datatypes.py b/llama_stack/distribution/access_control/datatypes.py new file mode 100644 index 000000000..9f95f2c7e --- /dev/null +++ b/llama_stack/distribution/access_control/datatypes.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum + +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self + + +class AccessAttributes(BaseModel): + """Structured representation of user attributes for access control. + + This model defines a structured approach to representing user attributes + with common standard categories for access control. + + Standard attribute categories include: + - roles: Role-based attributes (e.g., admin, data-scientist) + - teams: Team-based attributes (e.g., ml-team, infra-team) + - projects: Project access attributes (e.g., llama-3, customer-insights) + - namespaces: Namespace-based access control for resource isolation + """ + + # Standard attribute categories - the minimal set we need now + roles: list[str] | None = Field( + default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')" + ) + + teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')") + + projects: list[str] | None = Field( + default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')" + ) + + namespaces: list[str] | None = Field( + default=None, description="Namespace-based access control for resource isolation" + ) + + +class Action(str, Enum): + CREATE = "create" + READ = "read" + UPDATE = "update" + DELETE = "delete" + + +class Scope(BaseModel): + principal: str | None = None + actions: Action | list[Action] + resource: str | None = None + + +def _mutually_exclusive(obj, a: str, b: str): + if getattr(obj, a) and getattr(obj, b): + raise ValueError(f"{a} and {b} are mutually exclusive") + + +def _require_one_of(obj, a: str, b: str): + if not getattr(obj, a) and not getattr(obj, b): + raise ValueError(f"on of {a} or {b} is required") + + +class AttributeReference(str, Enum): + RESOURCE_ROLES = "resource.roles" + RESOURCE_TEAMS = "resource.teams" + RESOURCE_PROJECTS = "resource.projects" + RESOURCE_NAMESPACES = "resource.namespaces" + + +class Condition(BaseModel): + user_in: AttributeReference | list[AttributeReference] | str | None = None + user_not_in: AttributeReference | list[AttributeReference] | str | None = None + + +class AccessRule(BaseModel): + """Access rule based loosely on cedar policy language + + A rule defines a list of action either to permit or to forbid. It may specify a + principal or a resource that must match for the rule to take effect. The resource + to match should be specified in the form of a type qualified identifier, e.g. + model::my-model or vector_db::some-db, or a wildcard for all resources of a type, + e.g. model::*. If the principal or resource are not specified, they will match all + requests. + + A rule may also specify a condition, either a 'when' or an 'unless', with additional + constraints as to where the rule applies. The constraints at present are whether the + user requesting access is in or not in some set. This set can either be a particular + set of attributes on the resource e.g. resource.roles or a literal value of some + notion of group, e.g. role::admin or namespace::foo. + + Rules are tested in order to find a match. If a match is found, the request is + permitted or forbidden depending on the type of rule. If no match is found, the + request is denied. If no rules are specified, a rule that allows any action as + long as the resource attributes match the user attributes is added + (i.e. the previous behaviour is the default). + + Some examples in yaml: + + - permit: + principal: user-1 + actions: [create, read, delete] + resource: model::* + description: user-1 has full access to all models + - permit: + principal: user-2 + actions: [read] + resource: model::model-1 + description: user-2 has read access to model-1 only + - permit: + actions: [read] + when: + user_in: resource.namespaces + description: any user has read access to any resource with matching attributes + - forbid: + actions: [create, read, delete] + resource: vector_db::* + unless: + user_in: role::admin + description: only user with admin role can use vector_db resources + + """ + + permit: Scope | None = None + forbid: Scope | None = None + when: Condition | list[Condition] | None = None + unless: Condition | list[Condition] | None = None + description: str | None = None + + @model_validator(mode="after") + def validate_rule_format(self) -> Self: + _require_one_of(self, "permit", "forbid") + _mutually_exclusive(self, "permit", "forbid") + _mutually_exclusive(self, "when", "unless") + return self diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index def7048c0..a01cbdfb8 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO +from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig @@ -35,35 +36,6 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2" RoutingKey = str | list[str] -class AccessAttributes(BaseModel): - """Structured representation of user attributes for access control. - - This model defines a structured approach to representing user attributes - with common standard categories for access control. - - Standard attribute categories include: - - roles: Role-based attributes (e.g., admin, data-scientist) - - teams: Team-based attributes (e.g., ml-team, infra-team) - - projects: Project access attributes (e.g., llama-3, customer-insights) - - namespaces: Namespace-based access control for resource isolation - """ - - # Standard attribute categories - the minimal set we need now - roles: list[str] | None = Field( - default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')" - ) - - teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')") - - projects: list[str] | None = Field( - default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')" - ) - - namespaces: list[str] | None = Field( - default=None, description="Namespace-based access control for resource isolation" - ) - - class ResourceWithACL(Resource): """Extension of Resource that adds attribute-based access control capabilities. @@ -234,6 +206,7 @@ class AuthenticationConfig(BaseModel): ..., description="Provider-specific configuration", ) + access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources") class AuthenticationRequiredError(Exception): diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index b03d2dee8..7cf16ea3c 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -18,15 +18,23 @@ log = logging.getLogger(__name__) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) +class User: + principal: str + # further attributes that may be used for access control decisions + attributes: dict[str, list[str]] + + def __init__(self, principal: str, attributes: dict[str, list[str]]): + self.principal = principal + self.attributes = attributes + + class RequestProviderDataContext(AbstractContextManager): """Context manager for request provider data""" - def __init__( - self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None - ): + def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None): self.provider_data = provider_data or {} - if auth_attributes: - self.provider_data["__auth_attributes"] = auth_attributes + if user: + self.provider_data["__authenticated_user"] = user self.token = None @@ -95,9 +103,9 @@ def request_provider_data_context( return RequestProviderDataContext(provider_data, auth_attributes) -def get_auth_attributes() -> dict[str, list[str]] | None: +def get_authenticated_user() -> User | None: """Helper to retrieve auth attributes from the provider data context""" provider_data = PROVIDER_DATA_VAR.get() if not provider_data: return None - return provider_data.get("__auth_attributes") + return provider_data.get("__authenticated_user") diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index b7c7cb87f..6e7bb5edd 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.datatypes import ( + AccessRule, AutoRoutedProviderSpec, Provider, RoutingTableProviderSpec, @@ -118,6 +119,7 @@ async def resolve_impls( run_config: StackRunConfig, provider_registry: ProviderRegistry, dist_registry: DistributionRegistry, + policy: list[AccessRule], ) -> dict[Api, Any]: """ Resolves provider implementations by: @@ -140,7 +142,7 @@ async def resolve_impls( sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) - return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config) + return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy) def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: @@ -247,6 +249,7 @@ async def instantiate_providers( router_apis: set[Api], dist_registry: DistributionRegistry, run_config: StackRunConfig, + policy: list[AccessRule], ) -> dict: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} @@ -261,7 +264,7 @@ async def instantiate_providers( if isinstance(provider.spec, RoutingTableProviderSpec): inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] - impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config) + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy) if api_str.startswith("inner-"): inner_impls_by_provider_id[api_str][provider.provider_id] = impl @@ -312,6 +315,7 @@ async def instantiate_provider( inner_impls: dict[str, Any], dist_registry: DistributionRegistry, run_config: StackRunConfig, + policy: list[AccessRule], ): provider_spec = provider.spec if not hasattr(provider_spec, "module"): @@ -336,13 +340,15 @@ async def instantiate_provider( method = "get_routing_table_impl" config = None - args = [provider_spec.api, inner_impls, deps, dist_registry] + args = [provider_spec.api, inner_impls, deps, dist_registry, policy] else: method = "get_provider_impl" config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider.config) args = [config, deps] + if "policy" in inspect.signature(getattr(module, method)).parameters: + args.append(policy) fn = getattr(module, method) impl = await fn(*args) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 1358d5812..0a0c13880 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import RoutedProtocol +from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol from llama_stack.distribution.stack import StackRunConfig from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable @@ -18,6 +18,7 @@ async def get_routing_table_impl( impls_by_provider_id: dict[str, RoutedProtocol], _deps, dist_registry: DistributionRegistry, + policy: list[AccessRule], ) -> Any: from ..routing_tables.benchmarks import BenchmarksRoutingTable from ..routing_tables.datasets import DatasetsRoutingTable @@ -40,7 +41,7 @@ async def get_routing_table_impl( if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) + impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy) await impl.initialize() return impl diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 8ec87ca50..c31f58bcf 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -8,14 +8,15 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn -from llama_stack.distribution.access_control import check_access +from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.datatypes import ( AccessAttributes, + AccessRule, RoutableObject, RoutableObjectWithProvider, RoutedProtocol, ) -from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.distribution.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable @@ -73,9 +74,11 @@ class CommonRoutingTableImpl(RoutingTable): self, impls_by_provider_id: dict[str, RoutedProtocol], dist_registry: DistributionRegistry, + policy: list[AccessRule], ) -> None: self.impls_by_provider_id = impls_by_provider_id self.dist_registry = dist_registry + self.policy = policy async def initialize(self) -> None: async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: @@ -166,13 +169,15 @@ class CommonRoutingTableImpl(RoutingTable): return None # Check if user has permission to access this object - if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): - logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") + if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()): + logger.debug(f"Access denied to {type} '{identifier}'") return None return obj async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: + if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()): + raise AccessDeniedError() await self.dist_registry.delete(obj.type, obj.identifier) await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) @@ -187,11 +192,12 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] # If object supports access control but no attributes set, use creator's attributes - if not obj.access_attributes: - creator_attributes = get_auth_attributes() - if creator_attributes: - obj.access_attributes = AccessAttributes(**creator_attributes) - logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + creator = get_authenticated_user() + if not is_action_allowed(self.policy, "create", obj, creator): + raise AccessDeniedError() + if creator and creator.attributes: + obj.access_attributes = AccessAttributes(**creator.attributes) + logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object @@ -210,9 +216,7 @@ class CommonRoutingTableImpl(RoutingTable): # Apply attribute-based access control filtering if filtered_objs: filtered_objs = [ - obj - for obj in filtered_objs - if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) + obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user()) ] return filtered_objs diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index d70f06691..2994e0dc9 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -30,10 +30,7 @@ from pydantic import BaseModel, ValidationError from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.request_headers import ( - PROVIDER_DATA_VAR, - request_provider_data_context, -) +from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.server.endpoints import ( find_matching_endpoint, @@ -213,11 +210,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): async def endpoint(request: Request, **kwargs): # Get auth attributes from the request scope user_attributes = request.scope.get("user_attributes", {}) + principal = request.scope.get("principal", "") + user = User(principal, user_attributes) await log_request_pre_validation(request) # Use context manager with both provider data and auth attributes - with request_provider_data_context(request.headers, user_attributes): + with request_provider_data_context(request.headers, user): is_streaming = is_streaming_request(func.__name__, request, **kwargs) try: diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index fc68dc016..5a9708497 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -223,7 +223,10 @@ async def construct_stack( run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None ) -> dict[Api, Any]: dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) - impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) + policy = run_config.server.auth.access_policy if run_config.server.auth else [] + impls = await resolve_impls( + run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy + ) # Add internal implementations after all other providers are resolved add_internal_implementations(impls, run_config) diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 7503b8c90..4a77e65b9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -6,12 +6,12 @@ from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.datatypes import AccessRule, Api from .config import MetaReferenceAgentsImplConfig -async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]): +async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): from .agents import MetaReferenceAgentsImpl impl = MetaReferenceAgentsImpl( @@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap deps[Api.safety], deps[Api.tool_runtime], deps[Api.tool_groups], + policy, ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2e387e7e8..937bd0341 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -60,6 +60,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.distribution.datatypes import AccessRule from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin): vector_io_api: VectorIO, persistence_store: KVStore, created_at: str, + policy: list[AccessRule], ): self.agent_id = agent_id self.agent_config = agent_config self.inference_api = inference_api self.safety_api = safety_api self.vector_io_api = vector_io_api - self.storage = AgentPersistence(agent_id, persistence_store) + self.storage = AgentPersistence(agent_id, persistence_store, policy) self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api self.created_at = created_at diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index bcbfcbe31..399448ec1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -40,6 +40,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.distribution.datatypes import AccessRule from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.responses.responses_store import ResponsesStore @@ -61,6 +62,7 @@ class MetaReferenceAgentsImpl(Agents): safety_api: Safety, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, + policy: list[AccessRule], ): self.config = config self.inference_api = inference_api @@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents): self.in_memory_store = InmemoryKVStoreImpl() self.openai_responses_impl: OpenAIResponsesImpl | None = None + self.policy = policy async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) @@ -129,6 +132,7 @@ class MetaReferenceAgentsImpl(Agents): self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store ), created_at=agent_info.created_at, + policy=self.policy, ) async def create_agent_session( diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 5031a4a90..322bcaf04 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -10,9 +10,9 @@ import uuid from datetime import datetime, timezone from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn -from llama_stack.distribution.access_control import check_access -from llama_stack.distribution.datatypes import AccessAttributes -from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule +from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -23,6 +23,8 @@ class AgentSessionInfo(Session): vector_db_id: str | None = None started_at: datetime access_attributes: AccessAttributes | None = None + identifier: str | None = None + type: str = "session" class AgentInfo(AgentConfig): @@ -30,15 +32,17 @@ class AgentInfo(AgentConfig): class AgentPersistence: - def __init__(self, agent_id: str, kvstore: KVStore): + def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]): self.agent_id = agent_id self.kvstore = kvstore + self.policy = policy async def create_session(self, name: str) -> str: session_id = str(uuid.uuid4()) # Get current user's auth attributes for new sessions - auth_attributes = get_auth_attributes() + user = get_authenticated_user() + auth_attributes = user and user.attributes access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None session_info = AgentSessionInfo( @@ -47,7 +51,10 @@ class AgentPersistence: started_at=datetime.now(timezone.utc), access_attributes=access_attributes, turns=[], + identifier=name, # should this be qualified in any way? ) + if not is_action_allowed(self.policy, "create", session_info, user): + raise AccessDeniedError() await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", @@ -76,7 +83,7 @@ class AgentPersistence: if not hasattr(session_info, "access_attributes"): return True - return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes()) + return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: """Get session info if the user has access to it. For internal use by sub-session methods.""" diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 2a30fd0b8..9cbdc8e51 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl): @pytest.mark.asyncio async def test_models_routing_table(cached_disk_dist_registry): - table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry) + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple models and verify listing @@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_shields_routing_table(cached_disk_dist_registry): - table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry) + table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple shields and verify listing @@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_vectordbs_routing_table(cached_disk_dist_registry): - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry) + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) await table.initialize() - m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await m_table.initialize() await m_table.register_model( model_id="test-model", - provider_id="test_providere", + provider_id="test_provider", metadata={"embedding_dimension": 128}, model_type=ModelType.embedding, ) @@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_datasets_routing_table(cached_disk_dist_registry): - table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry) + table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple datasets and verify listing @@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_scoring_functions_routing_table(cached_disk_dist_registry): - table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry) + table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple scoring functions and verify listing @@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_benchmarks_routing_table(cached_disk_dist_registry): - table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry) + table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple benchmarks and verify listing @@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_tool_groups_routing_table(cached_disk_dist_registry): - table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry) + table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple tool groups and verify listing diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index 9549f6df6..7a7d52892 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis): mock_apis["safety_api"], mock_apis["tool_runtime_api"], mock_apis["tool_groups_api"], + {}, ) await impl.initialize() yield impl diff --git a/tests/unit/providers/agents/test_persistence_access_control.py b/tests/unit/providers/agents/test_persistence_access_control.py index 48fa647a8..3563ae60c 100644 --- a/tests/unit/providers/agents/test_persistence_access_control.py +++ b/tests/unit/providers/agents/test_persistence_access_control.py @@ -13,23 +13,24 @@ import pytest from llama_stack.apis.agents import Turn from llama_stack.apis.inference import CompletionMessage, StopReason from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.distribution.request_headers import User from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo @pytest.fixture async def test_setup(sqlite_kvstore): - agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore) + agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={}) yield agent_persistence @pytest.mark.asyncio -@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") -async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup): +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") +async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup): agent_persistence = test_setup # Set creator's attributes for the session creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]} - mock_get_auth_attributes.return_value = creator_attributes + mock_get_authenticated_user.return_value = User("test_user", creator_attributes) # Create a session session_id = await agent_persistence.create_session("Test Session") @@ -43,8 +44,8 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes, @pytest.mark.asyncio -@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") -async def test_session_access_control(mock_get_auth_attributes, test_setup): +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") +async def test_session_access_control(mock_get_authenticated_user, test_setup): agent_persistence = test_setup # Create a session with specific access attributes @@ -55,6 +56,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup): started_at=datetime.now(), access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), turns=[], + identifier="Restricted Session", ) await agent_persistence.kvstore.set( @@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup): ) # User with matching attributes can access - mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]} + mock_get_authenticated_user.return_value = User( + "testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]} + ) retrieved_session = await agent_persistence.get_session_info(session_id) assert retrieved_session is not None assert retrieved_session.session_id == session_id # User without matching attributes cannot access - mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]}) retrieved_session = await agent_persistence.get_session_info(session_id) assert retrieved_session is None @pytest.mark.asyncio -@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") -async def test_turn_access_control(mock_get_auth_attributes, test_setup): +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") +async def test_turn_access_control(mock_get_authenticated_user, test_setup): agent_persistence = test_setup # Create a session with restricted access @@ -87,6 +91,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): started_at=datetime.now(), access_attributes=AccessAttributes(roles=["admin"]), turns=[], + identifier="Restricted Session", ) await agent_persistence.kvstore.set( @@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): ) # Admin can add turn - mock_get_auth_attributes.return_value = {"roles": ["admin"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]}) await agent_persistence.add_turn_to_session(session_id, turn) # Admin can get turn @@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): assert retrieved_turn.turn_id == turn_id # Regular user cannot get turn - mock_get_auth_attributes.return_value = {"roles": ["user"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]}) with pytest.raises(ValueError): await agent_persistence.get_session_turn(session_id, turn_id) @@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") -async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup): +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") +async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup): agent_persistence = test_setup # Create a session with restricted access @@ -140,6 +145,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes started_at=datetime.now(), access_attributes=AccessAttributes(roles=["admin"]), turns=[], + identifier="Restricted Session", ) await agent_persistence.kvstore.set( @@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes turn_id = str(uuid.uuid4()) # Admin user can set inference iterations - mock_get_auth_attributes.return_value = {"roles": ["admin"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]}) await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5) # Admin user can get inference iterations @@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes assert infer_iters == 5 # Regular user cannot get inference iterations - mock_get_auth_attributes.return_value = {"roles": ["user"]} + mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]}) infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id) assert infer_iters is None diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index e352ba54d..70faee748 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -7,10 +7,14 @@ from unittest.mock import MagicMock, Mock, patch import pytest +import yaml +from pydantic import TypeAdapter from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL +from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL +from llama_stack.distribution.request_headers import User from llama_stack.distribution.routing_tables.models import ModelsRoutingTable @@ -32,13 +36,14 @@ async def test_setup(cached_disk_dist_registry): routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry, + policy={}, ) yield cached_disk_dist_registry, routing_table @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_access_control_with_cache(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( identifier="model-public", @@ -64,7 +69,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): await registry.register(model_admin_only) await registry.register(model_data_scientist) - mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]} + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]}) all_models = await routing_table.list_models() assert len(all_models.data) == 2 @@ -75,7 +80,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): with pytest.raises(ValueError): await routing_table.get_model("model-data-scientist") - mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]} + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]}) all_models = await routing_table.list_models() assert len(all_models.data) == 1 assert all_models.data[0].identifier == "model-public" @@ -86,7 +91,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): with pytest.raises(ValueError): await routing_table.get_model("model-data-scientist") - mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]} + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]}) all_models = await routing_table.list_models() assert len(all_models.data) == 2 model_ids = [m.identifier for m in all_models.data] @@ -102,8 +107,8 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_access_control_and_updates(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( identifier="model-updates", @@ -112,28 +117,37 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): model_type=ModelType.llm, ) await registry.register(model_public) - mock_get_auth_attributes.return_value = { - "roles": ["user"], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": ["user"], + }, + ) model = await routing_table.get_model("model-updates") assert model.identifier == "model-updates" model_public.access_attributes = AccessAttributes(roles=["admin"]) await registry.update(model_public) - mock_get_auth_attributes.return_value = { - "roles": ["user"], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": ["user"], + }, + ) with pytest.raises(ValueError): await routing_table.get_model("model-updates") - mock_get_auth_attributes.return_value = { - "roles": ["admin"], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": ["admin"], + }, + ) model = await routing_table.get_model("model-updates") assert model.identifier == "model-updates" @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup model = ModelWithACL( identifier="model-empty-attrs", @@ -143,9 +157,12 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se access_attributes=AccessAttributes(), ) await registry.register(model) - mock_get_auth_attributes.return_value = { - "roles": [], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": [], + }, + ) result = await routing_table.get_model("model-empty-attrs") assert result.identifier == "model-empty-attrs" all_models = await routing_table.list_models() @@ -154,8 +171,8 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_no_user_attributes(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_no_user_attributes(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( identifier="model-public-2", @@ -172,7 +189,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup): ) await registry.register(model_public) await registry.register(model_restricted) - mock_get_auth_attributes.return_value = None + mock_get_authenticated_user.return_value = User("test-user", None) model = await routing_table.get_model("model-public-2") assert model.identifier == "model-public-2" @@ -185,14 +202,14 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") -async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup): +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup): """Test that newly created resources inherit access attributes from their creator.""" registry, routing_table = test_setup # Set creator's attributes creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]} - mock_get_auth_attributes.return_value = creator_attributes + mock_get_authenticated_user.return_value = User("test-user", creator_attributes) # Create model without explicit access attributes model = ModelWithACL( @@ -211,15 +228,262 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup) assert registered_model.access_attributes.projects == ["llama-3"] # Verify another user without matching attributes can't access it - mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]} + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]}) with pytest.raises(ValueError): await routing_table.get_model("auto-access-model") # But a user with matching attributes can - mock_get_auth_attributes.return_value = { - "roles": ["data-scientist", "engineer"], - "teams": ["ml-team", "platform-team"], - "projects": ["llama-3"], - } + mock_get_authenticated_user.return_value = User( + "test-user", + { + "roles": ["data-scientist", "engineer"], + "teams": ["ml-team", "platform-team"], + "projects": ["llama-3"], + }, + ) model = await routing_table.get_model("auto-access-model") assert model.identifier == "auto-access-model" + + +@pytest.fixture +async def test_setup_with_access_policy(cached_disk_dist_registry): + mock_inference = Mock() + mock_inference.__provider_spec__ = MagicMock() + mock_inference.__provider_spec__.api = Api.inference + mock_inference.register_model = AsyncMock(side_effect=_return_model) + mock_inference.unregister_model = AsyncMock(side_effect=_return_model) + + config = """ + - permit: + principal: user-1 + actions: [create, read, delete] + description: user-1 has full access to all models + - permit: + principal: user-2 + actions: [read] + resource: model::model-1 + description: user-2 has read access to model-1 only + - permit: + principal: user-3 + actions: [read] + resource: model::model-2 + description: user-3 has read access to model-2 only + - forbid: + actions: [create, read, delete] + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + routing_table = ModelsRoutingTable( + impls_by_provider_id={"test_provider": mock_inference}, + dist_registry=cached_disk_dist_registry, + policy=policy, + ) + yield routing_table + + +@pytest.mark.asyncio +@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") +async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy): + routing_table = test_setup_with_access_policy + mock_get_authenticated_user.return_value = User( + "user-1", + { + "roles": ["admin"], + "projects": ["foo", "bar"], + }, + ) + await routing_table.register_model("model-1", provider_id="test_provider") + await routing_table.register_model("model-2", provider_id="test_provider") + await routing_table.register_model("model-3", provider_id="test_provider") + model = await routing_table.get_model("model-1") + assert model.identifier == "model-1" + model = await routing_table.get_model("model-2") + assert model.identifier == "model-2" + model = await routing_table.get_model("model-3") + assert model.identifier == "model-3" + + mock_get_authenticated_user.return_value = User( + "user-2", + { + "roles": ["user"], + "projects": ["foo"], + }, + ) + model = await routing_table.get_model("model-1") + assert model.identifier == "model-1" + with pytest.raises(ValueError): + await routing_table.get_model("model-2") + with pytest.raises(ValueError): + await routing_table.get_model("model-3") + with pytest.raises(AccessDeniedError): + await routing_table.register_model("model-4", provider_id="test_provider") + with pytest.raises(AccessDeniedError): + await routing_table.unregister_model("model-1") + + mock_get_authenticated_user.return_value = User( + "user-3", + { + "roles": ["user"], + "projects": ["bar"], + }, + ) + model = await routing_table.get_model("model-2") + assert model.identifier == "model-2" + with pytest.raises(ValueError): + await routing_table.get_model("model-1") + with pytest.raises(ValueError): + await routing_table.get_model("model-3") + with pytest.raises(AccessDeniedError): + await routing_table.register_model("model-5", provider_id="test_provider") + with pytest.raises(AccessDeniedError): + await routing_table.unregister_model("model-2") + + mock_get_authenticated_user.return_value = User( + "user-1", + { + "roles": ["admin"], + "projects": ["foo", "bar"], + }, + ) + await routing_table.unregister_model("model-3") + with pytest.raises(ValueError): + await routing_table.get_model("model-3") + + +def test_permit_when(): + config = """ + - permit: + principal: user-1 + actions: [read] + when: + user_in: resource.namespaces + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithACL( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + access_attributes=AccessAttributes(namespaces=["foo"]), + ) + assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) + + +def test_permit_unless(): + config = """ + - permit: + principal: user-1 + actions: [read] + resource: model::* + unless: + - user_not_in: resource.namespaces + - user_in: resource.teams + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithACL( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + access_attributes=AccessAttributes(namespaces=["foo"]), + ) + assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) + + +def test_forbid_when(): + config = """ + - forbid: + principal: user-1 + actions: [read] + when: + user_in: resource.namespaces + - permit: + actions: [read] + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithACL( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + access_attributes=AccessAttributes(namespaces=["foo"]), + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) + assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) + + +def test_forbid_unless(): + config = """ + - forbid: + principal: user-1 + actions: [read] + unless: + user_in: resource.namespaces + - permit: + actions: [read] + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithACL( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + access_attributes=AccessAttributes(namespaces=["foo"]), + ) + assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) + + +def test_condition_with_literal(): + config = """ + - permit: + actions: [read] + when: + user_in: role::admin + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithACL( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + access_attributes=AccessAttributes(namespaces=["foo"]), + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) + assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-4", None)) + + +def test_condition_with_unrecognised_literal(): + config = """ + - permit: + actions: [read] + when: + user_in: whatever + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithACL( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + access_attributes=AccessAttributes(namespaces=["foo"]), + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert not is_action_allowed(policy, "read", model, User("user-2", None)) + + +def test_empty_condition(): + config = """ + - permit: + actions: [read] + when: {} + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithACL( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + ) + assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert is_action_allowed(policy, "read", model, User("user-2", None)) diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index bb4c15dbc..acf4da0a3 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -100,9 +100,10 @@ async def test_resolve_impls_basic(): add_protocol_methods(SampleImpl, Inference) mock_module.get_provider_impl = AsyncMock(return_value=impl) + mock_module.get_provider_impl.__text_signature__ = "()" sys.modules["test_module"] = mock_module - impls = await resolve_impls(run_config, provider_registry, dist_registry) + impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={}) assert Api.inference in impls assert isinstance(impls[Api.inference], InferenceRouter)