diff --git a/llama_stack/distribution/access_control/access_control.py b/llama_stack/distribution/access_control/access_control.py index e68325783..84d506d8f 100644 --- a/llama_stack/distribution/access_control/access_control.py +++ b/llama_stack/distribution/access_control/access_control.py @@ -4,16 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Protocol +from typing import Any -from llama_stack.distribution.request_headers import User +from llama_stack.distribution.datatypes import User +from .conditions import ( + Condition, + ProtectedResource, + parse_conditions, +) from .datatypes import ( - AccessAttributes, AccessRule, Action, - AttributeReference, - Condition, Scope, ) @@ -37,43 +39,6 @@ def matches_scope( return action in scope.actions -def user_in_literal( - literal: str, - user_attributes: dict[str, list[str]] | None, -) -> bool: - for qualifier in ["role::", "team::", "project::", "namespace::"]: - if literal.startswith(qualifier): - if not user_attributes: - return False - ref = qualifier.replace("::", "s") - if ref in user_attributes: - value = literal.removeprefix(qualifier) - return value in user_attributes[ref] - else: - return False - return False - - -def user_in( - ref: AttributeReference | str, - resource_attributes: AccessAttributes | None, - user_attributes: dict[str, list[str]] | None, -) -> bool: - if not ref.startswith("resource."): - return user_in_literal(ref, user_attributes) - name = ref.removeprefix("resource.") - required = resource_attributes and getattr(resource_attributes, name) - if not required: - return True - if not user_attributes or name not in user_attributes: - return False - actual = user_attributes[name] - for value in required: - if value in actual: - return True - return False - - def as_list(obj: Any) -> list[Any]: if isinstance(obj, list): return obj @@ -82,55 +47,27 @@ def as_list(obj: Any) -> list[Any]: def matches_conditions( conditions: list[Condition], - resource_attributes: AccessAttributes | None, - user_attributes: dict[str, list[str]] | None, + resource: ProtectedResource, + user: User, ) -> bool: for condition in conditions: # must match all conditions - if not matches_condition(condition, resource_attributes, user_attributes): + if not condition.matches(resource, user): return False return True -def matches_condition( - condition: Condition | list[Condition], - resource_attributes: AccessAttributes | None, - user_attributes: dict[str, list[str]] | None, -) -> bool: - if isinstance(condition, list): - return matches_conditions(as_list(condition), resource_attributes, user_attributes) - if condition.user_in: - for ref in as_list(condition.user_in): - # if multiple references are specified, all must match - if not user_in(ref, resource_attributes, user_attributes): - return False - return True - if condition.user_not_in: - for ref in as_list(condition.user_not_in): - # if multiple references are specified, none must match - if user_in(ref, resource_attributes, user_attributes): - return False - return True - return True - - def default_policy() -> list[AccessRule]: - # for backwards compatibility, if no rules are provided , assume - # full access to all subject to attribute matching rules + # for backwards compatibility, if no rules are provided, assume + # full access subject to previous attribute matching rules return [ AccessRule( permit=Scope(actions=list(Action)), - when=Condition(user_in=list(AttributeReference)), - ) + when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]], + ), ] -class ProtectedResource(Protocol): - type: str - identifier: str - access_attributes: AccessAttributes - - def is_action_allowed( policy: list[AccessRule], action: Action, @@ -144,26 +81,23 @@ def is_action_allowed( if not len(policy): policy = default_policy() - resource_attributes = AccessAttributes() - if hasattr(resource, "access_attributes"): - resource_attributes = resource.access_attributes qualified_resource_id = resource.type + "::" + resource.identifier for rule in policy: if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal): if rule.when: - if matches_condition(rule.when, resource_attributes, user.attributes): + if matches_conditions(parse_conditions(as_list(rule.when)), resource, user): return False elif rule.unless: - if not matches_condition(rule.unless, resource_attributes, user.attributes): + if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user): return False else: return False elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal): if rule.when: - if matches_condition(rule.when, resource_attributes, user.attributes): + if matches_conditions(parse_conditions(as_list(rule.when)), resource, user): return True elif rule.unless: - if not matches_condition(rule.unless, resource_attributes, user.attributes): + if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user): return True else: return True diff --git a/llama_stack/distribution/access_control/conditions.py b/llama_stack/distribution/access_control/conditions.py new file mode 100644 index 000000000..25a267124 --- /dev/null +++ b/llama_stack/distribution/access_control/conditions.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Protocol + + +class User(Protocol): + principal: str + attributes: dict[str, list[str]] | None + + +class ProtectedResource(Protocol): + type: str + identifier: str + owner: User + + +class Condition(Protocol): + def matches(self, resource: ProtectedResource, user: User) -> bool: ... + + +class UserInOwnersList: + def __init__(self, name: str): + self.name = name + + def owners_values(self, resource: ProtectedResource) -> list[str] | None: + if ( + hasattr(resource, "owner") + and resource.owner + and resource.owner.attributes + and self.name in resource.owner.attributes + ): + return resource.owner.attributes[self.name] + else: + return None + + def matches(self, resource: ProtectedResource, user: User) -> bool: + required = self.owners_values(resource) + if not required: + return True + if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]: + return False + user_values = user.attributes[self.name] + for value in required: + if value in user_values: + return True + return False + + def __repr__(self): + return f"user in owners {self.name}" + + +class UserNotInOwnersList(UserInOwnersList): + def __init__(self, name: str): + super().__init__(name) + + def matches(self, resource: ProtectedResource, user: User) -> bool: + return not super().matches(resource, user) + + def __repr__(self): + return f"user not in owners {self.name}" + + +class UserWithValueInList: + def __init__(self, name: str, value: str): + self.name = name + self.value = value + + def matches(self, resource: ProtectedResource, user: User) -> bool: + if user.attributes and self.name in user.attributes: + return self.value in user.attributes[self.name] + print(f"User does not have {self.value} in {self.name}") + return False + + def __repr__(self): + return f"user with {self.value} in {self.name}" + + +class UserWithValueNotInList(UserWithValueInList): + def __init__(self, name: str, value: str): + super().__init__(name, value) + + def matches(self, resource: ProtectedResource, user: User) -> bool: + return not super().matches(resource, user) + + def __repr__(self): + return f"user with {self.value} not in {self.name}" + + +class UserIsOwner: + def matches(self, resource: ProtectedResource, user: User) -> bool: + return resource.owner.principal == user.principal if resource.owner else False + + def __repr__(self): + return "user is owner" + + +class UserIsNotOwner: + def matches(self, resource: ProtectedResource, user: User) -> bool: + return not resource.owner or resource.owner.principal != user.principal + + def __repr__(self): + return "user is not owner" + + +def parse_condition(condition: str) -> Condition: + words = condition.split() + match words: + case ["user", "is", "owner"]: + return UserIsOwner() + case ["user", "is", "not", "owner"]: + return UserIsNotOwner() + case ["user", "with", value, "in", name]: + return UserWithValueInList(name, value) + case ["user", "with", value, "not", "in", name]: + return UserWithValueNotInList(name, value) + case ["user", "in", "owners", name]: + return UserInOwnersList(name) + case ["user", "not", "in", "owners", name]: + return UserNotInOwnersList(name) + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def parse_conditions(conditions: list[str]) -> list[Condition]: + return [parse_condition(c) for c in conditions] diff --git a/llama_stack/distribution/access_control/datatypes.py b/llama_stack/distribution/access_control/datatypes.py index 9f95f2c7e..3e6c624dc 100644 --- a/llama_stack/distribution/access_control/datatypes.py +++ b/llama_stack/distribution/access_control/datatypes.py @@ -6,37 +6,10 @@ from enum import Enum -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, model_validator from typing_extensions import Self - -class AccessAttributes(BaseModel): - """Structured representation of user attributes for access control. - - This model defines a structured approach to representing user attributes - with common standard categories for access control. - - Standard attribute categories include: - - roles: Role-based attributes (e.g., admin, data-scientist) - - teams: Team-based attributes (e.g., ml-team, infra-team) - - projects: Project access attributes (e.g., llama-3, customer-insights) - - namespaces: Namespace-based access control for resource isolation - """ - - # Standard attribute categories - the minimal set we need now - roles: list[str] | None = Field( - default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')" - ) - - teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')") - - projects: list[str] | None = Field( - default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')" - ) - - namespaces: list[str] | None = Field( - default=None, description="Namespace-based access control for resource isolation" - ) +from .conditions import parse_conditions class Action(str, Enum): @@ -62,18 +35,6 @@ def _require_one_of(obj, a: str, b: str): raise ValueError(f"on of {a} or {b} is required") -class AttributeReference(str, Enum): - RESOURCE_ROLES = "resource.roles" - RESOURCE_TEAMS = "resource.teams" - RESOURCE_PROJECTS = "resource.projects" - RESOURCE_NAMESPACES = "resource.namespaces" - - -class Condition(BaseModel): - user_in: AttributeReference | list[AttributeReference] | str | None = None - user_not_in: AttributeReference | list[AttributeReference] | str | None = None - - class AccessRule(BaseModel): """Access rule based loosely on cedar policy language @@ -85,10 +46,14 @@ class AccessRule(BaseModel): requests. A rule may also specify a condition, either a 'when' or an 'unless', with additional - constraints as to where the rule applies. The constraints at present are whether the - user requesting access is in or not in some set. This set can either be a particular - set of attributes on the resource e.g. resource.roles or a literal value of some - notion of group, e.g. role::admin or namespace::foo. + constraints as to where the rule applies. The constraints supported at present are: + + - 'user with in ' + - 'user with not in ' + - 'user is owner' + - 'user is not owner' + - 'user in owners ' + - 'user not in owners ' Rules are tested in order to find a match. If a match is found, the request is permitted or forbidden depending on the type of rule. If no match is found, the @@ -99,33 +64,31 @@ class AccessRule(BaseModel): Some examples in yaml: - permit: - principal: user-1 - actions: [create, read, delete] - resource: model::* + principal: user-1 + actions: [create, read, delete] + resource: model::* description: user-1 has full access to all models - permit: - principal: user-2 - actions: [read] - resource: model::model-1 + principal: user-2 + actions: [read] + resource: model::model-1 description: user-2 has read access to model-1 only - permit: - actions: [read] - when: - user_in: resource.namespaces - description: any user has read access to any resource with matching attributes + actions: [read] + when: user in owner teams + description: any user has read access to any resource created by a member of their team - forbid: - actions: [create, read, delete] - resource: vector_db::* - unless: - user_in: role::admin + actions: [create, read, delete] + resource: vector_db::* + unless: user with admin in roles description: only user with admin role can use vector_db resources """ permit: Scope | None = None forbid: Scope | None = None - when: Condition | list[Condition] | None = None - unless: Condition | list[Condition] | None = None + when: str | list[str] | None = None + unless: str | list[str] | None = None description: str | None = None @model_validator(mode="after") @@ -133,4 +96,12 @@ class AccessRule(BaseModel): _require_one_of(self, "permit", "forbid") _mutually_exclusive(self, "permit", "forbid") _mutually_exclusive(self, "when", "unless") + if isinstance(self.when, list): + parse_conditions(self.when) + elif self.when: + parse_conditions([self.when]) + if isinstance(self.unless, list): + parse_conditions(self.unless) + elif self.unless: + parse_conditions([self.unless]) return self diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index a01cbdfb8..abc3f0065 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO -from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule +from llama_stack.distribution.access_control.datatypes import AccessRule from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig @@ -36,97 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2" RoutingKey = str | list[str] -class ResourceWithACL(Resource): - """Extension of Resource that adds attribute-based access control capabilities. +class User(BaseModel): + principal: str + # further attributes that may be used for access control decisions + attributes: dict[str, list[str]] | None = None - This class adds an optional access_attributes field that allows fine-grained control - over which users can access each resource. When attributes are defined, a user must have - matching attributes to access the resource. + def __init__(self, principal: str, attributes: dict[str, list[str]] | None): + super().__init__(principal=principal, attributes=attributes) - Attribute Matching Algorithm: - 1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users - 2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects") - 3. The matching algorithm requires ALL categories to match (AND relationship between categories) - 4. Within each category, ANY value match is sufficient (OR relationship within a category) - Examples: - # Resource visible to everyone (no access control) - model = Model(identifier="llama-2", ...) +class ResourceWithOwner(Resource): + """Extension of Resource that adds an optional owner, i.e. the user that created the + resource. This can be used to constrain access to the resource.""" - # Resource visible only to admins - model = Model( - identifier="gpt-4", - access_attributes=AccessAttributes(roles=["admin"]) - ) - - # Resource visible to data scientists on the ML team - model = Model( - identifier="private-model", - access_attributes=AccessAttributes( - roles=["data-scientist", "researcher"], - teams=["ml-team"] - ) - ) - # ^ User must have at least one of the roles AND be on the ml-team - - # Resource visible to users with specific project access - vector_db = VectorDB( - identifier="customer-embeddings", - access_attributes=AccessAttributes( - projects=["customer-insights"], - namespaces=["confidential"] - ) - ) - # ^ User must have access to the customer-insights project AND have confidential namespace - """ - - access_attributes: AccessAttributes | None = None + owner: User | None = None # Use the extended Resource for all routable objects -class ModelWithACL(Model, ResourceWithACL): +class ModelWithOwner(Model, ResourceWithOwner): pass -class ShieldWithACL(Shield, ResourceWithACL): +class ShieldWithOwner(Shield, ResourceWithOwner): pass -class VectorDBWithACL(VectorDB, ResourceWithACL): +class VectorDBWithOwner(VectorDB, ResourceWithOwner): pass -class DatasetWithACL(Dataset, ResourceWithACL): +class DatasetWithOwner(Dataset, ResourceWithOwner): pass -class ScoringFnWithACL(ScoringFn, ResourceWithACL): +class ScoringFnWithOwner(ScoringFn, ResourceWithOwner): pass -class BenchmarkWithACL(Benchmark, ResourceWithACL): +class BenchmarkWithOwner(Benchmark, ResourceWithOwner): pass -class ToolWithACL(Tool, ResourceWithACL): +class ToolWithOwner(Tool, ResourceWithOwner): pass -class ToolGroupWithACL(ToolGroup, ResourceWithACL): +class ToolGroupWithOwner(ToolGroup, ResourceWithOwner): pass RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup RoutableObjectWithProvider = Annotated[ - ModelWithACL - | ShieldWithACL - | VectorDBWithACL - | DatasetWithACL - | ScoringFnWithACL - | BenchmarkWithACL - | ToolWithACL - | ToolGroupWithACL, + ModelWithOwner + | ShieldWithOwner + | VectorDBWithOwner + | DatasetWithOwner + | ScoringFnWithOwner + | BenchmarkWithOwner + | ToolWithOwner + | ToolGroupWithOwner, Field(discriminator="type"), ] diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 7cf16ea3c..81d494e04 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -10,6 +10,8 @@ import logging from contextlib import AbstractContextManager from typing import Any +from llama_stack.distribution.datatypes import User + from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) @@ -18,16 +20,6 @@ log = logging.getLogger(__name__) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) -class User: - principal: str - # further attributes that may be used for access control decisions - attributes: dict[str, list[str]] - - def __init__(self, principal: str, attributes: dict[str, list[str]]): - self.principal = principal - self.attributes = attributes - - class RequestProviderDataContext(AbstractContextManager): """Context manager for request provider data""" diff --git a/llama_stack/distribution/routing_tables/benchmarks.py b/llama_stack/distribution/routing_tables/benchmarks.py index 589a00c02..815483494 100644 --- a/llama_stack/distribution/routing_tables/benchmarks.py +++ b/llama_stack/distribution/routing_tables/benchmarks.py @@ -8,7 +8,7 @@ from typing import Any from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.distribution.datatypes import ( - BenchmarkWithACL, + BenchmarkWithOwner, ) from llama_stack.log import get_logger @@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): ) if provider_benchmark_id is None: provider_benchmark_id = benchmark_id - benchmark = BenchmarkWithACL( + benchmark = BenchmarkWithOwner( identifier=benchmark_id, dataset_id=dataset_id, scoring_functions=scoring_functions, diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index c31f58bcf..b79c8a2a8 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -10,7 +10,6 @@ from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.datatypes import ( - AccessAttributes, AccessRule, RoutableObject, RoutableObjectWithProvider, @@ -195,9 +194,9 @@ class CommonRoutingTableImpl(RoutingTable): creator = get_authenticated_user() if not is_action_allowed(self.policy, "create", obj, creator): raise AccessDeniedError() - if creator and creator.attributes: - obj.access_attributes = AccessAttributes(**creator.attributes) - logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + if creator: + obj.owner = creator + logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}") registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object diff --git a/llama_stack/distribution/routing_tables/datasets.py b/llama_stack/distribution/routing_tables/datasets.py index 4401ad47e..fb34f40b6 100644 --- a/llama_stack/distribution/routing_tables/datasets.py +++ b/llama_stack/distribution/routing_tables/datasets.py @@ -19,7 +19,7 @@ from llama_stack.apis.datasets import ( ) from llama_stack.apis.resource import ResourceType from llama_stack.distribution.datatypes import ( - DatasetWithACL, + DatasetWithOwner, ) from llama_stack.log import get_logger @@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): if metadata is None: metadata = {} - dataset = DatasetWithACL( + dataset = DatasetWithOwner( identifier=dataset_id, provider_resource_id=provider_dataset_id, provider_id=provider_id, diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 7216d9935..c6a10ea9b 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -9,7 +9,7 @@ from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ( - ModelWithACL, + ModelWithOwner, ) from llama_stack.log import get_logger @@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") - model = ModelWithACL( + model = ModelWithOwner( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, diff --git a/llama_stack/distribution/routing_tables/scoring_functions.py b/llama_stack/distribution/routing_tables/scoring_functions.py index d85f64b57..742cc3ca6 100644 --- a/llama_stack/distribution/routing_tables/scoring_functions.py +++ b/llama_stack/distribution/routing_tables/scoring_functions.py @@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import ( ScoringFunctions, ) from llama_stack.distribution.datatypes import ( - ScoringFnWithACL, + ScoringFnWithOwner, ) from llama_stack.log import get_logger @@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - scoring_fn = ScoringFnWithACL( + scoring_fn = ScoringFnWithOwner( identifier=scoring_fn_id, description=description, return_type=return_type, diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py index 7f62596c9..5215981b9 100644 --- a/llama_stack/distribution/routing_tables/shields.py +++ b/llama_stack/distribution/routing_tables/shields.py @@ -9,7 +9,7 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields from llama_stack.distribution.datatypes import ( - ShieldWithACL, + ShieldWithOwner, ) from llama_stack.log import get_logger @@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) if params is None: params = {} - shield = ShieldWithACL( + shield = ShieldWithOwner( identifier=shield_id, provider_resource_id=provider_shield_id, provider_id=provider_id, diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index 3f103ed22..20f1c984f 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -8,7 +8,7 @@ from typing import Any from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups -from llama_stack.distribution.datatypes import ToolGroupWithACL +from llama_stack.distribution.datatypes import ToolGroupWithOwner from llama_stack.log import get_logger from .common import CommonRoutingTableImpl @@ -88,7 +88,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): mcp_endpoint: URL | None = None, args: dict[str, Any] | None = None, ) -> None: - toolgroup = ToolGroupWithACL( + toolgroup = ToolGroupWithOwner( identifier=toolgroup_id, provider_id=provider_id, provider_resource_id=toolgroup_id, diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index dc6c0d0ef..542e965f8 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.distribution.datatypes import ( - VectorDBWithACL, + VectorDBWithOwner, ) from llama_stack.log import get_logger @@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): "embedding_model": embedding_model, "embedding_dimension": model.metadata["embedding_dimension"], } - vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) + vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) await self.register_object(vector_db) return vector_db diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index fb26b49a7..81b1ffd37 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -105,24 +105,16 @@ class AuthenticationMiddleware: logger.exception("Error during authentication") return await self._send_auth_error(send, "Authentication service error") - # Store attributes in request scope for access control - if validation_result.access_attributes: - user_attributes = validation_result.access_attributes.model_dump(exclude_none=True) - else: - logger.warning("No access attributes, setting namespace to token by default") - user_attributes = { - "roles": [token], - } - # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) # can identify the requester and enforce per-client rate limits. scope["authenticated_client_id"] = token # Store attributes in request scope - scope["user_attributes"] = user_attributes scope["principal"] = validation_result.principal + if validation_result.attributes: + scope["user_attributes"] = validation_result.attributes logger.debug( - f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes" + f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes" ) return await self.app(scope, receive, send) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 723a65b77..7cb038494 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -16,43 +16,18 @@ from jose import jwt from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self -from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType +from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") -class TokenValidationResult(BaseModel): - principal: str | None = Field( - default=None, - description="The principal (username or persistent identifier) of the authenticated user", - ) - access_attributes: AccessAttributes | None = Field( - default=None, - description=""" - Structured user attributes for attribute-based access control. - - These attributes determine which resources the user can access. - The model provides standard categories like "roles", "teams", "projects", and "namespaces". - Each attribute category contains a list of values that the user has for that category. - During access control checks, these values are compared against resource requirements. - - Example with standard categories: - ```json - { - "roles": ["admin", "data-scientist"], - "teams": ["ml-team"], - "projects": ["llama-3"], - "namespaces": ["research"] - } - ``` - """, - ) - - -class AuthResponse(TokenValidationResult): +class AuthResponse(BaseModel): """The format of the authentication response from the auth endpoint.""" + principal: str + # further attributes that may be used for access control decisions + attributes: dict[str, list[str]] | None = None message: str | None = Field( default=None, description="Optional message providing additional context about the authentication result." ) @@ -78,7 +53,7 @@ class AuthProvider(ABC): """Abstract base class for authentication providers.""" @abstractmethod - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_token(self, token: str, scope: dict | None = None) -> User: """Validate a token and return access attributes.""" pass @@ -88,10 +63,10 @@ class AuthProvider(ABC): pass -def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: - attributes = AccessAttributes() +def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]: + attributes: dict[str, list[str]] = {} for claim_key, attribute_key in mapping.items(): - if claim_key not in claims or not hasattr(attributes, attribute_key): + if claim_key not in claims: continue claim = claims[claim_key] if isinstance(claim, list): @@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) else: values = claim.split() - current = getattr(attributes, attribute_key) - if current: - current.extend(values) + if attribute_key in attributes: + attributes[attribute_key].extend(values) else: - setattr(attributes, attribute_key, values) + attributes[attribute_key] = values return attributes @@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel): for key, value in v.items(): if not value: raise ValueError(f"claims_mapping value cannot be empty: {key}") - if value not in AccessAttributes.model_fields: - raise ValueError(f"claims_mapping value is not a valid attribute: {value}") return v @model_validator(mode="after") @@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider): self._jwks: dict[str, str] = {} self._jwks_lock = Lock() - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_token(self, token: str, scope: dict | None = None) -> User: if self.config.jwks: return await self.validate_jwt_token(token, scope) if self.config.introspection: return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") - async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the JWT token.""" await self._refresh_jwks() @@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider): # We should incorporate these into the access attributes. principal = claims["sub"] access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) - return TokenValidationResult( + return User( principal=principal, - access_attributes=access_attributes, + attributes=access_attributes, ) - async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def introspect_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using token introspection as defined by RFC 7662.""" form = { "token": token, @@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider): raise ValueError("Token not active") principal = fields["sub"] or fields["username"] access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) - return TokenValidationResult( + return User( principal=principal, - access_attributes=access_attributes, + attributes=access_attributes, ) except httpx.TimeoutException: logger.exception("Token introspection request timed out") @@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider): self.config = config self._client = None - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the custom authentication endpoint.""" if scope is None: scope = {} @@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider): json=auth_request.model_dump(), timeout=10.0, # Add a reasonable timeout ) + print("MADE CALL") if response.status_code != 200: logger.warning(f"Authentication failed with status code: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}") @@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider): try: response_data = response.json() auth_response = AuthResponse(**response_data) - return auth_response + return User(auth_response.principal, auth_response.attributes) except Exception as e: logger.exception("Error parsing authentication response") raise ValueError("Invalid authentication response format") from e diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 322bcaf04..25dbb5df7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -11,7 +11,8 @@ from datetime import datetime, timezone from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed -from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule +from llama_stack.distribution.access_control.datatypes import AccessRule +from llama_stack.distribution.datatypes import User from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.providers.utils.kvstore import KVStore @@ -22,7 +23,7 @@ class AgentSessionInfo(Session): # TODO: is this used anywhere? vector_db_id: str | None = None started_at: datetime - access_attributes: AccessAttributes | None = None + owner: User | None = None identifier: str | None = None type: str = "session" @@ -42,14 +43,12 @@ class AgentPersistence: # Get current user's auth attributes for new sessions user = get_authenticated_user() - auth_attributes = user and user.attributes - access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None session_info = AgentSessionInfo( session_id=session_id, session_name=name, started_at=datetime.now(timezone.utc), - access_attributes=access_attributes, + owner=user, turns=[], identifier=name, # should this be qualified in any way? ) @@ -80,7 +79,7 @@ class AgentPersistence: def _check_session_access(self, session_info: AgentSessionInfo) -> bool: """Check if current user has access to the session.""" # Handle backward compatibility for old sessions without access control - if not hasattr(session_info, "access_attributes"): + if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): return True return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) diff --git a/tests/unit/providers/agents/test_persistence_access_control.py b/tests/unit/providers/agents/test_persistence_access_control.py index 3563ae60c..d5b876a09 100644 --- a/tests/unit/providers/agents/test_persistence_access_control.py +++ b/tests/unit/providers/agents/test_persistence_access_control.py @@ -12,8 +12,7 @@ import pytest from llama_stack.apis.agents import Turn from llama_stack.apis.inference import CompletionMessage, StopReason -from llama_stack.distribution.datatypes import AccessAttributes -from llama_stack.distribution.request_headers import User +from llama_stack.distribution.datatypes import User from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo @@ -38,9 +37,10 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us # Get the session and verify access attributes were set session_info = await agent_persistence.get_session_info(session_id) assert session_info is not None - assert session_info.access_attributes is not None - assert session_info.access_attributes.roles == ["researcher"] - assert session_info.access_attributes.teams == ["ai-team"] + assert session_info.owner is not None + assert session_info.owner.attributes is not None + assert session_info.owner.attributes["roles"] == ["researcher"] + assert session_info.owner.attributes["teams"] == ["ai-team"] @pytest.mark.asyncio @@ -54,7 +54,7 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup): session_id=session_id, session_name="Restricted Session", started_at=datetime.now(), - access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), + owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}), turns=[], identifier="Restricted Session", ) @@ -89,7 +89,7 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup): session_id=session_id, session_name="Restricted Session", started_at=datetime.now(), - access_attributes=AccessAttributes(roles=["admin"]), + owner=User("someone", {"roles": ["admin"]}), turns=[], identifier="Restricted Session", ) @@ -143,7 +143,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_u session_id=session_id, session_name="Restricted Session", started_at=datetime.now(), - access_attributes=AccessAttributes(roles=["admin"]), + owner=User("someone", {"roles": ["admin"]}), turns=[], identifier="Restricted Session", ) diff --git a/tests/unit/registry/test_registry_acl.py b/tests/unit/registry/test_registry_acl.py index 25ea37bfa..48b3ac51b 100644 --- a/tests/unit/registry/test_registry_acl.py +++ b/tests/unit/registry/test_registry_acl.py @@ -8,19 +8,18 @@ import pytest from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ModelWithACL -from llama_stack.distribution.server.auth_providers import AccessAttributes +from llama_stack.distribution.datatypes import ModelWithOwner, User from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry @pytest.mark.asyncio async def test_registry_cache_with_acl(cached_disk_dist_registry): - model = ModelWithACL( + model = ModelWithOwner( identifier="model-acl", provider_id="test-provider", provider_resource_id="model-acl-resource", model_type=ModelType.llm, - access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]), + owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}), ) success = await cached_disk_dist_registry.register(model) @@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry): cached_model = cached_disk_dist_registry.get_cached("model", "model-acl") assert cached_model is not None assert cached_model.identifier == "model-acl" - assert cached_model.access_attributes.roles == ["admin"] - assert cached_model.access_attributes.teams == ["ai-team"] + assert cached_model.owner.principal == "testuser" + assert cached_model.owner.attributes["roles"] == ["admin"] + assert cached_model.owner.attributes["teams"] == ["ai-team"] fetched_model = await cached_disk_dist_registry.get("model", "model-acl") assert fetched_model is not None assert fetched_model.identifier == "model-acl" - assert fetched_model.access_attributes.roles == ["admin"] - - model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"]) - await cached_disk_dist_registry.update(model) - - updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl") - assert updated_cached is not None - assert updated_cached.access_attributes.roles == ["admin", "user"] - assert updated_cached.access_attributes.projects == ["project-x"] - assert updated_cached.access_attributes.teams is None + assert fetched_model.owner.attributes["roles"] == ["admin"] new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore) await new_registry.initialize() @@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry): new_model = await new_registry.get("model", "model-acl") assert new_model is not None assert new_model.identifier == "model-acl" - assert new_model.access_attributes.roles == ["admin", "user"] - assert new_model.access_attributes.projects == ["project-x"] - assert new_model.access_attributes.teams is None + assert new_model.owner.principal == "testuser" + assert new_model.owner.attributes["roles"] == ["admin"] + assert new_model.owner.attributes["teams"] == ["ai-team"] @pytest.mark.asyncio async def test_registry_empty_acl(cached_disk_dist_registry): - model = ModelWithACL( + model = ModelWithOwner( identifier="model-empty-acl", provider_id="test-provider", provider_resource_id="model-resource", model_type=ModelType.llm, - access_attributes=AccessAttributes(), + owner=User("testuser", None), ) await cached_disk_dist_registry.register(model) cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl") assert cached_model is not None - assert cached_model.access_attributes is not None - assert cached_model.access_attributes.roles is None - assert cached_model.access_attributes.teams is None - assert cached_model.access_attributes.projects is None - assert cached_model.access_attributes.namespaces is None + assert cached_model.owner is not None + assert cached_model.owner.attributes is None all_models = await cached_disk_dist_registry.get_all() assert len(all_models) == 1 - model = ModelWithACL( + model = ModelWithOwner( identifier="model-no-acl", provider_id="test-provider", provider_resource_id="model-resource-2", @@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry): cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl") assert cached_model is not None - assert cached_model.access_attributes is None + assert cached_model.owner is None all_models = await cached_disk_dist_registry.get_all() assert len(all_models) == 2 @@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry): @pytest.mark.asyncio async def test_registry_serialization(cached_disk_dist_registry): - attributes = AccessAttributes( - roles=["admin", "researcher"], - teams=["ai-team", "ml-team"], - projects=["project-a", "project-b"], - namespaces=["prod", "staging"], - ) + attributes = { + "roles": ["admin", "researcher"], + "teams": ["ai-team", "ml-team"], + "projects": ["project-a", "project-b"], + "namespaces": ["prod", "staging"], + } - model = ModelWithACL( + model = ModelWithOwner( identifier="model-serialize", provider_id="test-provider", provider_resource_id="model-resource", model_type=ModelType.llm, - access_attributes=attributes, + owner=User("bob", attributes), ) await cached_disk_dist_registry.register(model) @@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry): loaded_model = await new_registry.get("model", "model-serialize") assert loaded_model is not None - assert loaded_model.access_attributes.roles == ["admin", "researcher"] - assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"] - assert loaded_model.access_attributes.projects == ["project-a", "project-b"] - assert loaded_model.access_attributes.namespaces == ["prod", "staging"] + assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"] + assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"] + assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"] + assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"] diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 70faee748..f9ad47b0c 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -8,13 +8,12 @@ from unittest.mock import MagicMock, Mock, patch import pytest import yaml -from pydantic import TypeAdapter +from pydantic import TypeAdapter, ValidationError from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed -from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL -from llama_stack.distribution.request_headers import User +from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User from llama_stack.distribution.routing_tables.models import ModelsRoutingTable @@ -45,25 +44,25 @@ async def test_setup(cached_disk_dist_registry): @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") async def test_access_control_with_cache(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup - model_public = ModelWithACL( + model_public = ModelWithOwner( identifier="model-public", provider_id="test_provider", provider_resource_id="model-public", model_type=ModelType.llm, ) - model_admin_only = ModelWithACL( + model_admin_only = ModelWithOwner( identifier="model-admin", provider_id="test_provider", provider_resource_id="model-admin", model_type=ModelType.llm, - access_attributes=AccessAttributes(roles=["admin"]), + owner=User("testuser", {"roles": ["admin"]}), ) - model_data_scientist = ModelWithACL( + model_data_scientist = ModelWithOwner( identifier="model-data-scientist", provider_id="test_provider", provider_resource_id="model-data-scientist", model_type=ModelType.llm, - access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]), + owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}), ) await registry.register(model_public) await registry.register(model_admin_only) @@ -110,7 +109,7 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") async def test_access_control_and_updates(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup - model_public = ModelWithACL( + model_public = ModelWithOwner( identifier="model-updates", provider_id="test_provider", provider_resource_id="model-updates", @@ -125,7 +124,7 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu ) model = await routing_table.get_model("model-updates") assert model.identifier == "model-updates" - model_public.access_attributes = AccessAttributes(roles=["admin"]) + model_public.owner = User("testuser", {"roles": ["admin"]}) await registry.update(model_public) mock_get_authenticated_user.return_value = User( "test-user", @@ -149,12 +148,12 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup - model = ModelWithACL( + model = ModelWithOwner( identifier="model-empty-attrs", provider_id="test_provider", provider_resource_id="model-empty-attrs", model_type=ModelType.llm, - access_attributes=AccessAttributes(), + owner=User("testuser", {}), ) await registry.register(model) mock_get_authenticated_user.return_value = User( @@ -174,18 +173,18 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") async def test_no_user_attributes(mock_get_authenticated_user, test_setup): registry, routing_table = test_setup - model_public = ModelWithACL( + model_public = ModelWithOwner( identifier="model-public-2", provider_id="test_provider", provider_resource_id="model-public-2", model_type=ModelType.llm, ) - model_restricted = ModelWithACL( + model_restricted = ModelWithOwner( identifier="model-restricted", provider_id="test_provider", provider_resource_id="model-restricted", model_type=ModelType.llm, - access_attributes=AccessAttributes(roles=["admin"]), + owner=User("testuser", {"roles": ["admin"]}), ) await registry.register(model_public) await registry.register(model_restricted) @@ -212,7 +211,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set mock_get_authenticated_user.return_value = User("test-user", creator_attributes) # Create model without explicit access attributes - model = ModelWithACL( + model = ModelWithOwner( identifier="auto-access-model", provider_id="test_provider", provider_resource_id="auto-access-model", @@ -222,10 +221,11 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set # Verify the model got creator's attributes registered_model = await routing_table.get_model("auto-access-model") - assert registered_model.access_attributes is not None - assert registered_model.access_attributes.roles == ["data-scientist"] - assert registered_model.access_attributes.teams == ["ml-team"] - assert registered_model.access_attributes.projects == ["llama-3"] + assert registered_model.owner is not None + assert registered_model.owner.attributes is not None + assert registered_model.owner.attributes["roles"] == ["data-scientist"] + assert registered_model.owner.attributes["teams"] == ["ml-team"] + assert registered_model.owner.attributes["projects"] == ["llama-3"] # Verify another user without matching attributes can't access it mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]}) @@ -354,15 +354,14 @@ def test_permit_when(): - permit: principal: user-1 actions: [read] - when: - user_in: resource.namespaces + when: user in owners namespaces """ policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) - model = ModelWithACL( + model = ModelWithOwner( identifier="mymodel", provider_id="myprovider", model_type=ModelType.llm, - access_attributes=AccessAttributes(namespaces=["foo"]), + owner=User("testuser", {"namespaces": ["foo"]}), ) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) @@ -376,15 +375,15 @@ def test_permit_unless(): actions: [read] resource: model::* unless: - - user_not_in: resource.namespaces - - user_in: resource.teams + - user not in owners namespaces + - user in owners teams """ policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) - model = ModelWithACL( + model = ModelWithOwner( identifier="mymodel", provider_id="myprovider", model_type=ModelType.llm, - access_attributes=AccessAttributes(namespaces=["foo"]), + owner=User("testuser", {"namespaces": ["foo"]}), ) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) @@ -397,16 +396,16 @@ def test_forbid_when(): principal: user-1 actions: [read] when: - user_in: resource.namespaces + user in owners namespaces - permit: actions: [read] """ policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) - model = ModelWithACL( + model = ModelWithOwner( identifier="mymodel", provider_id="myprovider", model_type=ModelType.llm, - access_attributes=AccessAttributes(namespaces=["foo"]), + owner=User("testuser", {"namespaces": ["foo"]}), ) assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) @@ -419,35 +418,33 @@ def test_forbid_unless(): principal: user-1 actions: [read] unless: - user_in: resource.namespaces + user in owners namespaces - permit: actions: [read] """ policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) - model = ModelWithACL( + model = ModelWithOwner( identifier="mymodel", provider_id="myprovider", model_type=ModelType.llm, - access_attributes=AccessAttributes(namespaces=["foo"]), + owner=User("testuser", {"namespaces": ["foo"]}), ) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) -def test_condition_with_literal(): +def test_user_has_attribute(): config = """ - permit: actions: [read] - when: - user_in: role::admin + when: user with admin in roles """ policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) - model = ModelWithACL( + model = ModelWithOwner( identifier="mymodel", provider_id="myprovider", model_type=ModelType.llm, - access_attributes=AccessAttributes(namespaces=["foo"]), ) assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) @@ -455,35 +452,115 @@ def test_condition_with_literal(): assert not is_action_allowed(policy, "read", model, User("user-4", None)) -def test_condition_with_unrecognised_literal(): +def test_user_does_not_have_attribute(): config = """ - permit: actions: [read] - when: - user_in: whatever + unless: user with admin not in roles """ policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) - model = ModelWithACL( + model = ModelWithOwner( identifier="mymodel", provider_id="myprovider", model_type=ModelType.llm, - access_attributes=AccessAttributes(namespaces=["foo"]), ) assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) - assert not is_action_allowed(policy, "read", model, User("user-2", None)) + assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) + assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-4", None)) -def test_empty_condition(): +def test_is_owner(): config = """ - permit: actions: [read] - when: {} + when: user is owner """ policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) - model = ModelWithACL( + model = ModelWithOwner( identifier="mymodel", provider_id="myprovider", model_type=ModelType.llm, + owner=User("user-2", {"namespaces": ["foo"]}), ) - assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) - assert is_action_allowed(policy, "read", model, User("user-2", None)) + assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) + assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-4", None)) + + +def test_is_not_owner(): + config = """ + - permit: + actions: [read] + unless: user is not owner + """ + policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + model = ModelWithOwner( + identifier="mymodel", + provider_id="myprovider", + model_type=ModelType.llm, + owner=User("user-2", {"namespaces": ["foo"]}), + ) + assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) + assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) + assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]})) + assert not is_action_allowed(policy, "read", model, User("user-4", None)) + + +def test_invalid_rule_permit_and_forbid_both_specified(): + config = """ + - permit: + actions: [read] + forbid: + actions: [create] + """ + with pytest.raises(ValidationError): + TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + + +def test_invalid_rule_neither_permit_or_forbid_specified(): + config = """ + - when: user is owner + unless: user with admin in roles + """ + with pytest.raises(ValidationError): + TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + + +def test_invalid_rule_when_and_unless_both_specified(): + config = """ + - permit: + actions: [read] + when: user is owner + unless: user with admin in roles + """ + with pytest.raises(ValidationError): + TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + + +def test_invalid_condition(): + config = """ + - permit: + actions: [read] + when: random words that are not valid + """ + with pytest.raises(ValidationError): + TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) + + +@pytest.mark.parametrize( + "condition", + [ + "user is owner", + "user is not owner", + "user with dev in teams", + "user with default not in namespaces", + "user in owners roles", + "user not in owners projects", + ], +) +def test_condition_reprs(condition): + from llama_stack.distribution.access_control.conditions import parse_condition + + assert condition == str(parse_condition(condition)) diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 408acb88a..e159aefd1 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs): { "message": "Authentication successful", "principal": "test-principal", - "access_attributes": { + "attributes": { "roles": ["admin", "user"], "teams": ["ml-team", "nlp-team"], "projects": ["llama-3", "project-x"], @@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock { "message": "Authentication successful", "principal": "test-principal", - "access_attributes": { + "attributes": { "roles": ["admin", "user"], "teams": ["ml-team", "nlp-team"], "projects": ["llama-3", "project-x"], @@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) -@pytest.mark.asyncio -async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope): - """Test middleware behavior with no access attributes""" - middleware, mock_app = mock_http_middleware - mock_receive = AsyncMock() - mock_send = AsyncMock() - - with patch("httpx.AsyncClient") as mock_client: - mock_client_instance = AsyncMock() - mock_client.return_value.__aenter__.return_value = mock_client_instance - - mock_client_instance.post.return_value = MockResponse( - 200, - { - "message": "Authentication successful" - # No access_attributes - }, - ) - - await middleware(mock_scope, mock_receive, mock_send) - - assert "user_attributes" in mock_scope - attributes = mock_scope["user_attributes"] - assert "roles" in attributes - assert attributes["roles"] == ["test.jwt.token"] - - # oauth2 token provider tests @@ -380,16 +353,16 @@ def test_get_attributes_from_claims(): "aud": "llama-stack", } attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"}) - assert attributes.roles == ["my-user"] - assert attributes.teams == ["group1", "group2"] + assert attributes["roles"] == ["my-user"] + assert attributes["teams"] == ["group1", "group2"] claims = { "sub": "my-user", "tenant": "my-tenant", } attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"}) - assert attributes.roles == ["my-user"] - assert attributes.namespaces == ["my-tenant"] + assert attributes["roles"] == ["my-user"] + assert attributes["namespaces"] == ["my-tenant"] claims = { "sub": "my-user", @@ -408,9 +381,9 @@ def test_get_attributes_from_claims(): "groups": "teams", }, ) - assert set(attributes.roles) == {"my-user", "my-username"} - assert set(attributes.teams) == {"my-team", "group1", "group2"} - assert attributes.namespaces == ["my-tenant"] + assert set(attributes["roles"]) == {"my-user", "my-username"} + assert set(attributes["teams"]) == {"my-team", "group1", "group2"} + assert attributes["namespaces"] == ["my-tenant"] # TODO: add more tests for oauth2 token provider