feat: fine grained access control policy (#2264)

This allows a set of rules to be defined for determining access to
resources. The rules are (loosely) based on the cedar policy format.

A rule defines a list of action either to permit or to forbid. It may
specify a principal or a resource that must match for the rule to take
effect. It may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies.

A list of rules is held for each type to be protected and tried in order
to find a match. If a match is found, the request is permitted or
forbidden depening on the type of rule. If no match is found, the
request is denied. If no rules are specified for a given type, a rule
that allows any action as long as the resource attributes match the user
attributes is added (i.e. the previous behaviour is the default.

Some examples in yaml:

```
    model:
    - permit:
      principal: user-1
      actions: [create, read, delete]
      comment: user-1 has full access to all models
    - permit:
      principal: user-2
      actions: [read]
      resource: model-1
      comment: user-2 has read access to model-1 only
    - permit:
      actions: [read]
      when:
        user_in: resource.namespaces
      comment: any user has read access to models with matching attributes
    vector_db:
    - forbid:
      actions: [create, read, delete]
      unless:
        user_in: role::admin
      comment: only user with admin role can use vector_db resources
```

---------

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
grs 2025-06-03 17:51:12 -04:00 committed by GitHub
parent 8bee2954be
commit 7c1998db25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 956 additions and 450 deletions

View file

@ -1,86 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(
obj_identifier: str,
obj_attributes: AccessAttributes | None,
user_attributes: dict[str, Any] | None = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import User
from .conditions import (
Condition,
ProtectedResource,
parse_conditions,
)
from .datatypes import (
AccessRule,
Action,
Scope,
)
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
if resource_scope == actual_resource:
return True
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
def matches_scope(
scope: Scope,
action: Action,
resource: str,
user: str | None,
) -> bool:
if scope.resource and not matches_resource(scope.resource, resource):
return False
if scope.principal and scope.principal != user:
return False
return action in scope.actions
def as_list(obj: Any) -> list[Any]:
if isinstance(obj, list):
return obj
return [obj]
def matches_conditions(
conditions: list[Condition],
resource: ProtectedResource,
user: User,
) -> bool:
for condition in conditions:
# must match all conditions
if not condition.matches(resource, user):
return False
return True
def default_policy() -> list[AccessRule]:
# for backwards compatibility, if no rules are provided, assume
# full access subject to previous attribute matching rules
return [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
),
]
def is_action_allowed(
policy: list[AccessRule],
action: Action,
resource: ProtectedResource,
user: User | None,
) -> bool:
# If user is not set, assume authentication is not enabled
if not user:
return True
if not len(policy):
policy = default_policy()
qualified_resource_id = resource.type + "::" + resource.identifier
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return False
elif rule.unless:
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return False
else:
return False
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
if rule.when:
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return True
elif rule.unless:
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return True
else:
return True
# assume access is denied unless we find a rule that permits access
return False
class AccessDeniedError(RuntimeError):
pass

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
class User(Protocol):
principal: str
attributes: dict[str, list[str]] | None
class ProtectedResource(Protocol):
type: str
identifier: str
owner: User
class Condition(Protocol):
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
class UserInOwnersList:
def __init__(self, name: str):
self.name = name
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
if (
hasattr(resource, "owner")
and resource.owner
and resource.owner.attributes
and self.name in resource.owner.attributes
):
return resource.owner.attributes[self.name]
else:
return None
def matches(self, resource: ProtectedResource, user: User) -> bool:
required = self.owners_values(resource)
if not required:
return True
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
return False
user_values = user.attributes[self.name]
for value in required:
if value in user_values:
return True
return False
def __repr__(self):
return f"user in owners {self.name}"
class UserNotInOwnersList(UserInOwnersList):
def __init__(self, name: str):
super().__init__(name)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user not in owners {self.name}"
class UserWithValueInList:
def __init__(self, name: str, value: str):
self.name = name
self.value = value
def matches(self, resource: ProtectedResource, user: User) -> bool:
if user.attributes and self.name in user.attributes:
return self.value in user.attributes[self.name]
print(f"User does not have {self.value} in {self.name}")
return False
def __repr__(self):
return f"user with {self.value} in {self.name}"
class UserWithValueNotInList(UserWithValueInList):
def __init__(self, name: str, value: str):
super().__init__(name, value)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user with {self.value} not in {self.name}"
class UserIsOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return resource.owner.principal == user.principal if resource.owner else False
def __repr__(self):
return "user is owner"
class UserIsNotOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not resource.owner or resource.owner.principal != user.principal
def __repr__(self):
return "user is not owner"
def parse_condition(condition: str) -> Condition:
words = condition.split()
match words:
case ["user", "is", "owner"]:
return UserIsOwner()
case ["user", "is", "not", "owner"]:
return UserIsNotOwner()
case ["user", "with", value, "in", name]:
return UserWithValueInList(name, value)
case ["user", "with", value, "not", "in", name]:
return UserWithValueNotInList(name, value)
case ["user", "in", "owners", name]:
return UserInOwnersList(name)
case ["user", "not", "in", "owners", name]:
return UserNotInOwnersList(name)
case _:
raise ValueError(f"Invalid condition: {condition}")
def parse_conditions(conditions: list[str]) -> list[Condition]:
return [parse_condition(c) for c in conditions]

View file

@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel, model_validator
from typing_extensions import Self
from .conditions import parse_conditions
class Action(str, Enum):
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
class Scope(BaseModel):
principal: str | None = None
actions: Action | list[Action]
resource: str | None = None
def _mutually_exclusive(obj, a: str, b: str):
if getattr(obj, a) and getattr(obj, b):
raise ValueError(f"{a} and {b} are mutually exclusive")
def _require_one_of(obj, a: str, b: str):
if not getattr(obj, a) and not getattr(obj, b):
raise ValueError(f"on of {a} or {b} is required")
class AccessRule(BaseModel):
"""Access rule based loosely on cedar policy language
A rule defines a list of action either to permit or to forbid. It may specify a
principal or a resource that must match for the rule to take effect. The resource
to match should be specified in the form of a type qualified identifier, e.g.
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
e.g. model::*. If the principal or resource are not specified, they will match all
requests.
A rule may also specify a condition, either a 'when' or an 'unless', with additional
constraints as to where the rule applies. The constraints supported at present are:
- 'user with <attr-value> in <attr-name>'
- 'user with <attr-value> not in <attr-name>'
- 'user is owner'
- 'user is not owner'
- 'user in owners <attr-name>'
- 'user not in owners <attr-name>'
Rules are tested in order to find a match. If a match is found, the request is
permitted or forbidden depending on the type of rule. If no match is found, the
request is denied. If no rules are specified, a rule that allows any action as
long as the resource attributes match the user attributes is added
(i.e. the previous behaviour is the default).
Some examples in yaml:
- permit:
principal: user-1
actions: [create, read, delete]
resource: model::*
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
actions: [read]
when: user in owner teams
description: any user has read access to any resource created by a member of their team
- forbid:
actions: [create, read, delete]
resource: vector_db::*
unless: user with admin in roles
description: only user with admin role can use vector_db resources
"""
permit: Scope | None = None
forbid: Scope | None = None
when: str | list[str] | None = None
unless: str | list[str] | None = None
description: str | None = None
@model_validator(mode="after")
def validate_rule_format(self) -> Self:
_require_one_of(self, "permit", "forbid")
_mutually_exclusive(self, "permit", "forbid")
_mutually_exclusive(self, "when", "unless")
if isinstance(self.when, list):
parse_conditions(self.when)
elif self.when:
parse_conditions([self.when])
if isinstance(self.unless, list):
parse_conditions(self.unless)
elif self.unless:
parse_conditions([self.unless])
return self

View file

@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
@ -35,126 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = str | list[str] RoutingKey = str | list[str]
class AccessAttributes(BaseModel): class User(BaseModel):
"""Structured representation of user attributes for access control. principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
This model defines a structured approach to representing user attributes def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
with common standard categories for access control. super().__init__(principal=principal, attributes=attributes)
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class ResourceWithACL(Resource): class ResourceWithOwner(Resource):
"""Extension of Resource that adds attribute-based access control capabilities. """Extension of Resource that adds an optional owner, i.e. the user that created the
resource. This can be used to constrain access to the resource."""
This class adds an optional access_attributes field that allows fine-grained control owner: User | None = None
over which users can access each resource. When attributes are defined, a user must have
matching attributes to access the resource.
Attribute Matching Algorithm:
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
4. Within each category, ANY value match is sufficient (OR relationship within a category)
Examples:
# Resource visible to everyone (no access control)
model = Model(identifier="llama-2", ...)
# Resource visible only to admins
model = Model(
identifier="gpt-4",
access_attributes=AccessAttributes(roles=["admin"])
)
# Resource visible to data scientists on the ML team
model = Model(
identifier="private-model",
access_attributes=AccessAttributes(
roles=["data-scientist", "researcher"],
teams=["ml-team"]
)
)
# ^ User must have at least one of the roles AND be on the ml-team
# Resource visible to users with specific project access
vector_db = VectorDB(
identifier="customer-embeddings",
access_attributes=AccessAttributes(
projects=["customer-insights"],
namespaces=["confidential"]
)
)
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: AccessAttributes | None = None
# Use the extended Resource for all routable objects # Use the extended Resource for all routable objects
class ModelWithACL(Model, ResourceWithACL): class ModelWithOwner(Model, ResourceWithOwner):
pass pass
class ShieldWithACL(Shield, ResourceWithACL): class ShieldWithOwner(Shield, ResourceWithOwner):
pass pass
class VectorDBWithACL(VectorDB, ResourceWithACL): class VectorDBWithOwner(VectorDB, ResourceWithOwner):
pass pass
class DatasetWithACL(Dataset, ResourceWithACL): class DatasetWithOwner(Dataset, ResourceWithOwner):
pass pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL): class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
pass pass
class BenchmarkWithACL(Benchmark, ResourceWithACL): class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
pass pass
class ToolWithACL(Tool, ResourceWithACL): class ToolWithOwner(Tool, ResourceWithOwner):
pass pass
class ToolGroupWithACL(ToolGroup, ResourceWithACL): class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass pass
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
RoutableObjectWithProvider = Annotated[ RoutableObjectWithProvider = Annotated[
ModelWithACL ModelWithOwner
| ShieldWithACL | ShieldWithOwner
| VectorDBWithACL | VectorDBWithOwner
| DatasetWithACL | DatasetWithOwner
| ScoringFnWithACL | ScoringFnWithOwner
| BenchmarkWithACL | BenchmarkWithOwner
| ToolWithACL | ToolWithOwner
| ToolGroupWithACL, | ToolGroupWithOwner,
Field(discriminator="type"), Field(discriminator="type"),
] ]
@ -234,6 +175,7 @@ class AuthenticationConfig(BaseModel):
..., ...,
description="Provider-specific configuration", description="Provider-specific configuration",
) )
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class AuthenticationRequiredError(Exception): class AuthenticationRequiredError(Exception):

View file

@ -10,6 +10,8 @@ import logging
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import User
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -21,12 +23,10 @@ PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(AbstractContextManager): class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data""" """Context manager for request provider data"""
def __init__( def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
):
self.provider_data = provider_data or {} self.provider_data = provider_data or {}
if auth_attributes: if user:
self.provider_data["__auth_attributes"] = auth_attributes self.provider_data["__authenticated_user"] = user
self.token = None self.token = None
@ -95,9 +95,9 @@ def request_provider_data_context(
return RequestProviderDataContext(provider_data, auth_attributes) return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> dict[str, list[str]] | None: def get_authenticated_user() -> User | None:
"""Helper to retrieve auth attributes from the provider data context""" """Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get() provider_data = PROVIDER_DATA_VAR.get()
if not provider_data: if not provider_data:
return None return None
return provider_data.get("__auth_attributes") return provider_data.get("__authenticated_user")

View file

@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AccessRule,
AutoRoutedProviderSpec, AutoRoutedProviderSpec,
Provider, Provider,
RoutingTableProviderSpec, RoutingTableProviderSpec,
@ -118,6 +119,7 @@ async def resolve_impls(
run_config: StackRunConfig, run_config: StackRunConfig,
provider_registry: ProviderRegistry, provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> dict[Api, Any]: ) -> dict[Api, Any]:
""" """
Resolves provider implementations by: Resolves provider implementations by:
@ -140,7 +142,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config) return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -247,6 +249,7 @@ async def instantiate_providers(
router_apis: set[Api], router_apis: set[Api],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
run_config: StackRunConfig, run_config: StackRunConfig,
policy: list[AccessRule],
) -> dict: ) -> dict:
"""Instantiates providers asynchronously while managing dependencies.""" """Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {} impls: dict[Api, Any] = {}
@ -261,7 +264,7 @@ async def instantiate_providers(
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config) impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
if api_str.startswith("inner-"): if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@ -312,6 +315,7 @@ async def instantiate_provider(
inner_impls: dict[str, Any], inner_impls: dict[str, Any],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
run_config: StackRunConfig, run_config: StackRunConfig,
policy: list[AccessRule],
): ):
provider_spec = provider.spec provider_spec = provider.spec
if not hasattr(provider_spec, "module"): if not hasattr(provider_spec, "module"):
@ -336,13 +340,15 @@ async def instantiate_provider(
method = "get_routing_table_impl" method = "get_routing_table_impl"
config = None config = None
args = [provider_spec.api, inner_impls, deps, dist_registry] args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
else: else:
method = "get_provider_impl" method = "get_provider_impl"
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config) config = config_type(**provider.config)
args = [config, deps] args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
fn = getattr(module, method) fn = getattr(module, method)
impl = await fn(*args) impl = await fn(*args)

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol
from llama_stack.distribution.stack import StackRunConfig from llama_stack.distribution.stack import StackRunConfig
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
@ -18,6 +18,7 @@ async def get_routing_table_impl(
impls_by_provider_id: dict[str, RoutedProtocol], impls_by_provider_id: dict[str, RoutedProtocol],
_deps, _deps,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> Any: ) -> Any:
from ..routing_tables.benchmarks import BenchmarksRoutingTable from ..routing_tables.benchmarks import BenchmarksRoutingTable
from ..routing_tables.datasets import DatasetsRoutingTable from ..routing_tables.datasets import DatasetsRoutingTable
@ -40,7 +41,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables: if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BenchmarkWithACL, BenchmarkWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
) )
if provider_benchmark_id is None: if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL( benchmark = BenchmarkWithOwner(
identifier=benchmark_id, identifier=benchmark_id,
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,

View file

@ -8,14 +8,14 @@ from typing import Any
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control import check_access from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AccessAttributes, AccessRule,
RoutableObject, RoutableObject,
RoutableObjectWithProvider, RoutableObjectWithProvider,
RoutedProtocol, RoutedProtocol,
) )
from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
@ -73,9 +73,11 @@ class CommonRoutingTableImpl(RoutingTable):
self, self,
impls_by_provider_id: dict[str, RoutedProtocol], impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> None: ) -> None:
self.impls_by_provider_id = impls_by_provider_id self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry self.dist_registry = dist_registry
self.policy = policy
async def initialize(self) -> None: async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
@ -166,13 +168,15 @@ class CommonRoutingTableImpl(RoutingTable):
return None return None
# Check if user has permission to access this object # Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") logger.debug(f"Access denied to {type} '{identifier}'")
return None return None
return obj return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
raise AccessDeniedError()
await self.dist_registry.delete(obj.type, obj.identifier) await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
@ -187,11 +191,12 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes # If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes: creator = get_authenticated_user()
creator_attributes = get_auth_attributes() if not is_action_allowed(self.policy, "create", obj, creator):
if creator_attributes: raise AccessDeniedError()
obj.access_attributes = AccessAttributes(**creator_attributes) if creator:
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p) registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object # TODO: This needs to be fixed for all APIs once they return the registered object
@ -210,9 +215,7 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering # Apply attribute-based access control filtering
if filtered_objs: if filtered_objs:
filtered_objs = [ filtered_objs = [
obj obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
] ]
return filtered_objs return filtered_objs

View file

@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
) )
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
DatasetWithACL, DatasetWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
dataset = DatasetWithACL( dataset = DatasetWithOwner(
identifier=dataset_id, identifier=dataset_id,
provider_resource_id=provider_dataset_id, provider_resource_id=provider_dataset_id,
provider_id=provider_id, provider_id=provider_id,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ModelWithACL, ModelWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata") raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = ModelWithACL( model = ModelWithOwner(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,

View file

@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions, ScoringFunctions,
) )
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ScoringFnWithACL, ScoringFnWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." "No provider specified and multiple providers available. Please specify a provider_id."
) )
scoring_fn = ScoringFnWithACL( scoring_fn = ScoringFnWithOwner(
identifier=scoring_fn_id, identifier=scoring_fn_id,
description=description, description=description,
return_type=return_type, return_type=return_type,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ShieldWithACL, ShieldWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
) )
if params is None: if params is None:
params = {} params = {}
shield = ShieldWithACL( shield = ShieldWithOwner(
identifier=shield_id, identifier=shield_id,
provider_resource_id=provider_shield_id, provider_resource_id=provider_shield_id,
provider_id=provider_id, provider_id=provider_id,

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithACL from llama_stack.distribution.datatypes import ToolGroupWithOwner
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
@ -106,7 +106,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
toolgroup = ToolGroupWithACL( toolgroup = ToolGroupWithOwner(
identifier=toolgroup_id, identifier=toolgroup_id,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=toolgroup_id, provider_resource_id=toolgroup_id,

View file

@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
VectorDBWithACL, VectorDBWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
} }
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)
return vector_db return vector_db

View file

@ -105,24 +105,16 @@ class AuthenticationMiddleware:
logger.exception("Error during authentication") logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error") return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control
if validation_result.access_attributes:
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"roles": [token],
}
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits. # can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token scope["authenticated_client_id"] = token
# Store attributes in request scope # Store attributes in request scope
scope["user_attributes"] = user_attributes
scope["principal"] = validation_result.principal scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug( logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes" f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
) )
return await self.app(scope, receive, send) return await self.app(scope, receive, send)

View file

@ -16,43 +16,18 @@ from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
class TokenValidationResult(BaseModel): class AuthResponse(BaseModel):
principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
class AuthResponse(TokenValidationResult):
"""The format of the authentication response from the auth endpoint.""" """The format of the authentication response from the auth endpoint."""
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
message: str | None = Field( message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result." default=None, description="Optional message providing additional context about the authentication result."
) )
@ -78,7 +53,7 @@ class AuthProvider(ABC):
"""Abstract base class for authentication providers.""" """Abstract base class for authentication providers."""
@abstractmethod @abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token and return access attributes.""" """Validate a token and return access attributes."""
pass pass
@ -88,10 +63,10 @@ class AuthProvider(ABC):
pass pass
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes = AccessAttributes() attributes: dict[str, list[str]] = {}
for claim_key, attribute_key in mapping.items(): for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key): if claim_key not in claims:
continue continue
claim = claims[claim_key] claim = claims[claim_key]
if isinstance(claim, list): if isinstance(claim, list):
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
else: else:
values = claim.split() values = claim.split()
current = getattr(attributes, attribute_key) if attribute_key in attributes:
if current: attributes[attribute_key].extend(values)
current.extend(values)
else: else:
setattr(attributes, attribute_key, values) attributes[attribute_key] = values
return attributes return attributes
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
for key, value in v.items(): for key, value in v.items():
if not value: if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}") raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v return v
@model_validator(mode="after") @model_validator(mode="after")
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
self._jwks: dict[str, str] = {} self._jwks: dict[str, str] = {}
self._jwks_lock = Lock() self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks: if self.config.jwks:
return await self.validate_jwt_token(token, scope) return await self.validate_jwt_token(token, scope)
if self.config.introspection: if self.config.introspection:
return await self.introspect_token(token, scope) return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured") raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token.""" """Validate a token using the JWT token."""
await self._refresh_jwks() await self._refresh_jwks()
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
# We should incorporate these into the access attributes. # We should incorporate these into the access attributes.
principal = claims["sub"] principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult( return User(
principal=principal, principal=principal,
access_attributes=access_attributes, attributes=access_attributes,
) )
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def introspect_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using token introspection as defined by RFC 7662.""" """Validate a token using token introspection as defined by RFC 7662."""
form = { form = {
"token": token, "token": token,
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
raise ValueError("Token not active") raise ValueError("Token not active")
principal = fields["sub"] or fields["username"] principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return TokenValidationResult( return User(
principal=principal, principal=principal,
access_attributes=access_attributes, attributes=access_attributes,
) )
except httpx.TimeoutException: except httpx.TimeoutException:
logger.exception("Token introspection request timed out") logger.exception("Token introspection request timed out")
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
self.config = config self.config = config
self._client = None self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the custom authentication endpoint.""" """Validate a token using the custom authentication endpoint."""
if scope is None: if scope is None:
scope = {} scope = {}
@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider):
json=auth_request.model_dump(), json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout timeout=10.0, # Add a reasonable timeout
) )
print("MADE CALL")
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}") logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}")
@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider):
try: try:
response_data = response.json() response_data = response.json()
auth_response = AuthResponse(**response_data) auth_response = AuthResponse(**response_data)
return auth_response return User(auth_response.principal, auth_response.attributes)
except Exception as e: except Exception as e:
logger.exception("Error parsing authentication response") logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e raise ValueError("Invalid authentication response format") from e

View file

@ -33,10 +33,7 @@ from pydantic import BaseModel, ValidationError
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import ( from llama_stack.distribution.server.routes import (
find_matching_route, find_matching_route,
@ -217,11 +214,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
async def route_handler(request: Request, **kwargs): async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope # Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {}) user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal, user_attributes)
await log_request_pre_validation(request) await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes # Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes): with request_provider_data_context(request.headers, user):
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:

View file

@ -223,7 +223,10 @@ async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]: ) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) policy = run_config.server.auth.access_policy if run_config.server.auth else []
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved # Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config) add_internal_implementations(impls, run_config)

View file

@ -6,12 +6,12 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import AccessRule, Api
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]): async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.safety], deps[Api.safety],
deps[Api.tool_runtime], deps[Api.tool_runtime],
deps[Api.tool_groups], deps[Api.tool_groups],
policy,
) )
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -60,6 +60,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin):
vector_io_api: VectorIO, vector_io_api: VectorIO,
persistence_store: KVStore, persistence_store: KVStore,
created_at: str, created_at: str,
policy: list[AccessRule],
): ):
self.agent_id = agent_id self.agent_id = agent_id
self.agent_config = agent_config self.agent_config = agent_config
self.inference_api = inference_api self.inference_api = inference_api
self.safety_api = safety_api self.safety_api = safety_api
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store, policy)
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.created_at = created_at self.created_at = created_at

View file

@ -41,6 +41,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -62,6 +63,7 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety, safety_api: Safety,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
policy: list[AccessRule],
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
@ -72,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents):
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None self.openai_responses_impl: OpenAIResponsesImpl | None = None
self.policy = policy
async def initialize(self) -> None: async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store) self.persistence_store = await kvstore_impl(self.config.persistence_store)
@ -130,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
), ),
created_at=agent_info.created_at, created_at=agent_info.created_at,
policy=self.policy,
) )
async def create_agent_session( async def create_agent_session(

View file

@ -10,9 +10,10 @@ import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -22,7 +23,9 @@ class AgentSessionInfo(Session):
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: str | None = None vector_db_id: str | None = None
started_at: datetime started_at: datetime
access_attributes: AccessAttributes | None = None owner: User | None = None
identifier: str | None = None
type: str = "session"
class AgentInfo(AgentConfig): class AgentInfo(AgentConfig):
@ -30,24 +33,27 @@ class AgentInfo(AgentConfig):
class AgentPersistence: class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore): def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id self.agent_id = agent_id
self.kvstore = kvstore self.kvstore = kvstore
self.policy = policy
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions # Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes() user = get_authenticated_user()
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo( session_info = AgentSessionInfo(
session_id=session_id, session_id=session_id,
session_name=name, session_name=name,
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
access_attributes=access_attributes, owner=user,
turns=[], turns=[],
identifier=name, # should this be qualified in any way?
) )
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError()
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
@ -73,10 +79,10 @@ class AgentPersistence:
def _check_session_access(self, session_info: AgentSessionInfo) -> bool: def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session.""" """Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control # Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"): if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes()) return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods.""" """Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry): async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry) table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple models and verify listing # Register multiple models and verify listing
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry): async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry) table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple shields and verify listing # Register multiple shields and verify listing
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry) table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry) m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize() await m_table.initialize()
await m_table.register_model( await m_table.register_model(
model_id="test-model", model_id="test-model",
provider_id="test_providere", provider_id="test_provider",
metadata={"embedding_dimension": 128}, metadata={"embedding_dimension": 128},
model_type=ModelType.embedding, model_type=ModelType.embedding,
) )
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
async def test_datasets_routing_table(cached_disk_dist_registry): async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry) table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple datasets and verify listing # Register multiple datasets and verify listing
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry): async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry) table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple scoring functions and verify listing # Register multiple scoring functions and verify listing
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry): async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry) table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple benchmarks and verify listing # Register multiple benchmarks and verify listing
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry): async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry) table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple tool groups and verify listing # Register multiple tool groups and verify listing

View file

@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"], mock_apis["safety_api"],
mock_apis["tool_runtime_api"], mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"], mock_apis["tool_groups_api"],
{},
) )
await impl.initialize() await impl.initialize()
yield impl yield impl

View file

@ -12,24 +12,24 @@ import pytest
from llama_stack.apis.agents import Turn from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture @pytest.fixture
async def test_setup(sqlite_kvstore): async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore) agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence yield agent_persistence
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup): async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
# Set creator's attributes for the session # Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]} creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_auth_attributes.return_value = creator_attributes mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
# Create a session # Create a session
session_id = await agent_persistence.create_session("Test Session") session_id = await agent_persistence.create_session("Test Session")
@ -37,14 +37,15 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
# Get the session and verify access attributes were set # Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id) session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None assert session_info is not None
assert session_info.access_attributes is not None assert session_info.owner is not None
assert session_info.access_attributes.roles == ["researcher"] assert session_info.owner.attributes is not None
assert session_info.access_attributes.teams == ["ai-team"] assert session_info.owner.attributes["roles"] == ["researcher"]
assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_auth_attributes, test_setup): async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
# Create a session with specific access attributes # Create a session with specific access attributes
@ -53,8 +54,9 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
session_id=session_id, session_id=session_id,
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
turns=[], turns=[],
identifier="Restricted Session",
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
) )
# User with matching attributes can access # User with matching attributes can access
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]} mock_get_authenticated_user.return_value = User(
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
)
retrieved_session = await agent_persistence.get_session_info(session_id) retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None assert retrieved_session is not None
assert retrieved_session.session_id == session_id assert retrieved_session.session_id == session_id
# User without matching attributes cannot access # User without matching attributes cannot access
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
retrieved_session = await agent_persistence.get_session_info(session_id) retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None assert retrieved_session is None
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_auth_attributes, test_setup): async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
# Create a session with restricted access # Create a session with restricted access
@ -85,8 +89,9 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
session_id=session_id, session_id=session_id,
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]), owner=User("someone", {"roles": ["admin"]}),
turns=[], turns=[],
identifier="Restricted Session",
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
) )
# Admin can add turn # Admin can add turn
mock_get_auth_attributes.return_value = {"roles": ["admin"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.add_turn_to_session(session_id, turn) await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn # Admin can get turn
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
assert retrieved_turn.turn_id == turn_id assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn # Regular user cannot get turn
mock_get_auth_attributes.return_value = {"roles": ["user"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
with pytest.raises(ValueError): with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id) await agent_persistence.get_session_turn(session_id, turn_id)
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup): async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
# Create a session with restricted access # Create a session with restricted access
@ -138,8 +143,9 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
session_id=session_id, session_id=session_id,
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]), owner=User("someone", {"roles": ["admin"]}),
turns=[], turns=[],
identifier="Restricted Session",
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
# Admin user can set inference iterations # Admin user can set inference iterations
mock_get_auth_attributes.return_value = {"roles": ["admin"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5) await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations # Admin user can get inference iterations
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
assert infer_iters == 5 assert infer_iters == 5
# Regular user cannot get inference iterations # Regular user cannot get inference iterations
mock_get_auth_attributes.return_value = {"roles": ["user"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id) infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None assert infer_iters is None

View file

@ -8,19 +8,18 @@
import pytest import pytest
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithACL from llama_stack.distribution.datatypes import ModelWithOwner, User
from llama_stack.distribution.server.auth_providers import AccessAttributes
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_cache_with_acl(cached_disk_dist_registry): async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithACL( model = ModelWithOwner(
identifier="model-acl", identifier="model-acl",
provider_id="test-provider", provider_id="test-provider",
provider_resource_id="model-acl-resource", provider_resource_id="model-acl-resource",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]), owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
) )
success = await cached_disk_dist_registry.register(model) success = await cached_disk_dist_registry.register(model)
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.identifier == "model-acl" assert cached_model.identifier == "model-acl"
assert cached_model.access_attributes.roles == ["admin"] assert cached_model.owner.principal == "testuser"
assert cached_model.access_attributes.teams == ["ai-team"] assert cached_model.owner.attributes["roles"] == ["admin"]
assert cached_model.owner.attributes["teams"] == ["ai-team"]
fetched_model = await cached_disk_dist_registry.get("model", "model-acl") fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
assert fetched_model is not None assert fetched_model is not None
assert fetched_model.identifier == "model-acl" assert fetched_model.identifier == "model-acl"
assert fetched_model.access_attributes.roles == ["admin"] assert fetched_model.owner.attributes["roles"] == ["admin"]
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
await cached_disk_dist_registry.update(model)
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
assert updated_cached is not None
assert updated_cached.access_attributes.roles == ["admin", "user"]
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore) new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize() await new_registry.initialize()
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
new_model = await new_registry.get("model", "model-acl") new_model = await new_registry.get("model", "model-acl")
assert new_model is not None assert new_model is not None
assert new_model.identifier == "model-acl" assert new_model.identifier == "model-acl"
assert new_model.access_attributes.roles == ["admin", "user"] assert new_model.owner.principal == "testuser"
assert new_model.access_attributes.projects == ["project-x"] assert new_model.owner.attributes["roles"] == ["admin"]
assert new_model.access_attributes.teams is None assert new_model.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_empty_acl(cached_disk_dist_registry): async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithACL( model = ModelWithOwner(
identifier="model-empty-acl", identifier="model-empty-acl",
provider_id="test-provider", provider_id="test-provider",
provider_resource_id="model-resource", provider_resource_id="model-resource",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(), owner=User("testuser", None),
) )
await cached_disk_dist_registry.register(model) await cached_disk_dist_registry.register(model)
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.access_attributes is not None assert cached_model.owner is not None
assert cached_model.access_attributes.roles is None assert cached_model.owner.attributes is None
assert cached_model.access_attributes.teams is None
assert cached_model.access_attributes.projects is None
assert cached_model.access_attributes.namespaces is None
all_models = await cached_disk_dist_registry.get_all() all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 1 assert len(all_models) == 1
model = ModelWithACL( model = ModelWithOwner(
identifier="model-no-acl", identifier="model-no-acl",
provider_id="test-provider", provider_id="test-provider",
provider_resource_id="model-resource-2", provider_resource_id="model-resource-2",
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.access_attributes is None assert cached_model.owner is None
all_models = await cached_disk_dist_registry.get_all() all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 2 assert len(all_models) == 2
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_serialization(cached_disk_dist_registry): async def test_registry_serialization(cached_disk_dist_registry):
attributes = AccessAttributes( attributes = {
roles=["admin", "researcher"], "roles": ["admin", "researcher"],
teams=["ai-team", "ml-team"], "teams": ["ai-team", "ml-team"],
projects=["project-a", "project-b"], "projects": ["project-a", "project-b"],
namespaces=["prod", "staging"], "namespaces": ["prod", "staging"],
) }
model = ModelWithACL( model = ModelWithOwner(
identifier="model-serialize", identifier="model-serialize",
provider_id="test-provider", provider_id="test-provider",
provider_resource_id="model-resource", provider_resource_id="model-resource",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=attributes, owner=User("bob", attributes),
) )
await cached_disk_dist_registry.register(model) await cached_disk_dist_registry.register(model)
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
loaded_model = await new_registry.get("model", "model-serialize") loaded_model = await new_registry.get("model", "model-serialize")
assert loaded_model is not None assert loaded_model is not None
assert loaded_model.access_attributes.roles == ["admin", "researcher"] assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"] assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
assert loaded_model.access_attributes.projects == ["project-a", "project-b"] assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
assert loaded_model.access_attributes.namespaces == ["prod", "staging"] assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]

View file

@ -7,10 +7,13 @@
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
import yaml
from pydantic import TypeAdapter, ValidationError
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -32,39 +35,40 @@ async def test_setup(cached_disk_dist_registry):
routing_table = ModelsRoutingTable( routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference}, impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry, dist_registry=cached_disk_dist_registry,
policy={},
) )
yield cached_disk_dist_registry, routing_table yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithOwner(
identifier="model-public", identifier="model-public",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-public", provider_resource_id="model-public",
model_type=ModelType.llm, model_type=ModelType.llm,
) )
model_admin_only = ModelWithACL( model_admin_only = ModelWithOwner(
identifier="model-admin", identifier="model-admin",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-admin", provider_resource_id="model-admin",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]), owner=User("testuser", {"roles": ["admin"]}),
) )
model_data_scientist = ModelWithACL( model_data_scientist = ModelWithOwner(
identifier="model-data-scientist", identifier="model-data-scientist",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-data-scientist", provider_resource_id="model-data-scientist",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]), owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
) )
await registry.register(model_public) await registry.register(model_public)
await registry.register(model_admin_only) await registry.register(model_admin_only)
await registry.register(model_data_scientist) await registry.register(model_data_scientist)
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]} mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
all_models = await routing_table.list_models() all_models = await routing_table.list_models()
assert len(all_models.data) == 2 assert len(all_models.data) == 2
@ -75,7 +79,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist") await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]} mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
all_models = await routing_table.list_models() all_models = await routing_table.list_models()
assert len(all_models.data) == 1 assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public" assert all_models.data[0].identifier == "model-public"
@ -86,7 +90,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist") await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]} mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
all_models = await routing_table.list_models() all_models = await routing_table.list_models()
assert len(all_models.data) == 2 assert len(all_models.data) == 2
model_ids = [m.identifier for m in all_models.data] model_ids = [m.identifier for m in all_models.data]
@ -102,50 +106,62 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithOwner(
identifier="model-updates", identifier="model-updates",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-updates", provider_resource_id="model-updates",
model_type=ModelType.llm, model_type=ModelType.llm,
) )
await registry.register(model_public) await registry.register(model_public)
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": ["user"], "test-user",
} {
"roles": ["user"],
},
)
model = await routing_table.get_model("model-updates") model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates" assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"]) model_public.owner = User("testuser", {"roles": ["admin"]})
await registry.update(model_public) await registry.update(model_public)
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": ["user"], "test-user",
} {
"roles": ["user"],
},
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-updates") await routing_table.get_model("model-updates")
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": ["admin"], "test-user",
} {
"roles": ["admin"],
},
)
model = await routing_table.get_model("model-updates") model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates" assert model.identifier == "model-updates"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup): async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model = ModelWithACL( model = ModelWithOwner(
identifier="model-empty-attrs", identifier="model-empty-attrs",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-empty-attrs", provider_resource_id="model-empty-attrs",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(), owner=User("testuser", {}),
) )
await registry.register(model) await registry.register(model)
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": [], "test-user",
} {
"roles": [],
},
)
result = await routing_table.get_model("model-empty-attrs") result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs" assert result.identifier == "model-empty-attrs"
all_models = await routing_table.list_models() all_models = await routing_table.list_models()
@ -154,25 +170,25 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_auth_attributes, test_setup): async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithOwner(
identifier="model-public-2", identifier="model-public-2",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-public-2", provider_resource_id="model-public-2",
model_type=ModelType.llm, model_type=ModelType.llm,
) )
model_restricted = ModelWithACL( model_restricted = ModelWithOwner(
identifier="model-restricted", identifier="model-restricted",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="model-restricted", provider_resource_id="model-restricted",
model_type=ModelType.llm, model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]), owner=User("testuser", {"roles": ["admin"]}),
) )
await registry.register(model_public) await registry.register(model_public)
await registry.register(model_restricted) await registry.register(model_restricted)
mock_get_auth_attributes.return_value = None mock_get_authenticated_user.return_value = User("test-user", None)
model = await routing_table.get_model("model-public-2") model = await routing_table.get_model("model-public-2")
assert model.identifier == "model-public-2" assert model.identifier == "model-public-2"
@ -185,17 +201,17 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup): async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator.""" """Test that newly created resources inherit access attributes from their creator."""
registry, routing_table = test_setup registry, routing_table = test_setup
# Set creator's attributes # Set creator's attributes
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]} creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
mock_get_auth_attributes.return_value = creator_attributes mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
# Create model without explicit access attributes # Create model without explicit access attributes
model = ModelWithACL( model = ModelWithOwner(
identifier="auto-access-model", identifier="auto-access-model",
provider_id="test_provider", provider_id="test_provider",
provider_resource_id="auto-access-model", provider_resource_id="auto-access-model",
@ -205,21 +221,346 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
# Verify the model got creator's attributes # Verify the model got creator's attributes
registered_model = await routing_table.get_model("auto-access-model") registered_model = await routing_table.get_model("auto-access-model")
assert registered_model.access_attributes is not None assert registered_model.owner is not None
assert registered_model.access_attributes.roles == ["data-scientist"] assert registered_model.owner.attributes is not None
assert registered_model.access_attributes.teams == ["ml-team"] assert registered_model.owner.attributes["roles"] == ["data-scientist"]
assert registered_model.access_attributes.projects == ["llama-3"] assert registered_model.owner.attributes["teams"] == ["ml-team"]
assert registered_model.owner.attributes["projects"] == ["llama-3"]
# Verify another user without matching attributes can't access it # Verify another user without matching attributes can't access it
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]} mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("auto-access-model") await routing_table.get_model("auto-access-model")
# But a user with matching attributes can # But a user with matching attributes can
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": ["data-scientist", "engineer"], "test-user",
"teams": ["ml-team", "platform-team"], {
"projects": ["llama-3"], "roles": ["data-scientist", "engineer"],
} "teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
},
)
model = await routing_table.get_model("auto-access-model") model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model" assert model.identifier == "auto-access-model"
@pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
config = """
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy=policy,
)
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
{
"roles": ["user"],
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
{
"roles": ["user"],
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
def test_permit_when():
config = """
- permit:
principal: user-1
actions: [read]
when: user in owners namespaces
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_permit_unless():
config = """
- permit:
principal: user-1
actions: [read]
resource: model::*
unless:
- user not in owners namespaces
- user in owners teams
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_when():
config = """
- forbid:
principal: user-1
actions: [read]
when:
user in owners namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_unless():
config = """
- forbid:
principal: user-1
actions: [read]
unless:
user in owners namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_user_has_attribute():
config = """
- permit:
actions: [read]
when: user with admin in roles
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_user_does_not_have_attribute():
config = """
- permit:
actions: [read]
unless: user with admin not in roles
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_is_owner():
config = """
- permit:
actions: [read]
when: user is owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_is_not_owner():
config = """
- permit:
actions: [read]
unless: user is not owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_invalid_rule_permit_and_forbid_both_specified():
config = """
- permit:
actions: [read]
forbid:
actions: [create]
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_neither_permit_or_forbid_specified():
config = """
- when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_when_and_unless_both_specified():
config = """
- permit:
actions: [read]
when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_condition():
config = """
- permit:
actions: [read]
when: random words that are not valid
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
@pytest.mark.parametrize(
"condition",
[
"user is owner",
"user is not owner",
"user with dev in teams",
"user with default not in namespaces",
"user in owners roles",
"user not in owners projects",
],
)
def test_condition_reprs(condition):
from llama_stack.distribution.access_control.conditions import parse_condition
assert condition == str(parse_condition(condition))

View file

@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
{ {
"message": "Authentication successful", "message": "Authentication successful",
"principal": "test-principal", "principal": "test-principal",
"access_attributes": { "attributes": {
"roles": ["admin", "user"], "roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"], "teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"], "projects": ["llama-3", "project-x"],
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
{ {
"message": "Authentication successful", "message": "Authentication successful",
"principal": "test-principal", "principal": "test-principal",
"access_attributes": { "attributes": {
"roles": ["admin", "user"], "roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"], "teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"], "projects": ["llama-3", "project-x"],
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
@pytest.mark.asyncio
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
"""Test middleware behavior with no access attributes"""
middleware, mock_app = mock_http_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_client_instance
mock_client_instance.post.return_value = MockResponse(
200,
{
"message": "Authentication successful"
# No access_attributes
},
)
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
attributes = mock_scope["user_attributes"]
assert "roles" in attributes
assert attributes["roles"] == ["test.jwt.token"]
# oauth2 token provider tests # oauth2 token provider tests
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
"aud": "llama-stack", "aud": "llama-stack",
} }
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"}) attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
assert attributes.roles == ["my-user"] assert attributes["roles"] == ["my-user"]
assert attributes.teams == ["group1", "group2"] assert attributes["teams"] == ["group1", "group2"]
claims = { claims = {
"sub": "my-user", "sub": "my-user",
"tenant": "my-tenant", "tenant": "my-tenant",
} }
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"}) attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
assert attributes.roles == ["my-user"] assert attributes["roles"] == ["my-user"]
assert attributes.namespaces == ["my-tenant"] assert attributes["namespaces"] == ["my-tenant"]
claims = { claims = {
"sub": "my-user", "sub": "my-user",
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
"groups": "teams", "groups": "teams",
}, },
) )
assert set(attributes.roles) == {"my-user", "my-username"} assert set(attributes["roles"]) == {"my-user", "my-username"}
assert set(attributes.teams) == {"my-team", "group1", "group2"} assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
assert attributes.namespaces == ["my-tenant"] assert attributes["namespaces"] == ["my-tenant"]
# TODO: add more tests for oauth2 token provider # TODO: add more tests for oauth2 token provider

View file

@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
add_protocol_methods(SampleImpl, Inference) add_protocol_methods(SampleImpl, Inference)
mock_module.get_provider_impl = AsyncMock(return_value=impl) mock_module.get_provider_impl = AsyncMock(return_value=impl)
mock_module.get_provider_impl.__text_signature__ = "()"
sys.modules["test_module"] = mock_module sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry) impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
assert Api.inference in impls assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter) assert isinstance(impls[Api.inference], InferenceRouter)