Changes to access rule conditions:

* change from access_attributes to owner on dynamically created resources
 * define simpler string based conditions for more intuitive restriction
This commit is contained in:
Gordon Sim 2025-05-29 20:21:20 +01:00
parent 01ad876012
commit 96cd51a0c8
20 changed files with 427 additions and 431 deletions

View file

@ -4,16 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Protocol from typing import Any
from llama_stack.distribution.request_headers import User from llama_stack.distribution.datatypes import User
from .conditions import (
Condition,
ProtectedResource,
parse_conditions,
)
from .datatypes import ( from .datatypes import (
AccessAttributes,
AccessRule, AccessRule,
Action, Action,
AttributeReference,
Condition,
Scope, Scope,
) )
@ -37,43 +39,6 @@ def matches_scope(
return action in scope.actions return action in scope.actions
def user_in_literal(
literal: str,
user_attributes: dict[str, list[str]] | None,
) -> bool:
for qualifier in ["role::", "team::", "project::", "namespace::"]:
if literal.startswith(qualifier):
if not user_attributes:
return False
ref = qualifier.replace("::", "s")
if ref in user_attributes:
value = literal.removeprefix(qualifier)
return value in user_attributes[ref]
else:
return False
return False
def user_in(
ref: AttributeReference | str,
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
if not ref.startswith("resource."):
return user_in_literal(ref, user_attributes)
name = ref.removeprefix("resource.")
required = resource_attributes and getattr(resource_attributes, name)
if not required:
return True
if not user_attributes or name not in user_attributes:
return False
actual = user_attributes[name]
for value in required:
if value in actual:
return True
return False
def as_list(obj: Any) -> list[Any]: def as_list(obj: Any) -> list[Any]:
if isinstance(obj, list): if isinstance(obj, list):
return obj return obj
@ -82,55 +47,27 @@ def as_list(obj: Any) -> list[Any]:
def matches_conditions( def matches_conditions(
conditions: list[Condition], conditions: list[Condition],
resource_attributes: AccessAttributes | None, resource: ProtectedResource,
user_attributes: dict[str, list[str]] | None, user: User,
) -> bool: ) -> bool:
for condition in conditions: for condition in conditions:
# must match all conditions # must match all conditions
if not matches_condition(condition, resource_attributes, user_attributes): if not condition.matches(resource, user):
return False return False
return True return True
def matches_condition(
condition: Condition | list[Condition],
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
if isinstance(condition, list):
return matches_conditions(as_list(condition), resource_attributes, user_attributes)
if condition.user_in:
for ref in as_list(condition.user_in):
# if multiple references are specified, all must match
if not user_in(ref, resource_attributes, user_attributes):
return False
return True
if condition.user_not_in:
for ref in as_list(condition.user_not_in):
# if multiple references are specified, none must match
if user_in(ref, resource_attributes, user_attributes):
return False
return True
return True
def default_policy() -> list[AccessRule]: def default_policy() -> list[AccessRule]:
# for backwards compatibility, if no rules are provided , assume # for backwards compatibility, if no rules are provided, assume
# full access to all subject to attribute matching rules # full access subject to previous attribute matching rules
return [ return [
AccessRule( AccessRule(
permit=Scope(actions=list(Action)), permit=Scope(actions=list(Action)),
when=Condition(user_in=list(AttributeReference)), when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
) ),
] ]
class ProtectedResource(Protocol):
type: str
identifier: str
access_attributes: AccessAttributes
def is_action_allowed( def is_action_allowed(
policy: list[AccessRule], policy: list[AccessRule],
action: Action, action: Action,
@ -144,26 +81,23 @@ def is_action_allowed(
if not len(policy): if not len(policy):
policy = default_policy() policy = default_policy()
resource_attributes = AccessAttributes()
if hasattr(resource, "access_attributes"):
resource_attributes = resource.access_attributes
qualified_resource_id = resource.type + "::" + resource.identifier qualified_resource_id = resource.type + "::" + resource.identifier
for rule in policy: for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal): if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when: if rule.when:
if matches_condition(rule.when, resource_attributes, user.attributes): if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return False return False
elif rule.unless: elif rule.unless:
if not matches_condition(rule.unless, resource_attributes, user.attributes): if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return False return False
else: else:
return False return False
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal): elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
if rule.when: if rule.when:
if matches_condition(rule.when, resource_attributes, user.attributes): if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return True return True
elif rule.unless: elif rule.unless:
if not matches_condition(rule.unless, resource_attributes, user.attributes): if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return True return True
else: else:
return True return True

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
class User(Protocol):
principal: str
attributes: dict[str, list[str]] | None
class ProtectedResource(Protocol):
type: str
identifier: str
owner: User
class Condition(Protocol):
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
class UserInOwnersList:
def __init__(self, name: str):
self.name = name
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
if (
hasattr(resource, "owner")
and resource.owner
and resource.owner.attributes
and self.name in resource.owner.attributes
):
return resource.owner.attributes[self.name]
else:
return None
def matches(self, resource: ProtectedResource, user: User) -> bool:
required = self.owners_values(resource)
if not required:
return True
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
return False
user_values = user.attributes[self.name]
for value in required:
if value in user_values:
return True
return False
def __repr__(self):
return f"user in owners {self.name}"
class UserNotInOwnersList(UserInOwnersList):
def __init__(self, name: str):
super().__init__(name)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user not in owners {self.name}"
class UserWithValueInList:
def __init__(self, name: str, value: str):
self.name = name
self.value = value
def matches(self, resource: ProtectedResource, user: User) -> bool:
if user.attributes and self.name in user.attributes:
return self.value in user.attributes[self.name]
print(f"User does not have {self.value} in {self.name}")
return False
def __repr__(self):
return f"user with {self.value} in {self.name}"
class UserWithValueNotInList(UserWithValueInList):
def __init__(self, name: str, value: str):
super().__init__(name, value)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user with {self.value} not in {self.name}"
class UserIsOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return resource.owner.principal == user.principal if resource.owner else False
def __repr__(self):
return "user is owner"
class UserIsNotOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not resource.owner or resource.owner.principal != user.principal
def __repr__(self):
return "user is not owner"
def parse_condition(condition: str) -> Condition:
words = condition.split()
match words:
case ["user", "is", "owner"]:
return UserIsOwner()
case ["user", "is", "not", "owner"]:
return UserIsNotOwner()
case ["user", "with", value, "in", name]:
return UserWithValueInList(name, value)
case ["user", "with", value, "not", "in", name]:
return UserWithValueNotInList(name, value)
case ["user", "in", "owners", name]:
return UserInOwnersList(name)
case ["user", "not", "in", "owners", name]:
return UserNotInOwnersList(name)
case _:
raise ValueError(f"Invalid condition: {condition}")
def parse_conditions(conditions: list[str]) -> list[Condition]:
return [parse_condition(c) for c in conditions]

View file

@ -6,37 +6,10 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, model_validator
from typing_extensions import Self from typing_extensions import Self
from .conditions import parse_conditions
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class Action(str, Enum): class Action(str, Enum):
@ -62,18 +35,6 @@ def _require_one_of(obj, a: str, b: str):
raise ValueError(f"on of {a} or {b} is required") raise ValueError(f"on of {a} or {b} is required")
class AttributeReference(str, Enum):
RESOURCE_ROLES = "resource.roles"
RESOURCE_TEAMS = "resource.teams"
RESOURCE_PROJECTS = "resource.projects"
RESOURCE_NAMESPACES = "resource.namespaces"
class Condition(BaseModel):
user_in: AttributeReference | list[AttributeReference] | str | None = None
user_not_in: AttributeReference | list[AttributeReference] | str | None = None
class AccessRule(BaseModel): class AccessRule(BaseModel):
"""Access rule based loosely on cedar policy language """Access rule based loosely on cedar policy language
@ -85,10 +46,14 @@ class AccessRule(BaseModel):
requests. requests.
A rule may also specify a condition, either a 'when' or an 'unless', with additional A rule may also specify a condition, either a 'when' or an 'unless', with additional
constraints as to where the rule applies. The constraints at present are whether the constraints as to where the rule applies. The constraints supported at present are:
user requesting access is in or not in some set. This set can either be a particular
set of attributes on the resource e.g. resource.roles or a literal value of some - 'user with <attr-value> in <attr-name>'
notion of group, e.g. role::admin or namespace::foo. - 'user with <attr-value> not in <attr-name>'
- 'user is owner'
- 'user is not owner'
- 'user in owners <attr-name>'
- 'user not in owners <attr-name>'
Rules are tested in order to find a match. If a match is found, the request is Rules are tested in order to find a match. If a match is found, the request is
permitted or forbidden depending on the type of rule. If no match is found, the permitted or forbidden depending on the type of rule. If no match is found, the
@ -110,22 +75,20 @@ class AccessRule(BaseModel):
description: user-2 has read access to model-1 only description: user-2 has read access to model-1 only
- permit: - permit:
actions: [read] actions: [read]
when: when: user in owner teams
user_in: resource.namespaces description: any user has read access to any resource created by a member of their team
description: any user has read access to any resource with matching attributes
- forbid: - forbid:
actions: [create, read, delete] actions: [create, read, delete]
resource: vector_db::* resource: vector_db::*
unless: unless: user with admin in roles
user_in: role::admin
description: only user with admin role can use vector_db resources description: only user with admin role can use vector_db resources
""" """
permit: Scope | None = None permit: Scope | None = None
forbid: Scope | None = None forbid: Scope | None = None
when: Condition | list[Condition] | None = None when: str | list[str] | None = None
unless: Condition | list[Condition] | None = None unless: str | list[str] | None = None
description: str | None = None description: str | None = None
@model_validator(mode="after") @model_validator(mode="after")
@ -133,4 +96,12 @@ class AccessRule(BaseModel):
_require_one_of(self, "permit", "forbid") _require_one_of(self, "permit", "forbid")
_mutually_exclusive(self, "permit", "forbid") _mutually_exclusive(self, "permit", "forbid")
_mutually_exclusive(self, "when", "unless") _mutually_exclusive(self, "when", "unless")
if isinstance(self.when, list):
parse_conditions(self.when)
elif self.when:
parse_conditions([self.when])
if isinstance(self.unless, list):
parse_conditions(self.unless)
elif self.unless:
parse_conditions([self.unless])
return self return self

View file

@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
@ -36,97 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = str | list[str] RoutingKey = str | list[str]
class ResourceWithACL(Resource): class User(BaseModel):
"""Extension of Resource that adds attribute-based access control capabilities. principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
This class adds an optional access_attributes field that allows fine-grained control def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
over which users can access each resource. When attributes are defined, a user must have super().__init__(principal=principal, attributes=attributes)
matching attributes to access the resource.
Attribute Matching Algorithm:
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
4. Within each category, ANY value match is sufficient (OR relationship within a category)
Examples: class ResourceWithOwner(Resource):
# Resource visible to everyone (no access control) """Extension of Resource that adds an optional owner, i.e. the user that created the
model = Model(identifier="llama-2", ...) resource. This can be used to constrain access to the resource."""
# Resource visible only to admins owner: User | None = None
model = Model(
identifier="gpt-4",
access_attributes=AccessAttributes(roles=["admin"])
)
# Resource visible to data scientists on the ML team
model = Model(
identifier="private-model",
access_attributes=AccessAttributes(
roles=["data-scientist", "researcher"],
teams=["ml-team"]
)
)
# ^ User must have at least one of the roles AND be on the ml-team
# Resource visible to users with specific project access
vector_db = VectorDB(
identifier="customer-embeddings",
access_attributes=AccessAttributes(
projects=["customer-insights"],
namespaces=["confidential"]
)
)
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: AccessAttributes | None = None
# Use the extended Resource for all routable objects # Use the extended Resource for all routable objects
class ModelWithACL(Model, ResourceWithACL): class ModelWithOwner(Model, ResourceWithOwner):
pass pass
class ShieldWithACL(Shield, ResourceWithACL): class ShieldWithOwner(Shield, ResourceWithOwner):
pass pass
class VectorDBWithACL(VectorDB, ResourceWithACL): class VectorDBWithOwner(VectorDB, ResourceWithOwner):
pass pass
class DatasetWithACL(Dataset, ResourceWithACL): class DatasetWithOwner(Dataset, ResourceWithOwner):
pass pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL): class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
pass pass
class BenchmarkWithACL(Benchmark, ResourceWithACL): class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
pass pass
class ToolWithACL(Tool, ResourceWithACL): class ToolWithOwner(Tool, ResourceWithOwner):
pass pass
class ToolGroupWithACL(ToolGroup, ResourceWithACL): class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass pass
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
RoutableObjectWithProvider = Annotated[ RoutableObjectWithProvider = Annotated[
ModelWithACL ModelWithOwner
| ShieldWithACL | ShieldWithOwner
| VectorDBWithACL | VectorDBWithOwner
| DatasetWithACL | DatasetWithOwner
| ScoringFnWithACL | ScoringFnWithOwner
| BenchmarkWithACL | BenchmarkWithOwner
| ToolWithACL | ToolWithOwner
| ToolGroupWithACL, | ToolGroupWithOwner,
Field(discriminator="type"), Field(discriminator="type"),
] ]

View file

@ -10,6 +10,8 @@ import logging
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import User
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -18,16 +20,6 @@ log = logging.getLogger(__name__)
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class User:
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]]
def __init__(self, principal: str, attributes: dict[str, list[str]]):
self.principal = principal
self.attributes = attributes
class RequestProviderDataContext(AbstractContextManager): class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data""" """Context manager for request provider data"""

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BenchmarkWithACL, BenchmarkWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
) )
if provider_benchmark_id is None: if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL( benchmark = BenchmarkWithOwner(
identifier=benchmark_id, identifier=benchmark_id,
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,

View file

@ -10,7 +10,6 @@ from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AccessAttributes,
AccessRule, AccessRule,
RoutableObject, RoutableObject,
RoutableObjectWithProvider, RoutableObjectWithProvider,
@ -195,9 +194,9 @@ class CommonRoutingTableImpl(RoutingTable):
creator = get_authenticated_user() creator = get_authenticated_user()
if not is_action_allowed(self.policy, "create", obj, creator): if not is_action_allowed(self.policy, "create", obj, creator):
raise AccessDeniedError() raise AccessDeniedError()
if creator and creator.attributes: if creator:
obj.access_attributes = AccessAttributes(**creator.attributes) obj.owner = creator
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p) registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object # TODO: This needs to be fixed for all APIs once they return the registered object

View file

@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
) )
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
DatasetWithACL, DatasetWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
dataset = DatasetWithACL( dataset = DatasetWithOwner(
identifier=dataset_id, identifier=dataset_id,
provider_resource_id=provider_dataset_id, provider_resource_id=provider_dataset_id,
provider_id=provider_id, provider_id=provider_id,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ModelWithACL, ModelWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata") raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = ModelWithACL( model = ModelWithOwner(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,

View file

@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions, ScoringFunctions,
) )
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ScoringFnWithACL, ScoringFnWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." "No provider specified and multiple providers available. Please specify a provider_id."
) )
scoring_fn = ScoringFnWithACL( scoring_fn = ScoringFnWithOwner(
identifier=scoring_fn_id, identifier=scoring_fn_id,
description=description, description=description,
return_type=return_type, return_type=return_type,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ShieldWithACL, ShieldWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
) )
if params is None: if params is None:
params = {} params = {}
shield = ShieldWithACL( shield = ShieldWithOwner(
identifier=shield_id, identifier=shield_id,
provider_resource_id=provider_shield_id, provider_resource_id=provider_shield_id,
provider_id=provider_id, provider_id=provider_id,

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithACL from llama_stack.distribution.datatypes import ToolGroupWithOwner
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
@ -88,7 +88,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
toolgroup = ToolGroupWithACL( toolgroup = ToolGroupWithOwner(
identifier=toolgroup_id, identifier=toolgroup_id,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=toolgroup_id, provider_resource_id=toolgroup_id,

View file

@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
VectorDBWithACL, VectorDBWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
} }
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)
return vector_db return vector_db

View file

@ -105,24 +105,16 @@ class AuthenticationMiddleware:
logger.exception("Error during authentication") logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error") return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control
if validation_result.access_attributes:
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"roles": [token],
}
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits. # can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token scope["authenticated_client_id"] = token
# Store attributes in request scope # Store attributes in request scope
scope["user_attributes"] = user_attributes
scope["principal"] = validation_result.principal scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug( logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes" f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
) )
return await self.app(scope, receive, send) return await self.app(scope, receive, send)

View file

@ -16,43 +16,18 @@ from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
class TokenValidationResult(BaseModel): class AuthResponse(BaseModel):
principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
class AuthResponse(TokenValidationResult):
"""The format of the authentication response from the auth endpoint.""" """The format of the authentication response from the auth endpoint."""
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
message: str | None = Field( message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result." default=None, description="Optional message providing additional context about the authentication result."
) )
@ -78,7 +53,7 @@ class AuthProvider(ABC):
"""Abstract base class for authentication providers.""" """Abstract base class for authentication providers."""
@abstractmethod @abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token and return access attributes.""" """Validate a token and return access attributes."""
pass pass
@ -88,10 +63,10 @@ class AuthProvider(ABC):
pass pass
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes = AccessAttributes() attributes: dict[str, list[str]] = {}
for claim_key, attribute_key in mapping.items(): for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key): if claim_key not in claims:
continue continue
claim = claims[claim_key] claim = claims[claim_key]
if isinstance(claim, list): if isinstance(claim, list):
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
else: else:
values = claim.split() values = claim.split()
current = getattr(attributes, attribute_key) if attribute_key in attributes:
if current: attributes[attribute_key].extend(values)
current.extend(values)
else: else:
setattr(attributes, attribute_key, values) attributes[attribute_key] = values
return attributes return attributes
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
for key, value in v.items(): for key, value in v.items():
if not value: if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}") raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v return v
@model_validator(mode="after") @model_validator(mode="after")
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
self._jwks: dict[str, str] = {} self._jwks: dict[str, str] = {}
self._jwks_lock = Lock() self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks: if self.config.jwks:
return await self.validate_jwt_token(token, scope) return await self.validate_jwt_token(token, scope)
if self.config.introspection: if self.config.introspection:
return await self.introspect_token(token, scope) return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured") raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token.""" """Validate a token using the JWT token."""
await self._refresh_jwks() await self._refresh_jwks()
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
# We should incorporate these into the access attributes. # We should incorporate these into the access attributes.
principal = claims["sub"] principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult( return User(
principal=principal, principal=principal,
access_attributes=access_attributes, attributes=access_attributes,
) )
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def introspect_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using token introspection as defined by RFC 7662.""" """Validate a token using token introspection as defined by RFC 7662."""
form = { form = {
"token": token, "token": token,
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
raise ValueError("Token not active") raise ValueError("Token not active")
principal = fields["sub"] or fields["username"] principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return TokenValidationResult( return User(
principal=principal, principal=principal,
access_attributes=access_attributes, attributes=access_attributes,
) )
except httpx.TimeoutException: except httpx.TimeoutException:
logger.exception("Token introspection request timed out") logger.exception("Token introspection request timed out")
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
self.config = config self.config = config
self._client = None self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the custom authentication endpoint.""" """Validate a token using the custom authentication endpoint."""
if scope is None: if scope is None:
scope = {} scope = {}
@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider):
json=auth_request.model_dump(), json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout timeout=10.0, # Add a reasonable timeout
) )
print("MADE CALL")
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}") logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}")
@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider):
try: try:
response_data = response.json() response_data = response.json()
auth_response = AuthResponse(**response_data) auth_response = AuthResponse(**response_data)
return auth_response return User(auth_response.principal, auth_response.attributes)
except Exception as e: except Exception as e:
logger.exception("Error parsing authentication response") logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e raise ValueError("Invalid authentication response format") from e

View file

@ -11,7 +11,8 @@ from datetime import datetime, timezone
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
@ -22,7 +23,7 @@ class AgentSessionInfo(Session):
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: str | None = None vector_db_id: str | None = None
started_at: datetime started_at: datetime
access_attributes: AccessAttributes | None = None owner: User | None = None
identifier: str | None = None identifier: str | None = None
type: str = "session" type: str = "session"
@ -42,14 +43,12 @@ class AgentPersistence:
# Get current user's auth attributes for new sessions # Get current user's auth attributes for new sessions
user = get_authenticated_user() user = get_authenticated_user()
auth_attributes = user and user.attributes
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo( session_info = AgentSessionInfo(
session_id=session_id, session_id=session_id,
session_name=name, session_name=name,
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
access_attributes=access_attributes, owner=user,
turns=[], turns=[],
identifier=name, # should this be qualified in any way? identifier=name, # should this be qualified in any way?
) )
@ -80,7 +79,7 @@ class AgentPersistence:
def _check_session_access(self, session_info: AgentSessionInfo) -> bool: def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session.""" """Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control # Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"): if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())

View file

@ -12,8 +12,7 @@ import pytest
from llama_stack.apis.agents import Turn from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@ -38,9 +37,10 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
# Get the session and verify access attributes were set # Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id) session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None assert session_info is not None
assert session_info.access_attributes is not None assert session_info.owner is not None
assert session_info.access_attributes.roles == ["researcher"] assert session_info.owner.attributes is not None
assert session_info.access_attributes.teams == ["ai-team"] assert session_info.owner.attributes["roles"] == ["researcher"]
assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio @pytest.mark.asyncio
@ -54,7 +54,7 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
session_id=session_id, session_id=session_id,
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
turns=[], turns=[],
identifier="Restricted Session", identifier="Restricted Session",
) )
@ -89,7 +89,7 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
session_id=session_id, session_id=session_id,
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]), owner=User("someone", {"roles": ["admin"]}),
turns=[], turns=[],
identifier="Restricted Session", identifier="Restricted Session",
) )
@ -143,7 +143,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_u
session_id=session_id, session_id=session_id,
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]), owner=User("someone", {"roles": ["admin"]}),
turns=[], turns=[],
identifier="Restricted Session", identifier="Restricted Session",
) )

View file

@ -8,19 +8,18 @@
import pytest import pytest
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithACL from llama_stack.distribution.datatypes import ModelWithOwner, User
from llama_stack.distribution.server.auth_providers import AccessAttributes
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_cache_with_acl(cached_disk_dist_registry): async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithACL( model = ModelWithOwner(
identifier="model-acl", identifier="model-acl",
provider_id="test-provider", provider_id="test-provider",
provider_resource_id="model-acl-resource", provider_resource_id="model-acl-resource",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]), owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
) )
success = await cached_disk_dist_registry.register(model) success = await cached_disk_dist_registry.register(model)
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.identifier == "model-acl" assert cached_model.identifier == "model-acl"
assert cached_model.access_attributes.roles == ["admin"] assert cached_model.owner.principal == "testuser"
assert cached_model.access_attributes.teams == ["ai-team"] assert cached_model.owner.attributes["roles"] == ["admin"]
assert cached_model.owner.attributes["teams"] == ["ai-team"]
fetched_model = await cached_disk_dist_registry.get("model", "model-acl") fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
assert fetched_model is not None assert fetched_model is not None
assert fetched_model.identifier == "model-acl" assert fetched_model.identifier == "model-acl"
assert fetched_model.access_attributes.roles == ["admin"] assert fetched_model.owner.attributes["roles"] == ["admin"]
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
await cached_disk_dist_registry.update(model)
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
assert updated_cached is not None
assert updated_cached.access_attributes.roles == ["admin", "user"]
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore) new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize() await new_registry.initialize()
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
new_model = await new_registry.get("model", "model-acl") new_model = await new_registry.get("model", "model-acl")
assert new_model is not None assert new_model is not None
assert new_model.identifier == "model-acl" assert new_model.identifier == "model-acl"
assert new_model.access_attributes.roles == ["admin", "user"] assert new_model.owner.principal == "testuser"
assert new_model.access_attributes.projects == ["project-x"] assert new_model.owner.attributes["roles"] == ["admin"]
assert new_model.access_attributes.teams is None assert new_model.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_empty_acl(cached_disk_dist_registry): async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithACL( model = ModelWithOwner(
identifier="model-empty-acl", identifier="model-empty-acl",
provider_id="test-provider", provider_id="test-provider",
provider_resource_id="model-resource", provider_resource_id="model-resource",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(), owner=User("testuser", None),
) )
await cached_disk_dist_registry.register(model) await cached_disk_dist_registry.register(model)
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.access_attributes is not None assert cached_model.owner is not None
assert cached_model.access_attributes.roles is None assert cached_model.owner.attributes is None
assert cached_model.access_attributes.teams is None
assert cached_model.access_attributes.projects is None
assert cached_model.access_attributes.namespaces is None
all_models = await cached_disk_dist_registry.get_all() all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 1 assert len(all_models) == 1
model = ModelWithACL( model = ModelWithOwner(
identifier="model-no-acl", identifier="model-no-acl",
provider_id="test-provider", provider_id="test-provider",
provider_resource_id="model-resource-2", provider_resource_id="model-resource-2",
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.access_attributes is None assert cached_model.owner is None
all_models = await cached_disk_dist_registry.get_all() all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 2 assert len(all_models) == 2
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_serialization(cached_disk_dist_registry): async def test_registry_serialization(cached_disk_dist_registry):
attributes = AccessAttributes( attributes = {
roles=["admin", "researcher"], "roles": ["admin", "researcher"],
teams=["ai-team", "ml-team"], "teams": ["ai-team", "ml-team"],
projects=["project-a", "project-b"], "projects": ["project-a", "project-b"],
namespaces=["prod", "staging"], "namespaces": ["prod", "staging"],
) }
model = ModelWithACL( model = ModelWithOwner(
identifier="model-serialize", identifier="model-serialize",
provider_id="test-provider", provider_id="test-provider",
provider_resource_id="model-resource", provider_resource_id="model-resource",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=attributes, owner=User("bob", attributes),
) )
await cached_disk_dist_registry.register(model) await cached_disk_dist_registry.register(model)
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
loaded_model = await new_registry.get("model", "model-serialize") loaded_model = await new_registry.get("model", "model-serialize")
assert loaded_model is not None assert loaded_model is not None
assert loaded_model.access_attributes.roles == ["admin", "researcher"] assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"] assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
assert loaded_model.access_attributes.projects == ["project-a", "project-b"] assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
assert loaded_model.access_attributes.namespaces == ["prod", "staging"] assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]

View file

@ -8,13 +8,12 @@ from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
import yaml import yaml
from pydantic import TypeAdapter from pydantic import TypeAdapter, ValidationError
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.distribution.request_headers import User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -45,25 +44,25 @@ async def test_setup(cached_disk_dist_registry):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup): async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithOwner(
identifier="model-public", identifier="model-public",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-public", provider_resource_id="model-public",
model_type=ModelType.llm, model_type=ModelType.llm,
) )
model_admin_only = ModelWithACL( model_admin_only = ModelWithOwner(
identifier="model-admin", identifier="model-admin",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-admin", provider_resource_id="model-admin",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]), owner=User("testuser", {"roles": ["admin"]}),
) )
model_data_scientist = ModelWithACL( model_data_scientist = ModelWithOwner(
identifier="model-data-scientist", identifier="model-data-scientist",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-data-scientist", provider_resource_id="model-data-scientist",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]), owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
) )
await registry.register(model_public) await registry.register(model_public)
await registry.register(model_admin_only) await registry.register(model_admin_only)
@ -110,7 +109,7 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup): async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithOwner(
identifier="model-updates", identifier="model-updates",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-updates", provider_resource_id="model-updates",
@ -125,7 +124,7 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
) )
model = await routing_table.get_model("model-updates") model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates" assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"]) model_public.owner = User("testuser", {"roles": ["admin"]})
await registry.update(model_public) await registry.update(model_public)
mock_get_authenticated_user.return_value = User( mock_get_authenticated_user.return_value = User(
"test-user", "test-user",
@ -149,12 +148,12 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup): async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model = ModelWithACL( model = ModelWithOwner(
identifier="model-empty-attrs", identifier="model-empty-attrs",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-empty-attrs", provider_resource_id="model-empty-attrs",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(), owner=User("testuser", {}),
) )
await registry.register(model) await registry.register(model)
mock_get_authenticated_user.return_value = User( mock_get_authenticated_user.return_value = User(
@ -174,18 +173,18 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup): async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithOwner(
identifier="model-public-2", identifier="model-public-2",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-public-2", provider_resource_id="model-public-2",
model_type=ModelType.llm, model_type=ModelType.llm,
) )
model_restricted = ModelWithACL( model_restricted = ModelWithOwner(
identifier="model-restricted", identifier="model-restricted",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-restricted", provider_resource_id="model-restricted",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]), owner=User("testuser", {"roles": ["admin"]}),
) )
await registry.register(model_public) await registry.register(model_public)
await registry.register(model_restricted) await registry.register(model_restricted)
@ -212,7 +211,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
mock_get_authenticated_user.return_value = User("test-user", creator_attributes) mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
# Create model without explicit access attributes # Create model without explicit access attributes
model = ModelWithACL( model = ModelWithOwner(
identifier="auto-access-model", identifier="auto-access-model",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="auto-access-model", provider_resource_id="auto-access-model",
@ -222,10 +221,11 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
# Verify the model got creator's attributes # Verify the model got creator's attributes
registered_model = await routing_table.get_model("auto-access-model") registered_model = await routing_table.get_model("auto-access-model")
assert registered_model.access_attributes is not None assert registered_model.owner is not None
assert registered_model.access_attributes.roles == ["data-scientist"] assert registered_model.owner.attributes is not None
assert registered_model.access_attributes.teams == ["ml-team"] assert registered_model.owner.attributes["roles"] == ["data-scientist"]
assert registered_model.access_attributes.projects == ["llama-3"] assert registered_model.owner.attributes["teams"] == ["ml-team"]
assert registered_model.owner.attributes["projects"] == ["llama-3"]
# Verify another user without matching attributes can't access it # Verify another user without matching attributes can't access it
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]}) mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
@ -354,15 +354,14 @@ def test_permit_when():
- permit: - permit:
principal: user-1 principal: user-1
actions: [read] actions: [read]
when: when: user in owners namespaces
user_in: resource.namespaces
""" """
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL( model = ModelWithOwner(
identifier="mymodel", identifier="mymodel",
provider_id="myprovider", provider_id="myprovider",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]), owner=User("testuser", {"namespaces": ["foo"]}),
) )
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
@ -376,15 +375,15 @@ def test_permit_unless():
actions: [read] actions: [read]
resource: model::* resource: model::*
unless: unless:
- user_not_in: resource.namespaces - user not in owners namespaces
- user_in: resource.teams - user in owners teams
""" """
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL( model = ModelWithOwner(
identifier="mymodel", identifier="mymodel",
provider_id="myprovider", provider_id="myprovider",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]), owner=User("testuser", {"namespaces": ["foo"]}),
) )
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
@ -397,16 +396,16 @@ def test_forbid_when():
principal: user-1 principal: user-1
actions: [read] actions: [read]
when: when:
user_in: resource.namespaces user in owners namespaces
- permit: - permit:
actions: [read] actions: [read]
""" """
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL( model = ModelWithOwner(
identifier="mymodel", identifier="mymodel",
provider_id="myprovider", provider_id="myprovider",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]), owner=User("testuser", {"namespaces": ["foo"]}),
) )
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
@ -419,35 +418,33 @@ def test_forbid_unless():
principal: user-1 principal: user-1
actions: [read] actions: [read]
unless: unless:
user_in: resource.namespaces user in owners namespaces
- permit: - permit:
actions: [read] actions: [read]
""" """
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL( model = ModelWithOwner(
identifier="mymodel", identifier="mymodel",
provider_id="myprovider", provider_id="myprovider",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]), owner=User("testuser", {"namespaces": ["foo"]}),
) )
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]})) assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_condition_with_literal(): def test_user_has_attribute():
config = """ config = """
- permit: - permit:
actions: [read] actions: [read]
when: when: user with admin in roles
user_in: role::admin
""" """
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL( model = ModelWithOwner(
identifier="mymodel", identifier="mymodel",
provider_id="myprovider", provider_id="myprovider",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
) )
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]})) assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
@ -455,35 +452,115 @@ def test_condition_with_literal():
assert not is_action_allowed(policy, "read", model, User("user-4", None)) assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_condition_with_unrecognised_literal(): def test_user_does_not_have_attribute():
config = """ config = """
- permit: - permit:
actions: [read] actions: [read]
when: unless: user with admin not in roles
user_in: whatever
""" """
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL( model = ModelWithOwner(
identifier="mymodel", identifier="mymodel",
provider_id="myprovider", provider_id="myprovider",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
) )
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", None)) assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_empty_condition(): def test_is_owner():
config = """ config = """
- permit: - permit:
actions: [read] actions: [read]
when: {} when: user is owner
""" """
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config)) policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL( model = ModelWithOwner(
identifier="mymodel", identifier="mymodel",
provider_id="myprovider", provider_id="myprovider",
model_type=ModelType.llm, model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
) )
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]})) assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", None)) assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_is_not_owner():
config = """
- permit:
actions: [read]
unless: user is not owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_invalid_rule_permit_and_forbid_both_specified():
config = """
- permit:
actions: [read]
forbid:
actions: [create]
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_neither_permit_or_forbid_specified():
config = """
- when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_when_and_unless_both_specified():
config = """
- permit:
actions: [read]
when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_condition():
config = """
- permit:
actions: [read]
when: random words that are not valid
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
@pytest.mark.parametrize(
"condition",
[
"user is owner",
"user is not owner",
"user with dev in teams",
"user with default not in namespaces",
"user in owners roles",
"user not in owners projects",
],
)
def test_condition_reprs(condition):
from llama_stack.distribution.access_control.conditions import parse_condition
assert condition == str(parse_condition(condition))

View file

@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
{ {
"message": "Authentication successful", "message": "Authentication successful",
"principal": "test-principal", "principal": "test-principal",
"access_attributes": { "attributes": {
"roles": ["admin", "user"], "roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"], "teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"], "projects": ["llama-3", "project-x"],
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
{ {
"message": "Authentication successful", "message": "Authentication successful",
"principal": "test-principal", "principal": "test-principal",
"access_attributes": { "attributes": {
"roles": ["admin", "user"], "roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"], "teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"], "projects": ["llama-3", "project-x"],
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
@pytest.mark.asyncio
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
"""Test middleware behavior with no access attributes"""
middleware, mock_app = mock_http_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_client_instance
mock_client_instance.post.return_value = MockResponse(
200,
{
"message": "Authentication successful"
# No access_attributes
},
)
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
attributes = mock_scope["user_attributes"]
assert "roles" in attributes
assert attributes["roles"] == ["test.jwt.token"]
# oauth2 token provider tests # oauth2 token provider tests
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
"aud": "llama-stack", "aud": "llama-stack",
} }
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"}) attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
assert attributes.roles == ["my-user"] assert attributes["roles"] == ["my-user"]
assert attributes.teams == ["group1", "group2"] assert attributes["teams"] == ["group1", "group2"]
claims = { claims = {
"sub": "my-user", "sub": "my-user",
"tenant": "my-tenant", "tenant": "my-tenant",
} }
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"}) attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
assert attributes.roles == ["my-user"] assert attributes["roles"] == ["my-user"]
assert attributes.namespaces == ["my-tenant"] assert attributes["namespaces"] == ["my-tenant"]
claims = { claims = {
"sub": "my-user", "sub": "my-user",
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
"groups": "teams", "groups": "teams",
}, },
) )
assert set(attributes.roles) == {"my-user", "my-username"} assert set(attributes["roles"]) == {"my-user", "my-username"}
assert set(attributes.teams) == {"my-team", "group1", "group2"} assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
assert attributes.namespaces == ["my-tenant"] assert attributes["namespaces"] == ["my-tenant"]
# TODO: add more tests for oauth2 token provider # TODO: add more tests for oauth2 token provider