feat: fine grained access control policy

This allows a set of rules to be defined for determining access to resources.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-05-06 18:54:58 +01:00
parent 9623d5d230
commit 01ad876012
20 changed files with 724 additions and 214 deletions

View file

@ -1,86 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(
obj_identifier: str,
obj_attributes: AccessAttributes | None,
user_attributes: dict[str, Any] | None = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol
from llama_stack.distribution.request_headers import User
from .datatypes import (
AccessAttributes,
AccessRule,
Action,
AttributeReference,
Condition,
Scope,
)
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
if resource_scope == actual_resource:
return True
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
def matches_scope(
scope: Scope,
action: Action,
resource: str,
user: str | None,
) -> bool:
if scope.resource and not matches_resource(scope.resource, resource):
return False
if scope.principal and scope.principal != user:
return False
return action in scope.actions
def user_in_literal(
literal: str,
user_attributes: dict[str, list[str]] | None,
) -> bool:
for qualifier in ["role::", "team::", "project::", "namespace::"]:
if literal.startswith(qualifier):
if not user_attributes:
return False
ref = qualifier.replace("::", "s")
if ref in user_attributes:
value = literal.removeprefix(qualifier)
return value in user_attributes[ref]
else:
return False
return False
def user_in(
ref: AttributeReference | str,
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
if not ref.startswith("resource."):
return user_in_literal(ref, user_attributes)
name = ref.removeprefix("resource.")
required = resource_attributes and getattr(resource_attributes, name)
if not required:
return True
if not user_attributes or name not in user_attributes:
return False
actual = user_attributes[name]
for value in required:
if value in actual:
return True
return False
def as_list(obj: Any) -> list[Any]:
if isinstance(obj, list):
return obj
return [obj]
def matches_conditions(
conditions: list[Condition],
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
for condition in conditions:
# must match all conditions
if not matches_condition(condition, resource_attributes, user_attributes):
return False
return True
def matches_condition(
condition: Condition | list[Condition],
resource_attributes: AccessAttributes | None,
user_attributes: dict[str, list[str]] | None,
) -> bool:
if isinstance(condition, list):
return matches_conditions(as_list(condition), resource_attributes, user_attributes)
if condition.user_in:
for ref in as_list(condition.user_in):
# if multiple references are specified, all must match
if not user_in(ref, resource_attributes, user_attributes):
return False
return True
if condition.user_not_in:
for ref in as_list(condition.user_not_in):
# if multiple references are specified, none must match
if user_in(ref, resource_attributes, user_attributes):
return False
return True
return True
def default_policy() -> list[AccessRule]:
# for backwards compatibility, if no rules are provided , assume
# full access to all subject to attribute matching rules
return [
AccessRule(
permit=Scope(actions=list(Action)),
when=Condition(user_in=list(AttributeReference)),
)
]
class ProtectedResource(Protocol):
type: str
identifier: str
access_attributes: AccessAttributes
def is_action_allowed(
policy: list[AccessRule],
action: Action,
resource: ProtectedResource,
user: User | None,
) -> bool:
# If user is not set, assume authentication is not enabled
if not user:
return True
if not len(policy):
policy = default_policy()
resource_attributes = AccessAttributes()
if hasattr(resource, "access_attributes"):
resource_attributes = resource.access_attributes
qualified_resource_id = resource.type + "::" + resource.identifier
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:
if matches_condition(rule.when, resource_attributes, user.attributes):
return False
elif rule.unless:
if not matches_condition(rule.unless, resource_attributes, user.attributes):
return False
else:
return False
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
if rule.when:
if matches_condition(rule.when, resource_attributes, user.attributes):
return True
elif rule.unless:
if not matches_condition(rule.unless, resource_attributes, user.attributes):
return True
else:
return True
# assume access is denied unless we find a rule that permits access
return False
class AccessDeniedError(RuntimeError):
pass

View file

@ -0,0 +1,136 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class Action(str, Enum):
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
class Scope(BaseModel):
principal: str | None = None
actions: Action | list[Action]
resource: str | None = None
def _mutually_exclusive(obj, a: str, b: str):
if getattr(obj, a) and getattr(obj, b):
raise ValueError(f"{a} and {b} are mutually exclusive")
def _require_one_of(obj, a: str, b: str):
if not getattr(obj, a) and not getattr(obj, b):
raise ValueError(f"on of {a} or {b} is required")
class AttributeReference(str, Enum):
RESOURCE_ROLES = "resource.roles"
RESOURCE_TEAMS = "resource.teams"
RESOURCE_PROJECTS = "resource.projects"
RESOURCE_NAMESPACES = "resource.namespaces"
class Condition(BaseModel):
user_in: AttributeReference | list[AttributeReference] | str | None = None
user_not_in: AttributeReference | list[AttributeReference] | str | None = None
class AccessRule(BaseModel):
"""Access rule based loosely on cedar policy language
A rule defines a list of action either to permit or to forbid. It may specify a
principal or a resource that must match for the rule to take effect. The resource
to match should be specified in the form of a type qualified identifier, e.g.
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
e.g. model::*. If the principal or resource are not specified, they will match all
requests.
A rule may also specify a condition, either a 'when' or an 'unless', with additional
constraints as to where the rule applies. The constraints at present are whether the
user requesting access is in or not in some set. This set can either be a particular
set of attributes on the resource e.g. resource.roles or a literal value of some
notion of group, e.g. role::admin or namespace::foo.
Rules are tested in order to find a match. If a match is found, the request is
permitted or forbidden depending on the type of rule. If no match is found, the
request is denied. If no rules are specified, a rule that allows any action as
long as the resource attributes match the user attributes is added
(i.e. the previous behaviour is the default).
Some examples in yaml:
- permit:
principal: user-1
actions: [create, read, delete]
resource: model::*
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
actions: [read]
when:
user_in: resource.namespaces
description: any user has read access to any resource with matching attributes
- forbid:
actions: [create, read, delete]
resource: vector_db::*
unless:
user_in: role::admin
description: only user with admin role can use vector_db resources
"""
permit: Scope | None = None
forbid: Scope | None = None
when: Condition | list[Condition] | None = None
unless: Condition | list[Condition] | None = None
description: str | None = None
@model_validator(mode="after")
def validate_rule_format(self) -> Self:
_require_one_of(self, "permit", "forbid")
_mutually_exclusive(self, "permit", "forbid")
_mutually_exclusive(self, "when", "unless")
return self

View file

@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
@ -35,35 +36,6 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = str | list[str]
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities.
@ -234,6 +206,7 @@ class AuthenticationConfig(BaseModel):
...,
description="Provider-specific configuration",
)
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class AuthenticationRequiredError(Exception):

View file

@ -18,15 +18,23 @@ log = logging.getLogger(__name__)
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class User:
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]]
def __init__(self, principal: str, attributes: dict[str, list[str]]):
self.principal = principal
self.attributes = attributes
class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""
def __init__(
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
):
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
self.provider_data = provider_data or {}
if auth_attributes:
self.provider_data["__auth_attributes"] = auth_attributes
if user:
self.provider_data["__authenticated_user"] = user
self.token = None
@ -95,9 +103,9 @@ def request_provider_data_context(
return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> dict[str, list[str]] | None:
def get_authenticated_user() -> User | None:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__auth_attributes")
return provider_data.get("__authenticated_user")

View file

@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AccessRule,
AutoRoutedProviderSpec,
Provider,
RoutingTableProviderSpec,
@ -118,6 +119,7 @@ async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
@ -140,7 +142,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -247,6 +249,7 @@ async def instantiate_providers(
router_apis: set[Api],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
) -> dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
@ -261,7 +264,7 @@ async def instantiate_providers(
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@ -312,6 +315,7 @@ async def instantiate_provider(
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
@ -336,13 +340,15 @@ async def instantiate_provider(
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, inner_impls, deps, dist_registry]
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
else:
method = "get_provider_impl"
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
fn = getattr(module, method)
impl = await fn(*args)

View file

@ -6,7 +6,7 @@
from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol
from llama_stack.distribution.stack import StackRunConfig
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -18,6 +18,7 @@ async def get_routing_table_impl(
impls_by_provider_id: dict[str, RoutedProtocol],
_deps,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> Any:
from ..routing_tables.benchmarks import BenchmarksRoutingTable
from ..routing_tables.datasets import DatasetsRoutingTable
@ -40,7 +41,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
await impl.initialize()
return impl

View file

@ -8,14 +8,15 @@ from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import (
AccessAttributes,
AccessRule,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
@ -73,9 +74,11 @@ class CommonRoutingTableImpl(RoutingTable):
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
self.policy = policy
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
@ -166,13 +169,15 @@ class CommonRoutingTableImpl(RoutingTable):
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}'")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
raise AccessDeniedError()
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
@ -187,11 +192,12 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
creator = get_authenticated_user()
if not is_action_allowed(self.policy, "create", obj, creator):
raise AccessDeniedError()
if creator and creator.attributes:
obj.access_attributes = AccessAttributes(**creator.attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
@ -210,9 +216,7 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
]
return filtered_objs

View file

@ -30,10 +30,7 @@ from pydantic import BaseModel, ValidationError
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
@ -213,11 +210,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal, user_attributes)
await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
with request_provider_data_context(request.headers, user):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:

View file

@ -223,7 +223,10 @@ async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)

View file

@ -6,12 +6,12 @@
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import AccessRule, Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.safety],
deps[Api.tool_runtime],
deps[Api.tool_groups],
policy,
)
await impl.initialize()
return impl

View file

@ -60,6 +60,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin):
vector_io_api: VectorIO,
persistence_store: KVStore,
created_at: str,
policy: list[AccessRule],
):
self.agent_id = agent_id
self.agent_config = agent_config
self.inference_api = inference_api
self.safety_api = safety_api
self.vector_io_api = vector_io_api
self.storage = AgentPersistence(agent_id, persistence_store)
self.storage = AgentPersistence(agent_id, persistence_store, policy)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at

View file

@ -40,6 +40,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -61,6 +62,7 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
policy: list[AccessRule],
):
self.config = config
self.inference_api = inference_api
@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents):
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
self.policy = policy
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
@ -129,6 +132,7 @@ class MetaReferenceAgentsImpl(Agents):
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
policy=self.policy,
)
async def create_agent_session(

View file

@ -10,9 +10,9 @@ import uuid
from datetime import datetime, timezone
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -23,6 +23,8 @@ class AgentSessionInfo(Session):
vector_db_id: str | None = None
started_at: datetime
access_attributes: AccessAttributes | None = None
identifier: str | None = None
type: str = "session"
class AgentInfo(AgentConfig):
@ -30,15 +32,17 @@ class AgentInfo(AgentConfig):
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id
self.kvstore = kvstore
self.policy = policy
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes()
user = get_authenticated_user()
auth_attributes = user and user.attributes
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo(
@ -47,7 +51,10 @@ class AgentPersistence:
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
turns=[],
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError()
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -76,7 +83,7 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes"):
return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple models and verify listing
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple shields and verify listing
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_providere",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple datasets and verify listing
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple scoring functions and verify listing
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple benchmarks and verify listing
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple tool groups and verify listing

View file

@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
{},
)
await impl.initialize()
yield impl

View file

@ -13,23 +13,24 @@ import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_auth_attributes.return_value = creator_attributes
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
# Create a session
session_id = await agent_persistence.create_session("Test Session")
@ -43,8 +44,8 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
@ -55,6 +56,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
)
# User with matching attributes can access
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
mock_get_authenticated_user.return_value = User(
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
)
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
@ -87,6 +91,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
)
# Admin can add turn
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_auth_attributes.return_value = {"roles": ["user"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
@ -140,6 +145,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_auth_attributes.return_value = {"roles": ["user"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None

View file

@ -7,10 +7,14 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
import yaml
from pydantic import TypeAdapter
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL
from llama_stack.distribution.request_headers import User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -32,13 +36,14 @@ async def test_setup(cached_disk_dist_registry):
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy={},
)
yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-public",
@ -64,7 +69,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
await registry.register(model_admin_only)
await registry.register(model_data_scientist)
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
@ -75,7 +80,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public"
@ -86,7 +91,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
model_ids = [m.identifier for m in all_models.data]
@ -102,8 +107,8 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-updates",
@ -112,28 +117,37 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
model_type=ModelType.llm,
)
await registry.register(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["user"],
},
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"])
await registry.update(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["user"],
},
)
with pytest.raises(ValueError):
await routing_table.get_model("model-updates")
mock_get_auth_attributes.return_value = {
"roles": ["admin"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["admin"],
},
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model = ModelWithACL(
identifier="model-empty-attrs",
@ -143,9 +157,12 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
access_attributes=AccessAttributes(),
)
await registry.register(model)
mock_get_auth_attributes.return_value = {
"roles": [],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": [],
},
)
result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs"
all_models = await routing_table.list_models()
@ -154,8 +171,8 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-public-2",
@ -172,7 +189,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
)
await registry.register(model_public)
await registry.register(model_restricted)
mock_get_auth_attributes.return_value = None
mock_get_authenticated_user.return_value = User("test-user", None)
model = await routing_table.get_model("model-public-2")
assert model.identifier == "model-public-2"
@ -185,14 +202,14 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator."""
registry, routing_table = test_setup
# Set creator's attributes
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
mock_get_auth_attributes.return_value = creator_attributes
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
# Create model without explicit access attributes
model = ModelWithACL(
@ -211,15 +228,262 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
assert registered_model.access_attributes.projects == ["llama-3"]
# Verify another user without matching attributes can't access it
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
with pytest.raises(ValueError):
await routing_table.get_model("auto-access-model")
# But a user with matching attributes can
mock_get_auth_attributes.return_value = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
},
)
model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model"
@pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
config = """
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy=policy,
)
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
{
"roles": ["user"],
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
{
"roles": ["user"],
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
def test_permit_when():
config = """
- permit:
principal: user-1
actions: [read]
when:
user_in: resource.namespaces
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_permit_unless():
config = """
- permit:
principal: user-1
actions: [read]
resource: model::*
unless:
- user_not_in: resource.namespaces
- user_in: resource.teams
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_when():
config = """
- forbid:
principal: user-1
actions: [read]
when:
user_in: resource.namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_unless():
config = """
- forbid:
principal: user-1
actions: [read]
unless:
user_in: resource.namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_condition_with_literal():
config = """
- permit:
actions: [read]
when:
user_in: role::admin
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_condition_with_unrecognised_literal():
config = """
- permit:
actions: [read]
when:
user_in: whatever
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", None))
def test_empty_condition():
config = """
- permit:
actions: [read]
when: {}
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", None))

View file

@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
add_protocol_methods(SampleImpl, Inference)
mock_module.get_provider_impl = AsyncMock(return_value=impl)
mock_module.get_provider_impl.__text_signature__ = "()"
sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry)
impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter)