Changes to access rule conditions:

* change from access_attributes to owner on dynamically created resources
 * define simpler string based conditions for more intuitive restriction
This commit is contained in:
Gordon Sim 2025-05-29 20:21:20 +01:00
parent 01ad876012
commit 96cd51a0c8
20 changed files with 427 additions and 431 deletions

View file

@ -4,16 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol
from typing import Any
from llama_stack.distribution.request_headers import User
from llama_stack.distribution.datatypes import User
from .conditions import (
Condition,
ProtectedResource,
parse_conditions,
)
from .datatypes import (
AccessAttributes,
AccessRule,
Action,
AttributeReference,
Condition,
Scope,
)
@ -37,43 +39,6 @@ def matches_scope(
return action in scope.actions
def user_in_literal(
literal: str,
user_attributes: dict[str, list[str]] | None,
) -> bool:
for qualifier in ["role::", "team::", "project::", "namespace::"]:
if literal.startswith(qualifier):
if not user_attributes:
return False
ref = qualifier.replace("::", "s")
if ref in user_attributes:
value = literal.removeprefix(qualifier)
return value in user_attributes[ref]
else:
return False
return False
def user_in(
ref: AttributeReference | str,
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
if not ref.startswith("resource."):
return user_in_literal(ref, user_attributes)
name = ref.removeprefix("resource.")
required = resource_attributes and getattr(resource_attributes, name)
if not required:
return True
if not user_attributes or name not in user_attributes:
return False
actual = user_attributes[name]
for value in required:
if value in actual:
return True
return False
def as_list(obj: Any) -> list[Any]:
if isinstance(obj, list):
return obj
@ -82,55 +47,27 @@ def as_list(obj: Any) -> list[Any]:
def matches_conditions(
conditions: list[Condition],
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
resource: ProtectedResource,
user: User,
) -> bool:
for condition in conditions:
# must match all conditions
if not matches_condition(condition, resource_attributes, user_attributes):
if not condition.matches(resource, user):
return False
return True
def matches_condition(
condition: Condition | list[Condition],
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
if isinstance(condition, list):
return matches_conditions(as_list(condition), resource_attributes, user_attributes)
if condition.user_in:
for ref in as_list(condition.user_in):
# if multiple references are specified, all must match
if not user_in(ref, resource_attributes, user_attributes):
return False
return True
if condition.user_not_in:
for ref in as_list(condition.user_not_in):
# if multiple references are specified, none must match
if user_in(ref, resource_attributes, user_attributes):
return False
return True
return True
def default_policy() -> list[AccessRule]:
# for backwards compatibility, if no rules are provided , assume
# full access to all subject to attribute matching rules
# for backwards compatibility, if no rules are provided, assume
# full access subject to previous attribute matching rules
return [
AccessRule(
permit=Scope(actions=list(Action)),
when=Condition(user_in=list(AttributeReference)),
)
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
),
]
class ProtectedResource(Protocol):
type: str
identifier: str
access_attributes: AccessAttributes
def is_action_allowed(
policy: list[AccessRule],
action: Action,
@ -144,26 +81,23 @@ def is_action_allowed(
if not len(policy):
policy = default_policy()
resource_attributes = AccessAttributes()
if hasattr(resource, "access_attributes"):
resource_attributes = resource.access_attributes
qualified_resource_id = resource.type + "::" + resource.identifier
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:
if matches_condition(rule.when, resource_attributes, user.attributes):
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return False
elif rule.unless:
if not matches_condition(rule.unless, resource_attributes, user.attributes):
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return False
else:
return False
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
if rule.when:
if matches_condition(rule.when, resource_attributes, user.attributes):
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return True
elif rule.unless:
if not matches_condition(rule.unless, resource_attributes, user.attributes):
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return True
else:
return True

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
class User(Protocol):
principal: str
attributes: dict[str, list[str]] | None
class ProtectedResource(Protocol):
type: str
identifier: str
owner: User
class Condition(Protocol):
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
class UserInOwnersList:
def __init__(self, name: str):
self.name = name
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
if (
hasattr(resource, "owner")
and resource.owner
and resource.owner.attributes
and self.name in resource.owner.attributes
):
return resource.owner.attributes[self.name]
else:
return None
def matches(self, resource: ProtectedResource, user: User) -> bool:
required = self.owners_values(resource)
if not required:
return True
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
return False
user_values = user.attributes[self.name]
for value in required:
if value in user_values:
return True
return False
def __repr__(self):
return f"user in owners {self.name}"
class UserNotInOwnersList(UserInOwnersList):
def __init__(self, name: str):
super().__init__(name)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user not in owners {self.name}"
class UserWithValueInList:
def __init__(self, name: str, value: str):
self.name = name
self.value = value
def matches(self, resource: ProtectedResource, user: User) -> bool:
if user.attributes and self.name in user.attributes:
return self.value in user.attributes[self.name]
print(f"User does not have {self.value} in {self.name}")
return False
def __repr__(self):
return f"user with {self.value} in {self.name}"
class UserWithValueNotInList(UserWithValueInList):
def __init__(self, name: str, value: str):
super().__init__(name, value)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user with {self.value} not in {self.name}"
class UserIsOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return resource.owner.principal == user.principal if resource.owner else False
def __repr__(self):
return "user is owner"
class UserIsNotOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not resource.owner or resource.owner.principal != user.principal
def __repr__(self):
return "user is not owner"
def parse_condition(condition: str) -> Condition:
words = condition.split()
match words:
case ["user", "is", "owner"]:
return UserIsOwner()
case ["user", "is", "not", "owner"]:
return UserIsNotOwner()
case ["user", "with", value, "in", name]:
return UserWithValueInList(name, value)
case ["user", "with", value, "not", "in", name]:
return UserWithValueNotInList(name, value)
case ["user", "in", "owners", name]:
return UserInOwnersList(name)
case ["user", "not", "in", "owners", name]:
return UserNotInOwnersList(name)
case _:
raise ValueError(f"Invalid condition: {condition}")
def parse_conditions(conditions: list[str]) -> list[Condition]:
return [parse_condition(c) for c in conditions]

View file

@ -6,37 +6,10 @@
from enum import Enum
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, model_validator
from typing_extensions import Self
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation"
)
from .conditions import parse_conditions
class Action(str, Enum):
@ -62,18 +35,6 @@ def _require_one_of(obj, a: str, b: str):
raise ValueError(f"on of {a} or {b} is required")
class AttributeReference(str, Enum):
RESOURCE_ROLES = "resource.roles"
RESOURCE_TEAMS = "resource.teams"
RESOURCE_PROJECTS = "resource.projects"
RESOURCE_NAMESPACES = "resource.namespaces"
class Condition(BaseModel):
user_in: AttributeReference | list[AttributeReference] | str | None = None
user_not_in: AttributeReference | list[AttributeReference] | str | None = None
class AccessRule(BaseModel):
"""Access rule based loosely on cedar policy language
@ -85,10 +46,14 @@ class AccessRule(BaseModel):
requests.
A rule may also specify a condition, either a 'when' or an 'unless', with additional
constraints as to where the rule applies. The constraints at present are whether the
user requesting access is in or not in some set. This set can either be a particular
set of attributes on the resource e.g. resource.roles or a literal value of some
notion of group, e.g. role::admin or namespace::foo.
constraints as to where the rule applies. The constraints supported at present are:
- 'user with <attr-value> in <attr-name>'
- 'user with <attr-value> not in <attr-name>'
- 'user is owner'
- 'user is not owner'
- 'user in owners <attr-name>'
- 'user not in owners <attr-name>'
Rules are tested in order to find a match. If a match is found, the request is
permitted or forbidden depending on the type of rule. If no match is found, the
@ -99,33 +64,31 @@ class AccessRule(BaseModel):
Some examples in yaml:
- permit:
principal: user-1
actions: [create, read, delete]
resource: model::*
principal: user-1
actions: [create, read, delete]
resource: model::*
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
actions: [read]
when:
user_in: resource.namespaces
description: any user has read access to any resource with matching attributes
actions: [read]
when: user in owner teams
description: any user has read access to any resource created by a member of their team
- forbid:
actions: [create, read, delete]
resource: vector_db::*
unless:
user_in: role::admin
actions: [create, read, delete]
resource: vector_db::*
unless: user with admin in roles
description: only user with admin role can use vector_db resources
"""
permit: Scope | None = None
forbid: Scope | None = None
when: Condition | list[Condition] | None = None
unless: Condition | list[Condition] | None = None
when: str | list[str] | None = None
unless: str | list[str] | None = None
description: str | None = None
@model_validator(mode="after")
@ -133,4 +96,12 @@ class AccessRule(BaseModel):
_require_one_of(self, "permit", "forbid")
_mutually_exclusive(self, "permit", "forbid")
_mutually_exclusive(self, "when", "unless")
if isinstance(self.when, list):
parse_conditions(self.when)
elif self.when:
parse_conditions([self.when])
if isinstance(self.unless, list):
parse_conditions(self.unless)
elif self.unless:
parse_conditions([self.unless])
return self

View file

@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
@ -36,97 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = str | list[str]
class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities.
class User(BaseModel):
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
This class adds an optional access_attributes field that allows fine-grained control
over which users can access each resource. When attributes are defined, a user must have
matching attributes to access the resource.
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
super().__init__(principal=principal, attributes=attributes)
Attribute Matching Algorithm:
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
4. Within each category, ANY value match is sufficient (OR relationship within a category)
Examples:
# Resource visible to everyone (no access control)
model = Model(identifier="llama-2", ...)
class ResourceWithOwner(Resource):
"""Extension of Resource that adds an optional owner, i.e. the user that created the
resource. This can be used to constrain access to the resource."""
# Resource visible only to admins
model = Model(
identifier="gpt-4",
access_attributes=AccessAttributes(roles=["admin"])
)
# Resource visible to data scientists on the ML team
model = Model(
identifier="private-model",
access_attributes=AccessAttributes(
roles=["data-scientist", "researcher"],
teams=["ml-team"]
)
)
# ^ User must have at least one of the roles AND be on the ml-team
# Resource visible to users with specific project access
vector_db = VectorDB(
identifier="customer-embeddings",
access_attributes=AccessAttributes(
projects=["customer-insights"],
namespaces=["confidential"]
)
)
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: AccessAttributes | None = None
owner: User | None = None
# Use the extended Resource for all routable objects
class ModelWithACL(Model, ResourceWithACL):
class ModelWithOwner(Model, ResourceWithOwner):
pass
class ShieldWithACL(Shield, ResourceWithACL):
class ShieldWithOwner(Shield, ResourceWithOwner):
pass
class VectorDBWithACL(VectorDB, ResourceWithACL):
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
pass
class DatasetWithACL(Dataset, ResourceWithACL):
class DatasetWithOwner(Dataset, ResourceWithOwner):
pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
pass
class BenchmarkWithACL(Benchmark, ResourceWithACL):
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
pass
class ToolWithACL(Tool, ResourceWithACL):
class ToolWithOwner(Tool, ResourceWithOwner):
pass
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
RoutableObjectWithProvider = Annotated[
ModelWithACL
| ShieldWithACL
| VectorDBWithACL
| DatasetWithACL
| ScoringFnWithACL
| BenchmarkWithACL
| ToolWithACL
| ToolGroupWithACL,
ModelWithOwner
| ShieldWithOwner
| VectorDBWithOwner
| DatasetWithOwner
| ScoringFnWithOwner
| BenchmarkWithOwner
| ToolWithOwner
| ToolGroupWithOwner,
Field(discriminator="type"),
]

View file

@ -10,6 +10,8 @@ import logging
from contextlib import AbstractContextManager
from typing import Any
from llama_stack.distribution.datatypes import User
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
@ -18,16 +20,6 @@ log = logging.getLogger(__name__)
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class User:
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]]
def __init__(self, principal: str, attributes: dict[str, list[str]]):
self.principal = principal
self.attributes = attributes
class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.distribution.datatypes import (
BenchmarkWithACL,
BenchmarkWithOwner,
)
from llama_stack.log import get_logger
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL(
benchmark = BenchmarkWithOwner(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,

View file

@ -10,7 +10,6 @@ from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import (
AccessAttributes,
AccessRule,
RoutableObject,
RoutableObjectWithProvider,
@ -195,9 +194,9 @@ class CommonRoutingTableImpl(RoutingTable):
creator = get_authenticated_user()
if not is_action_allowed(self.policy, "create", obj, creator):
raise AccessDeniedError()
if creator and creator.attributes:
obj.access_attributes = AccessAttributes(**creator.attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
if creator:
obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object

View file

@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
)
from llama_stack.apis.resource import ResourceType
from llama_stack.distribution.datatypes import (
DatasetWithACL,
DatasetWithOwner,
)
from llama_stack.log import get_logger
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
if metadata is None:
metadata = {}
dataset = DatasetWithACL(
dataset = DatasetWithOwner(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithACL,
ModelWithOwner,
)
from llama_stack.log import get_logger
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = ModelWithACL(
model = ModelWithOwner(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,

View file

@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions,
)
from llama_stack.distribution.datatypes import (
ScoringFnWithACL,
ScoringFnWithOwner,
)
from llama_stack.log import get_logger
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithACL(
scoring_fn = ScoringFnWithOwner(
identifier=scoring_fn_id,
description=description,
return_type=return_type,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.distribution.datatypes import (
ShieldWithACL,
ShieldWithOwner,
)
from llama_stack.log import get_logger
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
if params is None:
params = {}
shield = ShieldWithACL(
shield = ShieldWithOwner(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithACL
from llama_stack.distribution.datatypes import ToolGroupWithOwner
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
@ -88,7 +88,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
toolgroup = ToolGroupWithACL(
toolgroup = ToolGroupWithOwner(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,

View file

@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.datatypes import (
VectorDBWithACL,
VectorDBWithOwner,
)
from llama_stack.log import get_logger
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db

View file

@ -105,24 +105,16 @@ class AuthenticationMiddleware:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control
if validation_result.access_attributes:
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"roles": [token],
}
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
# Store attributes in request scope
scope["user_attributes"] = user_attributes
scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)
return await self.app(scope, receive, send)

View file

@ -16,43 +16,18 @@ from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class TokenValidationResult(BaseModel):
principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
class AuthResponse(TokenValidationResult):
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
@ -78,7 +53,7 @@ class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token and return access attributes."""
pass
@ -88,10 +63,10 @@ class AuthProvider(ABC):
pass
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key):
if claim_key not in claims:
continue
claim = claims[claim_key]
if isinstance(claim, list):
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
else:
values = claim.split()
current = getattr(attributes, attribute_key)
if current:
current.extend(values)
if attribute_key in attributes:
attributes[attribute_key].extend(values)
else:
setattr(attributes, attribute_key, values)
attributes[attribute_key] = values
return attributes
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v
@model_validator(mode="after")
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks:
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult(
return User(
principal=principal,
access_attributes=access_attributes,
attributes=access_attributes,
)
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return TokenValidationResult(
return User(
principal=principal,
access_attributes=access_attributes,
attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
self.config = config
self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the custom authentication endpoint."""
if scope is None:
scope = {}
@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider):
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
print("MADE CALL")
if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider):
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
return auth_response
return User(auth_response.principal, auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e

View file

@ -11,7 +11,8 @@ from datetime import datetime, timezone
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore
@ -22,7 +23,7 @@ class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_db_id: str | None = None
started_at: datetime
access_attributes: AccessAttributes | None = None
owner: User | None = None
identifier: str | None = None
type: str = "session"
@ -42,14 +43,12 @@ class AgentPersistence:
# Get current user's auth attributes for new sessions
user = get_authenticated_user()
auth_attributes = user and user.attributes
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
owner=user,
turns=[],
identifier=name, # should this be qualified in any way?
)
@ -80,7 +79,7 @@ class AgentPersistence:
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"):
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())

View file

@ -12,8 +12,7 @@ import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import User
from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@ -38,9 +37,10 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.access_attributes is not None
assert session_info.access_attributes.roles == ["researcher"]
assert session_info.access_attributes.teams == ["ai-team"]
assert session_info.owner is not None
assert session_info.owner.attributes is not None
assert session_info.owner.attributes["roles"] == ["researcher"]
assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
@ -54,7 +54,7 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
turns=[],
identifier="Restricted Session",
)
@ -89,7 +89,7 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
@ -143,7 +143,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_u
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)

View file

@ -8,19 +8,18 @@
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithACL
from llama_stack.distribution.server.auth_providers import AccessAttributes
from llama_stack.distribution.datatypes import ModelWithOwner, User
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-acl",
provider_id="test-provider",
provider_resource_id="model-acl-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
)
success = await cached_disk_dist_registry.register(model)
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
assert cached_model is not None
assert cached_model.identifier == "model-acl"
assert cached_model.access_attributes.roles == ["admin"]
assert cached_model.access_attributes.teams == ["ai-team"]
assert cached_model.owner.principal == "testuser"
assert cached_model.owner.attributes["roles"] == ["admin"]
assert cached_model.owner.attributes["teams"] == ["ai-team"]
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
assert fetched_model is not None
assert fetched_model.identifier == "model-acl"
assert fetched_model.access_attributes.roles == ["admin"]
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
await cached_disk_dist_registry.update(model)
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
assert updated_cached is not None
assert updated_cached.access_attributes.roles == ["admin", "user"]
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
assert fetched_model.owner.attributes["roles"] == ["admin"]
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize()
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
new_model = await new_registry.get("model", "model-acl")
assert new_model is not None
assert new_model.identifier == "model-acl"
assert new_model.access_attributes.roles == ["admin", "user"]
assert new_model.access_attributes.projects == ["project-x"]
assert new_model.access_attributes.teams is None
assert new_model.owner.principal == "testuser"
assert new_model.owner.attributes["roles"] == ["admin"]
assert new_model.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-empty-acl",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(),
owner=User("testuser", None),
)
await cached_disk_dist_registry.register(model)
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
assert cached_model is not None
assert cached_model.access_attributes is not None
assert cached_model.access_attributes.roles is None
assert cached_model.access_attributes.teams is None
assert cached_model.access_attributes.projects is None
assert cached_model.access_attributes.namespaces is None
assert cached_model.owner is not None
assert cached_model.owner.attributes is None
all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 1
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-no-acl",
provider_id="test-provider",
provider_resource_id="model-resource-2",
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
assert cached_model is not None
assert cached_model.access_attributes is None
assert cached_model.owner is None
all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 2
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_registry_serialization(cached_disk_dist_registry):
attributes = AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
projects=["project-a", "project-b"],
namespaces=["prod", "staging"],
)
attributes = {
"roles": ["admin", "researcher"],
"teams": ["ai-team", "ml-team"],
"projects": ["project-a", "project-b"],
"namespaces": ["prod", "staging"],
}
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-serialize",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=attributes,
owner=User("bob", attributes),
)
await cached_disk_dist_registry.register(model)
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
loaded_model = await new_registry.get("model", "model-serialize")
assert loaded_model is not None
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]

View file

@ -8,13 +8,12 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
import yaml
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL
from llama_stack.distribution.request_headers import User
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -45,25 +44,25 @@ async def test_setup(cached_disk_dist_registry):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-public",
provider_id="test_provider",
provider_resource_id="model-public",
model_type=ModelType.llm,
)
model_admin_only = ModelWithACL(
model_admin_only = ModelWithOwner(
identifier="model-admin",
provider_id="test_provider",
provider_resource_id="model-admin",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("testuser", {"roles": ["admin"]}),
)
model_data_scientist = ModelWithACL(
model_data_scientist = ModelWithOwner(
identifier="model-data-scientist",
provider_id="test_provider",
provider_resource_id="model-data-scientist",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
)
await registry.register(model_public)
await registry.register(model_admin_only)
@ -110,7 +109,7 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-updates",
provider_id="test_provider",
provider_resource_id="model-updates",
@ -125,7 +124,7 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"])
model_public.owner = User("testuser", {"roles": ["admin"]})
await registry.update(model_public)
mock_get_authenticated_user.return_value = User(
"test-user",
@ -149,12 +148,12 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-empty-attrs",
provider_id="test_provider",
provider_resource_id="model-empty-attrs",
model_type=ModelType.llm,
access_attributes=AccessAttributes(),
owner=User("testuser", {}),
)
await registry.register(model)
mock_get_authenticated_user.return_value = User(
@ -174,18 +173,18 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-public-2",
provider_id="test_provider",
provider_resource_id="model-public-2",
model_type=ModelType.llm,
)
model_restricted = ModelWithACL(
model_restricted = ModelWithOwner(
identifier="model-restricted",
provider_id="test_provider",
provider_resource_id="model-restricted",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("testuser", {"roles": ["admin"]}),
)
await registry.register(model_public)
await registry.register(model_restricted)
@ -212,7 +211,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
# Create model without explicit access attributes
model = ModelWithACL(
model = ModelWithOwner(
identifier="auto-access-model",
provider_id="test_provider",
provider_resource_id="auto-access-model",
@ -222,10 +221,11 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
# Verify the model got creator's attributes
registered_model = await routing_table.get_model("auto-access-model")
assert registered_model.access_attributes is not None
assert registered_model.access_attributes.roles == ["data-scientist"]
assert registered_model.access_attributes.teams == ["ml-team"]
assert registered_model.access_attributes.projects == ["llama-3"]
assert registered_model.owner is not None
assert registered_model.owner.attributes is not None
assert registered_model.owner.attributes["roles"] == ["data-scientist"]
assert registered_model.owner.attributes["teams"] == ["ml-team"]
assert registered_model.owner.attributes["projects"] == ["llama-3"]
# Verify another user without matching attributes can't access it
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
@ -354,15 +354,14 @@ def test_permit_when():
- permit:
principal: user-1
actions: [read]
when:
user_in: resource.namespaces
when: user in owners namespaces
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
@ -376,15 +375,15 @@ def test_permit_unless():
actions: [read]
resource: model::*
unless:
- user_not_in: resource.namespaces
- user_in: resource.teams
- user not in owners namespaces
- user in owners teams
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
@ -397,16 +396,16 @@ def test_forbid_when():
principal: user-1
actions: [read]
when:
user_in: resource.namespaces
user in owners namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
@ -419,35 +418,33 @@ def test_forbid_unless():
principal: user-1
actions: [read]
unless:
user_in: resource.namespaces
user in owners namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_condition_with_literal():
def test_user_has_attribute():
config = """
- permit:
actions: [read]
when:
user_in: role::admin
when: user with admin in roles
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
@ -455,35 +452,115 @@ def test_condition_with_literal():
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_condition_with_unrecognised_literal():
def test_user_does_not_have_attribute():
config = """
- permit:
actions: [read]
when:
user_in: whatever
unless: user with admin not in roles
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", None))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_empty_condition():
def test_is_owner():
config = """
- permit:
actions: [read]
when: {}
when: user is owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", None))
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_is_not_owner():
config = """
- permit:
actions: [read]
unless: user is not owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_invalid_rule_permit_and_forbid_both_specified():
config = """
- permit:
actions: [read]
forbid:
actions: [create]
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_neither_permit_or_forbid_specified():
config = """
- when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_when_and_unless_both_specified():
config = """
- permit:
actions: [read]
when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_condition():
config = """
- permit:
actions: [read]
when: random words that are not valid
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
@pytest.mark.parametrize(
"condition",
[
"user is owner",
"user is not owner",
"user with dev in teams",
"user with default not in namespaces",
"user in owners roles",
"user not in owners projects",
],
)
def test_condition_reprs(condition):
from llama_stack.distribution.access_control.conditions import parse_condition
assert condition == str(parse_condition(condition))

View file

@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
{
"message": "Authentication successful",
"principal": "test-principal",
"access_attributes": {
"attributes": {
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
{
"message": "Authentication successful",
"principal": "test-principal",
"access_attributes": {
"attributes": {
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
@pytest.mark.asyncio
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
"""Test middleware behavior with no access attributes"""
middleware, mock_app = mock_http_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_client_instance
mock_client_instance.post.return_value = MockResponse(
200,
{
"message": "Authentication successful"
# No access_attributes
},
)
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
attributes = mock_scope["user_attributes"]
assert "roles" in attributes
assert attributes["roles"] == ["test.jwt.token"]
# oauth2 token provider tests
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
"aud": "llama-stack",
}
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
assert attributes.roles == ["my-user"]
assert attributes.teams == ["group1", "group2"]
assert attributes["roles"] == ["my-user"]
assert attributes["teams"] == ["group1", "group2"]
claims = {
"sub": "my-user",
"tenant": "my-tenant",
}
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
assert attributes.roles == ["my-user"]
assert attributes.namespaces == ["my-tenant"]
assert attributes["roles"] == ["my-user"]
assert attributes["namespaces"] == ["my-tenant"]
claims = {
"sub": "my-user",
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
"groups": "teams",
},
)
assert set(attributes.roles) == {"my-user", "my-username"}
assert set(attributes.teams) == {"my-team", "group1", "group2"}
assert attributes.namespaces == ["my-tenant"]
assert set(attributes["roles"]) == {"my-user", "my-username"}
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
assert attributes["namespaces"] == ["my-tenant"]
# TODO: add more tests for oauth2 token provider