feat: fine grained access control policy

This allows a set of rules to be defined for determining access to resources.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-05-06 18:54:58 +01:00
parent 9623d5d230
commit 01ad876012
20 changed files with 724 additions and 214 deletions

View file

@ -1,86 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(
obj_identifier: str,
obj_attributes: AccessAttributes | None,
user_attributes: dict[str, Any] | None = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol
from llama_stack.distribution.request_headers import User
from .datatypes import (
AccessAttributes,
AccessRule,
Action,
AttributeReference,
Condition,
Scope,
)
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
if resource_scope == actual_resource:
return True
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
def matches_scope(
scope: Scope,
action: Action,
resource: str,
user: str | None,
) -> bool:
if scope.resource and not matches_resource(scope.resource, resource):
return False
if scope.principal and scope.principal != user:
return False
return action in scope.actions
def user_in_literal(
literal: str,
user_attributes: dict[str, list[str]] | None,
) -> bool:
for qualifier in ["role::", "team::", "project::", "namespace::"]:
if literal.startswith(qualifier):
if not user_attributes:
return False
ref = qualifier.replace("::", "s")
if ref in user_attributes:
value = literal.removeprefix(qualifier)
return value in user_attributes[ref]
else:
return False
return False
def user_in(
ref: AttributeReference | str,
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
if not ref.startswith("resource."):
return user_in_literal(ref, user_attributes)
name = ref.removeprefix("resource.")
required = resource_attributes and getattr(resource_attributes, name)
if not required:
return True
if not user_attributes or name not in user_attributes:
return False
actual = user_attributes[name]
for value in required:
if value in actual:
return True
return False
def as_list(obj: Any) -> list[Any]:
if isinstance(obj, list):
return obj
return [obj]
def matches_conditions(
conditions: list[Condition],
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
for condition in conditions:
# must match all conditions
if not matches_condition(condition, resource_attributes, user_attributes):
return False
return True
def matches_condition(
condition: Condition | list[Condition],
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
if isinstance(condition, list):
return matches_conditions(as_list(condition), resource_attributes, user_attributes)
if condition.user_in:
for ref in as_list(condition.user_in):
# if multiple references are specified, all must match
if not user_in(ref, resource_attributes, user_attributes):
return False
return True
if condition.user_not_in:
for ref in as_list(condition.user_not_in):
# if multiple references are specified, none must match
if user_in(ref, resource_attributes, user_attributes):
return False
return True
return True
def default_policy() -> list[AccessRule]:
# for backwards compatibility, if no rules are provided , assume
# full access to all subject to attribute matching rules
return [
AccessRule(
permit=Scope(actions=list(Action)),
when=Condition(user_in=list(AttributeReference)),
)
]
class ProtectedResource(Protocol):
type: str
identifier: str
access_attributes: AccessAttributes
def is_action_allowed(
policy: list[AccessRule],
action: Action,
resource: ProtectedResource,
user: User | None,
) -> bool:
# If user is not set, assume authentication is not enabled
if not user:
return True
if not len(policy):
policy = default_policy()
resource_attributes = AccessAttributes()
if hasattr(resource, "access_attributes"):
resource_attributes = resource.access_attributes
qualified_resource_id = resource.type + "::" + resource.identifier
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:
if matches_condition(rule.when, resource_attributes, user.attributes):
return False
elif rule.unless:
if not matches_condition(rule.unless, resource_attributes, user.attributes):
return False
else:
return False
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
if rule.when:
if matches_condition(rule.when, resource_attributes, user.attributes):
return True
elif rule.unless:
if not matches_condition(rule.unless, resource_attributes, user.attributes):
return True
else:
return True
# assume access is denied unless we find a rule that permits access
return False
class AccessDeniedError(RuntimeError):
pass

View file

@ -0,0 +1,136 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class Action(str, Enum):
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
class Scope(BaseModel):
principal: str | None = None
actions: Action | list[Action]
resource: str | None = None
def _mutually_exclusive(obj, a: str, b: str):
if getattr(obj, a) and getattr(obj, b):
raise ValueError(f"{a} and {b} are mutually exclusive")
def _require_one_of(obj, a: str, b: str):
if not getattr(obj, a) and not getattr(obj, b):
raise ValueError(f"on of {a} or {b} is required")
class AttributeReference(str, Enum):
RESOURCE_ROLES = "resource.roles"
RESOURCE_TEAMS = "resource.teams"
RESOURCE_PROJECTS = "resource.projects"
RESOURCE_NAMESPACES = "resource.namespaces"
class Condition(BaseModel):
user_in: AttributeReference | list[AttributeReference] | str | None = None
user_not_in: AttributeReference | list[AttributeReference] | str | None = None
class AccessRule(BaseModel):
"""Access rule based loosely on cedar policy language
A rule defines a list of action either to permit or to forbid. It may specify a
principal or a resource that must match for the rule to take effect. The resource
to match should be specified in the form of a type qualified identifier, e.g.
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
e.g. model::*. If the principal or resource are not specified, they will match all
requests.
A rule may also specify a condition, either a 'when' or an 'unless', with additional
constraints as to where the rule applies. The constraints at present are whether the
user requesting access is in or not in some set. This set can either be a particular
set of attributes on the resource e.g. resource.roles or a literal value of some
notion of group, e.g. role::admin or namespace::foo.
Rules are tested in order to find a match. If a match is found, the request is
permitted or forbidden depending on the type of rule. If no match is found, the
request is denied. If no rules are specified, a rule that allows any action as
long as the resource attributes match the user attributes is added
(i.e. the previous behaviour is the default).
Some examples in yaml:
- permit:
principal: user-1
actions: [create, read, delete]
resource: model::*
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
actions: [read]
when:
user_in: resource.namespaces
description: any user has read access to any resource with matching attributes
- forbid:
actions: [create, read, delete]
resource: vector_db::*
unless:
user_in: role::admin
description: only user with admin role can use vector_db resources
"""
permit: Scope | None = None
forbid: Scope | None = None
when: Condition | list[Condition] | None = None
unless: Condition | list[Condition] | None = None
description: str | None = None
@model_validator(mode="after")
def validate_rule_format(self) -> Self:
_require_one_of(self, "permit", "forbid")
_mutually_exclusive(self, "permit", "forbid")
_mutually_exclusive(self, "when", "unless")
return self

View file

@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
@ -35,35 +36,6 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = str | list[str] RoutingKey = str | list[str]
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class ResourceWithACL(Resource): class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities. """Extension of Resource that adds attribute-based access control capabilities.
@ -234,6 +206,7 @@ class AuthenticationConfig(BaseModel):
..., ...,
description="Provider-specific configuration", description="Provider-specific configuration",
) )
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class AuthenticationRequiredError(Exception): class AuthenticationRequiredError(Exception):

View file

@ -18,15 +18,23 @@ log = logging.getLogger(__name__)
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class User:
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]]
def __init__(self, principal: str, attributes: dict[str, list[str]]):
self.principal = principal
self.attributes = attributes
class RequestProviderDataContext(AbstractContextManager): class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data""" """Context manager for request provider data"""
def __init__( def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
):
self.provider_data = provider_data or {} self.provider_data = provider_data or {}
if auth_attributes: if user:
self.provider_data["__auth_attributes"] = auth_attributes self.provider_data["__authenticated_user"] = user
self.token = None self.token = None
@ -95,9 +103,9 @@ def request_provider_data_context(
return RequestProviderDataContext(provider_data, auth_attributes) return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> dict[str, list[str]] | None: def get_authenticated_user() -> User | None:
"""Helper to retrieve auth attributes from the provider data context""" """Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get() provider_data = PROVIDER_DATA_VAR.get()
if not provider_data: if not provider_data:
return None return None
return provider_data.get("__auth_attributes") return provider_data.get("__authenticated_user")

View file

@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AccessRule,
AutoRoutedProviderSpec, AutoRoutedProviderSpec,
Provider, Provider,
RoutingTableProviderSpec, RoutingTableProviderSpec,
@ -118,6 +119,7 @@ async def resolve_impls(
run_config: StackRunConfig, run_config: StackRunConfig,
provider_registry: ProviderRegistry, provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> dict[Api, Any]: ) -> dict[Api, Any]:
""" """
Resolves provider implementations by: Resolves provider implementations by:
@ -140,7 +142,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config) return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -247,6 +249,7 @@ async def instantiate_providers(
router_apis: set[Api], router_apis: set[Api],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
run_config: StackRunConfig, run_config: StackRunConfig,
policy: list[AccessRule],
) -> dict: ) -> dict:
"""Instantiates providers asynchronously while managing dependencies.""" """Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {} impls: dict[Api, Any] = {}
@ -261,7 +264,7 @@ async def instantiate_providers(
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config) impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
if api_str.startswith("inner-"): if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@ -312,6 +315,7 @@ async def instantiate_provider(
inner_impls: dict[str, Any], inner_impls: dict[str, Any],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
run_config: StackRunConfig, run_config: StackRunConfig,
policy: list[AccessRule],
): ):
provider_spec = provider.spec provider_spec = provider.spec
if not hasattr(provider_spec, "module"): if not hasattr(provider_spec, "module"):
@ -336,13 +340,15 @@ async def instantiate_provider(
method = "get_routing_table_impl" method = "get_routing_table_impl"
config = None config = None
args = [provider_spec.api, inner_impls, deps, dist_registry] args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
else: else:
method = "get_provider_impl" method = "get_provider_impl"
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config) config = config_type(**provider.config)
args = [config, deps] args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
fn = getattr(module, method) fn = getattr(module, method)
impl = await fn(*args) impl = await fn(*args)

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol
from llama_stack.distribution.stack import StackRunConfig from llama_stack.distribution.stack import StackRunConfig
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
@ -18,6 +18,7 @@ async def get_routing_table_impl(
impls_by_provider_id: dict[str, RoutedProtocol], impls_by_provider_id: dict[str, RoutedProtocol],
_deps, _deps,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> Any: ) -> Any:
from ..routing_tables.benchmarks import BenchmarksRoutingTable from ..routing_tables.benchmarks import BenchmarksRoutingTable
from ..routing_tables.datasets import DatasetsRoutingTable from ..routing_tables.datasets import DatasetsRoutingTable
@ -40,7 +41,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables: if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -8,14 +8,15 @@ from typing import Any
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control import check_access from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AccessAttributes, AccessAttributes,
AccessRule,
RoutableObject, RoutableObject,
RoutableObjectWithProvider, RoutableObjectWithProvider,
RoutedProtocol, RoutedProtocol,
) )
from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
@ -73,9 +74,11 @@ class CommonRoutingTableImpl(RoutingTable):
self, self,
impls_by_provider_id: dict[str, RoutedProtocol], impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> None: ) -> None:
self.impls_by_provider_id = impls_by_provider_id self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry self.dist_registry = dist_registry
self.policy = policy
async def initialize(self) -> None: async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
@ -166,13 +169,15 @@ class CommonRoutingTableImpl(RoutingTable):
return None return None
# Check if user has permission to access this object # Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") logger.debug(f"Access denied to {type} '{identifier}'")
return None return None
return obj return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
raise AccessDeniedError()
await self.dist_registry.delete(obj.type, obj.identifier) await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
@ -187,11 +192,12 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes # If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes: creator = get_authenticated_user()
creator_attributes = get_auth_attributes() if not is_action_allowed(self.policy, "create", obj, creator):
if creator_attributes: raise AccessDeniedError()
obj.access_attributes = AccessAttributes(**creator_attributes) if creator and creator.attributes:
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") obj.access_attributes = AccessAttributes(**creator.attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p) registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object # TODO: This needs to be fixed for all APIs once they return the registered object
@ -210,9 +216,7 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering # Apply attribute-based access control filtering
if filtered_objs: if filtered_objs:
filtered_objs = [ filtered_objs = [
obj obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
] ]
return filtered_objs return filtered_objs

View file

@ -30,10 +30,7 @@ from pydantic import BaseModel, ValidationError
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import ( from llama_stack.distribution.server.endpoints import (
find_matching_endpoint, find_matching_endpoint,
@ -213,11 +210,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
# Get auth attributes from the request scope # Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {}) user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal, user_attributes)
await log_request_pre_validation(request) await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes # Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes): with request_provider_data_context(request.headers, user):
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:

View file

@ -223,7 +223,10 @@ async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]: ) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) policy = run_config.server.auth.access_policy if run_config.server.auth else []
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved # Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config) add_internal_implementations(impls, run_config)

View file

@ -6,12 +6,12 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import AccessRule, Api
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]): async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.safety], deps[Api.safety],
deps[Api.tool_runtime], deps[Api.tool_runtime],
deps[Api.tool_groups], deps[Api.tool_groups],
policy,
) )
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -60,6 +60,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin):
vector_io_api: VectorIO, vector_io_api: VectorIO,
persistence_store: KVStore, persistence_store: KVStore,
created_at: str, created_at: str,
policy: list[AccessRule],
): ):
self.agent_id = agent_id self.agent_id = agent_id
self.agent_config = agent_config self.agent_config = agent_config
self.inference_api = inference_api self.inference_api = inference_api
self.safety_api = safety_api self.safety_api = safety_api
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store, policy)
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.created_at = created_at self.created_at = created_at

View file

@ -40,6 +40,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -61,6 +62,7 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety, safety_api: Safety,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
policy: list[AccessRule],
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents):
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None self.openai_responses_impl: OpenAIResponsesImpl | None = None
self.policy = policy
async def initialize(self) -> None: async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store) self.persistence_store = await kvstore_impl(self.config.persistence_store)
@ -129,6 +132,7 @@ class MetaReferenceAgentsImpl(Agents):
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
), ),
created_at=agent_info.created_at, created_at=agent_info.created_at,
policy=self.policy,
) )
async def create_agent_session( async def create_agent_session(

View file

@ -10,9 +10,9 @@ import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -23,6 +23,8 @@ class AgentSessionInfo(Session):
vector_db_id: str | None = None vector_db_id: str | None = None
started_at: datetime started_at: datetime
access_attributes: AccessAttributes | None = None access_attributes: AccessAttributes | None = None
identifier: str | None = None
type: str = "session"
class AgentInfo(AgentConfig): class AgentInfo(AgentConfig):
@ -30,15 +32,17 @@ class AgentInfo(AgentConfig):
class AgentPersistence: class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore): def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id self.agent_id = agent_id
self.kvstore = kvstore self.kvstore = kvstore
self.policy = policy
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions # Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes() user = get_authenticated_user()
auth_attributes = user and user.attributes
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo( session_info = AgentSessionInfo(
@ -47,7 +51,10 @@ class AgentPersistence:
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
access_attributes=access_attributes, access_attributes=access_attributes,
turns=[], turns=[],
identifier=name, # should this be qualified in any way?
) )
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError()
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
@ -76,7 +83,7 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes"): if not hasattr(session_info, "access_attributes"):
return True return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes()) return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods.""" """Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry): async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry) table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple models and verify listing # Register multiple models and verify listing
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry): async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry) table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple shields and verify listing # Register multiple shields and verify listing
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry) table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry) m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize() await m_table.initialize()
await m_table.register_model( await m_table.register_model(
model_id="test-model", model_id="test-model",
provider_id="test_providere", provider_id="test_provider",
metadata={"embedding_dimension": 128}, metadata={"embedding_dimension": 128},
model_type=ModelType.embedding, model_type=ModelType.embedding,
) )
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
async def test_datasets_routing_table(cached_disk_dist_registry): async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry) table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple datasets and verify listing # Register multiple datasets and verify listing
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry): async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry) table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple scoring functions and verify listing # Register multiple scoring functions and verify listing
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry): async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry) table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple benchmarks and verify listing # Register multiple benchmarks and verify listing
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry): async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry) table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
# Register multiple tool groups and verify listing # Register multiple tool groups and verify listing

View file

@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"], mock_apis["safety_api"],
mock_apis["tool_runtime_api"], mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"], mock_apis["tool_groups_api"],
{},
) )
await impl.initialize() await impl.initialize()
yield impl yield impl

View file

@ -13,23 +13,24 @@ import pytest
from llama_stack.apis.agents import Turn from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture @pytest.fixture
async def test_setup(sqlite_kvstore): async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore) agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence yield agent_persistence
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup): async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
# Set creator's attributes for the session # Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]} creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_auth_attributes.return_value = creator_attributes mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
# Create a session # Create a session
session_id = await agent_persistence.create_session("Test Session") session_id = await agent_persistence.create_session("Test Session")
@ -43,8 +44,8 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_auth_attributes, test_setup): async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
# Create a session with specific access attributes # Create a session with specific access attributes
@ -55,6 +56,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
turns=[], turns=[],
identifier="Restricted Session",
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
) )
# User with matching attributes can access # User with matching attributes can access
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]} mock_get_authenticated_user.return_value = User(
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
)
retrieved_session = await agent_persistence.get_session_info(session_id) retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None assert retrieved_session is not None
assert retrieved_session.session_id == session_id assert retrieved_session.session_id == session_id
# User without matching attributes cannot access # User without matching attributes cannot access
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
retrieved_session = await agent_persistence.get_session_info(session_id) retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None assert retrieved_session is None
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_auth_attributes, test_setup): async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
# Create a session with restricted access # Create a session with restricted access
@ -87,6 +91,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]), access_attributes=AccessAttributes(roles=["admin"]),
turns=[], turns=[],
identifier="Restricted Session",
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
) )
# Admin can add turn # Admin can add turn
mock_get_auth_attributes.return_value = {"roles": ["admin"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.add_turn_to_session(session_id, turn) await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn # Admin can get turn
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
assert retrieved_turn.turn_id == turn_id assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn # Regular user cannot get turn
mock_get_auth_attributes.return_value = {"roles": ["user"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
with pytest.raises(ValueError): with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id) await agent_persistence.get_session_turn(session_id, turn_id)
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup): async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
# Create a session with restricted access # Create a session with restricted access
@ -140,6 +145,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]), access_attributes=AccessAttributes(roles=["admin"]),
turns=[], turns=[],
identifier="Restricted Session",
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
# Admin user can set inference iterations # Admin user can set inference iterations
mock_get_auth_attributes.return_value = {"roles": ["admin"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5) await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations # Admin user can get inference iterations
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
assert infer_iters == 5 assert infer_iters == 5
# Regular user cannot get inference iterations # Regular user cannot get inference iterations
mock_get_auth_attributes.return_value = {"roles": ["user"]} mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id) infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None assert infer_iters is None

View file

@ -7,10 +7,14 @@
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
import yaml
from pydantic import TypeAdapter
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL
from llama_stack.distribution.request_headers import User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -32,13 +36,14 @@ async def test_setup(cached_disk_dist_registry):
routing_table = ModelsRoutingTable( routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference}, impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry, dist_registry=cached_disk_dist_registry,
policy={},
) )
yield cached_disk_dist_registry, routing_table yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithACL(
identifier="model-public", identifier="model-public",
@ -64,7 +69,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
await registry.register(model_admin_only) await registry.register(model_admin_only)
await registry.register(model_data_scientist) await registry.register(model_data_scientist)
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]} mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
all_models = await routing_table.list_models() all_models = await routing_table.list_models()
assert len(all_models.data) == 2 assert len(all_models.data) == 2
@ -75,7 +80,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist") await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]} mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
all_models = await routing_table.list_models() all_models = await routing_table.list_models()
assert len(all_models.data) == 1 assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public" assert all_models.data[0].identifier == "model-public"
@ -86,7 +91,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist") await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]} mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
all_models = await routing_table.list_models() all_models = await routing_table.list_models()
assert len(all_models.data) == 2 assert len(all_models.data) == 2
model_ids = [m.identifier for m in all_models.data] model_ids = [m.identifier for m in all_models.data]
@ -102,8 +107,8 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithACL(
identifier="model-updates", identifier="model-updates",
@ -112,28 +117,37 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
model_type=ModelType.llm, model_type=ModelType.llm,
) )
await registry.register(model_public) await registry.register(model_public)
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": ["user"], "test-user",
} {
"roles": ["user"],
},
)
model = await routing_table.get_model("model-updates") model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates" assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"]) model_public.access_attributes = AccessAttributes(roles=["admin"])
await registry.update(model_public) await registry.update(model_public)
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": ["user"], "test-user",
} {
"roles": ["user"],
},
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-updates") await routing_table.get_model("model-updates")
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": ["admin"], "test-user",
} {
"roles": ["admin"],
},
)
model = await routing_table.get_model("model-updates") model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates" assert model.identifier == "model-updates"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup): async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model = ModelWithACL( model = ModelWithACL(
identifier="model-empty-attrs", identifier="model-empty-attrs",
@ -143,9 +157,12 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
access_attributes=AccessAttributes(), access_attributes=AccessAttributes(),
) )
await registry.register(model) await registry.register(model)
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": [], "test-user",
} {
"roles": [],
},
)
result = await routing_table.get_model("model-empty-attrs") result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs" assert result.identifier == "model-empty-attrs"
all_models = await routing_table.list_models() all_models = await routing_table.list_models()
@ -154,8 +171,8 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_auth_attributes, test_setup): async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
model_public = ModelWithACL( model_public = ModelWithACL(
identifier="model-public-2", identifier="model-public-2",
@ -172,7 +189,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
) )
await registry.register(model_public) await registry.register(model_public)
await registry.register(model_restricted) await registry.register(model_restricted)
mock_get_auth_attributes.return_value = None mock_get_authenticated_user.return_value = User("test-user", None)
model = await routing_table.get_model("model-public-2") model = await routing_table.get_model("model-public-2")
assert model.identifier == "model-public-2" assert model.identifier == "model-public-2"
@ -185,14 +202,14 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup): async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator.""" """Test that newly created resources inherit access attributes from their creator."""
registry, routing_table = test_setup registry, routing_table = test_setup
# Set creator's attributes # Set creator's attributes
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]} creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
mock_get_auth_attributes.return_value = creator_attributes mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
# Create model without explicit access attributes # Create model without explicit access attributes
model = ModelWithACL( model = ModelWithACL(
@ -211,15 +228,262 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
assert registered_model.access_attributes.projects == ["llama-3"] assert registered_model.access_attributes.projects == ["llama-3"]
# Verify another user without matching attributes can't access it # Verify another user without matching attributes can't access it
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]} mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("auto-access-model") await routing_table.get_model("auto-access-model")
# But a user with matching attributes can # But a user with matching attributes can
mock_get_auth_attributes.return_value = { mock_get_authenticated_user.return_value = User(
"roles": ["data-scientist", "engineer"], "test-user",
"teams": ["ml-team", "platform-team"], {
"projects": ["llama-3"], "roles": ["data-scientist", "engineer"],
} "teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
},
)
model = await routing_table.get_model("auto-access-model") model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model" assert model.identifier == "auto-access-model"
@pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
config = """
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy=policy,
)
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
{
"roles": ["user"],
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
{
"roles": ["user"],
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
def test_permit_when():
config = """
- permit:
principal: user-1
actions: [read]
when:
user_in: resource.namespaces
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_permit_unless():
config = """
- permit:
principal: user-1
actions: [read]
resource: model::*
unless:
- user_not_in: resource.namespaces
- user_in: resource.teams
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_when():
config = """
- forbid:
principal: user-1
actions: [read]
when:
user_in: resource.namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_unless():
config = """
- forbid:
principal: user-1
actions: [read]
unless:
user_in: resource.namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_condition_with_literal():
config = """
- permit:
actions: [read]
when:
user_in: role::admin
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_condition_with_unrecognised_literal():
config = """
- permit:
actions: [read]
when:
user_in: whatever
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", None))
def test_empty_condition():
config = """
- permit:
actions: [read]
when: {}
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", None))

View file

@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
add_protocol_methods(SampleImpl, Inference) add_protocol_methods(SampleImpl, Inference)
mock_module.get_provider_impl = AsyncMock(return_value=impl) mock_module.get_provider_impl = AsyncMock(return_value=impl)
mock_module.get_provider_impl.__text_signature__ = "()"
sys.modules["test_module"] = mock_module sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry) impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
assert Api.inference in impls assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter) assert isinstance(impls[Api.inference], InferenceRouter)