feat: allow access attributes for resources to be configured

This allows a set of rules to be defined for determining the access attributes to apply to
a particular resource. It also checks that the attributes determined for a new resource to
be registered are matched by attributes associated with the request context.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-05-06 18:54:58 +01:00
parent 0cc0731189
commit 490e77bffa
10 changed files with 402 additions and 19 deletions

View file

@ -16,7 +16,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
@ -63,6 +63,15 @@ class AccessAttributes(BaseModel):
)
class AccessAttributesRule(BaseModel):
"""Rule for associating AccessAttributes with particular resources"""
resource_type: ResourceType | None = Field(default=None)
resource_id: str | None = Field(default=None)
provider_id: str | None = Field(default=None)
attributes: AccessAttributes
class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities.
@ -233,6 +242,9 @@ class AuthenticationConfig(BaseModel):
...,
description="Provider-specific configuration",
)
resource_attribute_rules: list[AccessAttributesRule] = Field(
default=[], description="Rules for determining access attributes for resources"
)
class ServerConfig(BaseModel):

View file

@ -34,6 +34,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -118,6 +119,7 @@ async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
resource_attributes: ResourceAccessAttributes,
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
@ -140,7 +142,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, resource_attributes)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -243,7 +245,10 @@ def sort_providers_by_deps(
async def instantiate_providers(
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
sorted_providers: list[tuple[str, ProviderWithSpec]],
router_apis: set[Api],
dist_registry: DistributionRegistry,
resource_attributes: ResourceAccessAttributes,
) -> dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
@ -258,7 +263,7 @@ async def instantiate_providers(
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, resource_attributes)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@ -308,6 +313,7 @@ async def instantiate_provider(
deps: dict[Api, Any],
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
resource_attributes: ResourceAccessAttributes,
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
@ -332,7 +338,7 @@ async def instantiate_provider(
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, inner_impls, deps, dist_registry]
args = [provider_spec.api, inner_impls, deps, dist_registry, resource_attributes]
else:
method = "get_provider_impl"

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes, AccessAttributesRule, ResourceWithACL
def match_access_attributes_rule(
rule: AccessAttributesRule, resource_type: str, resource_id: str, provider_id: str
) -> bool:
if rule.resource_type and rule.resource_type.value != resource_type:
return False
if rule.resource_id and rule.resource_id != resource_id:
return False
if rule.provider_id and rule.provider_id != provider_id:
return False
return True
class ResourceAccessAttributes:
def __init__(self, rules: list[AccessAttributesRule]) -> None:
self.rules = rules
self.access_check_enabled = False
def enable_access_checks(self):
self.access_check_enabled = True
def get(self, resource_type: str, resource_id: str, provider_id: str) -> AccessAttributes | None:
for rule in self.rules:
if match_access_attributes_rule(rule, resource_type, resource_id, provider_id):
return rule.attributes
return None
def apply(self, resource: ResourceWithACL, user_attributes: dict[str, list[str]] | None) -> bool:
"""Sets the resource access attributes based on the specified rules.
Returns True if a matching rule was found for this resource.
If access checks have been enable, also checks whether the user attributes allow the
resource to be created.
"""
resource_attributes = self.get(resource.type, resource.identifier, resource.provider_id)
if resource_attributes:
if self.access_check_enabled and not check_access(
resource.identifier, resource_attributes, user_attributes
):
raise ValueError(f"Access denied: {resource.type} '{resource.identifier}'")
resource.access_attributes = resource_attributes
return True
return False

View file

@ -7,6 +7,7 @@
from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -26,6 +27,7 @@ async def get_routing_table_impl(
impls_by_provider_id: dict[str, RoutedProtocol],
_deps,
dist_registry: DistributionRegistry,
resource_attributes: ResourceAccessAttributes,
) -> Any:
api_to_tables = {
"vector_dbs": VectorDBsRoutingTable,
@ -40,7 +42,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, resource_attributes)
await impl.initialize()
return impl

View file

@ -58,6 +58,7 @@ from llama_stack.distribution.datatypes import (
VectorDBWithACL,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -114,9 +115,11 @@ class CommonRoutingTableImpl(RoutingTable):
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
resource_attributes: ResourceAccessAttributes,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
self.resource_attributes = resource_attributes
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
@ -219,8 +222,12 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
if self.resource_attributes.apply(obj, get_auth_attributes()):
logger.info(
f"Setting access attributes for {obj.type} '{obj.identifier}' based on resource attribute rules"
)
else:
# If object supports access control but no attributes set, use creator's attributes
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)

View file

@ -38,6 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -223,12 +224,19 @@ async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
if run_config.server.auth:
resource_attributes = ResourceAccessAttributes(run_config.server.auth.resource_attribute_rules)
else:
resource_attributes = ResourceAccessAttributes([])
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, resource_attributes
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls)
resource_attributes.enable_access_checks()
return impls