feat: fine grained access control policy (#2264)

This allows a set of rules to be defined for determining access to
resources. The rules are (loosely) based on the cedar policy format.

A rule defines a list of action either to permit or to forbid. It may
specify a principal or a resource that must match for the rule to take
effect. It may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies.

A list of rules is held for each type to be protected and tried in order
to find a match. If a match is found, the request is permitted or
forbidden depening on the type of rule. If no match is found, the
request is denied. If no rules are specified for a given type, a rule
that allows any action as long as the resource attributes match the user
attributes is added (i.e. the previous behaviour is the default.

Some examples in yaml:

```
    model:
    - permit:
      principal: user-1
      actions: [create, read, delete]
      comment: user-1 has full access to all models
    - permit:
      principal: user-2
      actions: [read]
      resource: model-1
      comment: user-2 has read access to model-1 only
    - permit:
      actions: [read]
      when:
        user_in: resource.namespaces
      comment: any user has read access to models with matching attributes
    vector_db:
    - forbid:
      actions: [create, read, delete]
      unless:
        user_in: role::admin
      comment: only user with admin role can use vector_db resources
```

---------

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
grs 2025-06-03 17:51:12 -04:00 committed by GitHub
parent 8bee2954be
commit 7c1998db25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 956 additions and 450 deletions

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.distribution.datatypes import (
BenchmarkWithACL,
BenchmarkWithOwner,
)
from llama_stack.log import get_logger
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL(
benchmark = BenchmarkWithOwner(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,

View file

@ -8,14 +8,14 @@ from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import (
AccessAttributes,
AccessRule,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
@ -73,9 +73,11 @@ class CommonRoutingTableImpl(RoutingTable):
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
self.policy = policy
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
@ -166,13 +168,15 @@ class CommonRoutingTableImpl(RoutingTable):
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}'")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
raise AccessDeniedError()
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
@ -187,11 +191,12 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
creator = get_authenticated_user()
if not is_action_allowed(self.policy, "create", obj, creator):
raise AccessDeniedError()
if creator:
obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
@ -210,9 +215,7 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
]
return filtered_objs

View file

@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
)
from llama_stack.apis.resource import ResourceType
from llama_stack.distribution.datatypes import (
DatasetWithACL,
DatasetWithOwner,
)
from llama_stack.log import get_logger
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
if metadata is None:
metadata = {}
dataset = DatasetWithACL(
dataset = DatasetWithOwner(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithACL,
ModelWithOwner,
)
from llama_stack.log import get_logger
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = ModelWithACL(
model = ModelWithOwner(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,

View file

@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions,
)
from llama_stack.distribution.datatypes import (
ScoringFnWithACL,
ScoringFnWithOwner,
)
from llama_stack.log import get_logger
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithACL(
scoring_fn = ScoringFnWithOwner(
identifier=scoring_fn_id,
description=description,
return_type=return_type,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.distribution.datatypes import (
ShieldWithACL,
ShieldWithOwner,
)
from llama_stack.log import get_logger
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
if params is None:
params = {}
shield = ShieldWithACL(
shield = ShieldWithOwner(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithACL
from llama_stack.distribution.datatypes import ToolGroupWithOwner
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
@ -106,7 +106,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
toolgroup = ToolGroupWithACL(
toolgroup = ToolGroupWithOwner(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,

View file

@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.datatypes import (
VectorDBWithACL,
VectorDBWithOwner,
)
from llama_stack.log import get_logger
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db