mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
feat: allow access attributes for resources to be configured
This allows a set of rules to be defined for determining the access attributes to apply to a particular resource. It also checks that the attributes determined for a new resource to be registered are matched by attributes associated with the request context. Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
0cc0731189
commit
490e77bffa
10 changed files with 402 additions and 19 deletions
|
@ -16,7 +16,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
|
|||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
|
@ -63,6 +63,15 @@ class AccessAttributes(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class AccessAttributesRule(BaseModel):
|
||||
"""Rule for associating AccessAttributes with particular resources"""
|
||||
|
||||
resource_type: ResourceType | None = Field(default=None)
|
||||
resource_id: str | None = Field(default=None)
|
||||
provider_id: str | None = Field(default=None)
|
||||
attributes: AccessAttributes
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
|
||||
|
@ -233,6 +242,9 @@ class AuthenticationConfig(BaseModel):
|
|||
...,
|
||||
description="Provider-specific configuration",
|
||||
)
|
||||
resource_attribute_rules: list[AccessAttributesRule] = Field(
|
||||
default=[], description="Rules for determining access attributes for resources"
|
||||
)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
|
|
|
@ -34,6 +34,7 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -118,6 +119,7 @@ async def resolve_impls(
|
|||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
resource_attributes: ResourceAccessAttributes,
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by:
|
||||
|
@ -140,7 +142,7 @@ async def resolve_impls(
|
|||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, resource_attributes)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
|
@ -243,7 +245,10 @@ def sort_providers_by_deps(
|
|||
|
||||
|
||||
async def instantiate_providers(
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
||||
router_apis: set[Api],
|
||||
dist_registry: DistributionRegistry,
|
||||
resource_attributes: ResourceAccessAttributes,
|
||||
) -> dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
|
@ -258,7 +263,7 @@ async def instantiate_providers(
|
|||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, resource_attributes)
|
||||
|
||||
if api_str.startswith("inner-"):
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
|
@ -308,6 +313,7 @@ async def instantiate_provider(
|
|||
deps: dict[Api, Any],
|
||||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
resource_attributes: ResourceAccessAttributes,
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
|
@ -332,7 +338,7 @@ async def instantiate_provider(
|
|||
method = "get_routing_table_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry, resource_attributes]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
|
|
54
llama_stack/distribution/resource_attributes.py
Normal file
54
llama_stack/distribution/resource_attributes.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, AccessAttributesRule, ResourceWithACL
|
||||
|
||||
|
||||
def match_access_attributes_rule(
|
||||
rule: AccessAttributesRule, resource_type: str, resource_id: str, provider_id: str
|
||||
) -> bool:
|
||||
if rule.resource_type and rule.resource_type.value != resource_type:
|
||||
return False
|
||||
if rule.resource_id and rule.resource_id != resource_id:
|
||||
return False
|
||||
if rule.provider_id and rule.provider_id != provider_id:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ResourceAccessAttributes:
|
||||
def __init__(self, rules: list[AccessAttributesRule]) -> None:
|
||||
self.rules = rules
|
||||
self.access_check_enabled = False
|
||||
|
||||
def enable_access_checks(self):
|
||||
self.access_check_enabled = True
|
||||
|
||||
def get(self, resource_type: str, resource_id: str, provider_id: str) -> AccessAttributes | None:
|
||||
for rule in self.rules:
|
||||
if match_access_attributes_rule(rule, resource_type, resource_id, provider_id):
|
||||
return rule.attributes
|
||||
return None
|
||||
|
||||
def apply(self, resource: ResourceWithACL, user_attributes: dict[str, list[str]] | None) -> bool:
|
||||
"""Sets the resource access attributes based on the specified rules.
|
||||
|
||||
Returns True if a matching rule was found for this resource.
|
||||
|
||||
If access checks have been enable, also checks whether the user attributes allow the
|
||||
resource to be created.
|
||||
"""
|
||||
|
||||
resource_attributes = self.get(resource.type, resource.identifier, resource.provider_id)
|
||||
if resource_attributes:
|
||||
if self.access_check_enabled and not check_access(
|
||||
resource.identifier, resource_attributes, user_attributes
|
||||
):
|
||||
raise ValueError(f"Access denied: {resource.type} '{resource.identifier}'")
|
||||
resource.access_attributes = resource_attributes
|
||||
return True
|
||||
return False
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
@ -26,6 +27,7 @@ async def get_routing_table_impl(
|
|||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
dist_registry: DistributionRegistry,
|
||||
resource_attributes: ResourceAccessAttributes,
|
||||
) -> Any:
|
||||
api_to_tables = {
|
||||
"vector_dbs": VectorDBsRoutingTable,
|
||||
|
@ -40,7 +42,7 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, resource_attributes)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ from llama_stack.distribution.datatypes import (
|
|||
VectorDBWithACL,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
@ -114,9 +115,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self,
|
||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
dist_registry: DistributionRegistry,
|
||||
resource_attributes: ResourceAccessAttributes,
|
||||
) -> None:
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
self.resource_attributes = resource_attributes
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
|
@ -219,8 +222,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
if self.resource_attributes.apply(obj, get_auth_attributes()):
|
||||
logger.info(
|
||||
f"Setting access attributes for {obj.type} '{obj.identifier}' based on resource attribute rules"
|
||||
)
|
||||
else:
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
|
|
|
@ -38,6 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -223,12 +224,19 @@ async def construct_stack(
|
|||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
if run_config.server.auth:
|
||||
resource_attributes = ResourceAccessAttributes(run_config.server.auth.resource_attribute_rules)
|
||||
else:
|
||||
resource_attributes = ResourceAccessAttributes([])
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, resource_attributes
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
resource_attributes.enable_access_checks()
|
||||
return impls
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.models.models import Model, ModelType
|
|||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
|
||||
from llama_stack.distribution.routers.routing_tables import (
|
||||
BenchmarksRoutingTable,
|
||||
DatasetsRoutingTable,
|
||||
|
@ -123,7 +124,9 @@ class ToolGroupsImpl(Impl):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
|
||||
table = ModelsRoutingTable(
|
||||
{"test_provider": InferenceImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple models and verify listing
|
||||
|
@ -165,7 +168,9 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
|
||||
table = ShieldsRoutingTable(
|
||||
{"test_provider": SafetyImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple shields and verify listing
|
||||
|
@ -181,10 +186,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
|
||||
table = VectorDBsRoutingTable(
|
||||
{"test_provider": VectorDBImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
|
||||
m_table = ModelsRoutingTable(
|
||||
{"test_providere": InferenceImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
|
@ -211,7 +220,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
|
||||
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]))
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple datasets and verify listing
|
||||
|
@ -237,7 +246,9 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
|
||||
table = ScoringFunctionsRoutingTable(
|
||||
{"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple scoring functions and verify listing
|
||||
|
@ -263,7 +274,9 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
|
||||
table = BenchmarksRoutingTable(
|
||||
{"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple benchmarks and verify listing
|
||||
|
@ -281,7 +294,9 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
|
||||
table = ToolGroupsRoutingTable(
|
||||
{"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple tool groups and verify listing
|
||||
|
|
209
tests/unit/distribution/test_resource_attributes.py
Normal file
209
tests/unit/distribution/test_resource_attributes.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datasets import URIDataSource
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
DatasetWithACL,
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
)
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes, match_access_attributes_rule
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rules():
|
||||
config = """{
|
||||
"provider_type": "custom",
|
||||
"config": {},
|
||||
"resource_attribute_rules": [
|
||||
{
|
||||
"resource_type": "model",
|
||||
"resource_id": "my-model",
|
||||
"provider_id": "my-provider",
|
||||
"attributes": {
|
||||
"roles": ["role1"],
|
||||
"teams": ["team1"],
|
||||
"projects": ["project1"],
|
||||
"namespaces": ["namespace1"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"resource_type": "dataset",
|
||||
"provider_id": "my-provider",
|
||||
"attributes": {
|
||||
"roles": ["role2"],
|
||||
"teams": ["team2"],
|
||||
"projects": ["project2"],
|
||||
"namespaces": ["namespace2"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"provider_id": "my-provider",
|
||||
"attributes": {
|
||||
"roles": ["role3"],
|
||||
"teams": ["team3"],
|
||||
"projects": ["project3"],
|
||||
"namespaces": ["namespace3"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"resource_type": "model",
|
||||
"attributes": {
|
||||
"roles": ["role4"],
|
||||
"teams": ["team4"],
|
||||
"projects": ["project4"],
|
||||
"namespaces": ["namespace4"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"attributes": {
|
||||
"roles": ["role5"],
|
||||
"teams": ["team5"],
|
||||
"projects": ["project5"],
|
||||
"namespaces": ["namespace5"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}"""
|
||||
return AuthenticationConfig.model_validate_json(config).resource_attribute_rules
|
||||
|
||||
|
||||
def test_match_access_attributes_rule(rules):
|
||||
assert match_access_attributes_rule(rules[0], "model", "my-model", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[0], "model", "another-model", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[0], "model", "my-model", "another-provider")
|
||||
assert not match_access_attributes_rule(rules[0], "dataset", "my-model", "my-provider")
|
||||
|
||||
assert match_access_attributes_rule(rules[1], "dataset", "my-data", "my-provider")
|
||||
assert match_access_attributes_rule(rules[1], "dataset", "different-data", "my-provider")
|
||||
assert match_access_attributes_rule(rules[1], "dataset", "any-data", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[1], "model", "a-model", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[1], "dataset", "any-data", "another-provider")
|
||||
assert not match_access_attributes_rule(rules[1], "model", "my-model", "my-provider")
|
||||
|
||||
assert match_access_attributes_rule(rules[2], "dataset", "foo", "my-provider")
|
||||
assert match_access_attributes_rule(rules[2], "model", "foo", "my-provider")
|
||||
assert match_access_attributes_rule(rules[2], "vector_db", "bar", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[2], "dataset", "foo", "another-provider")
|
||||
|
||||
assert match_access_attributes_rule(rules[3], "model", "foo", "my-provider")
|
||||
assert match_access_attributes_rule(rules[3], "model", "bar", "my-provider")
|
||||
assert match_access_attributes_rule(rules[3], "model", "foo", "another-provider")
|
||||
assert not match_access_attributes_rule(rules[3], "dataset", "bar", "my-provider")
|
||||
|
||||
assert match_access_attributes_rule(rules[4], "model", "foo", "my-provider")
|
||||
assert match_access_attributes_rule(rules[4], "model", "bar", "my-provider")
|
||||
assert match_access_attributes_rule(rules[4], "model", "foo", "another-provider")
|
||||
assert match_access_attributes_rule(rules[4], "dataset", "bar", "my-provider")
|
||||
assert match_access_attributes_rule(rules[4], "vector_db", "baz", "any-provider")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resource_access_attributes(rules):
|
||||
return ResourceAccessAttributes(rules)
|
||||
|
||||
|
||||
def dataset(identifier: str, provider_id: str) -> DatasetWithACL:
|
||||
return DatasetWithACL(
|
||||
identifier=identifier,
|
||||
provider_id=provider_id,
|
||||
purpose="eval/question-answer",
|
||||
source=URIDataSource(uri="https://a.com/a.jsonl"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"resource,expected_roles",
|
||||
[
|
||||
(ModelWithACL(identifier="my-model", provider_id="my-provider"), ["role1"]),
|
||||
(ModelWithACL(identifier="another-model", provider_id="my-provider"), ["role3"]),
|
||||
(ModelWithACL(identifier="third-model", provider_id="another-provider"), ["role4"]),
|
||||
(dataset(identifier="my-dataset", provider_id="my-provider"), ["role2"]),
|
||||
(dataset(identifier="another-dataset", provider_id="another-provider"), ["role5"]),
|
||||
(ShieldWithACL(identifier="my-shield", provider_id="my-provider"), ["role3"]),
|
||||
(ShieldWithACL(identifier="another-shield", provider_id="another-provider"), ["role5"]),
|
||||
],
|
||||
)
|
||||
def test_apply(resource_access_attributes, resource, expected_roles):
|
||||
assert resource_access_attributes.apply(resource, None)
|
||||
assert resource.access_attributes.roles == expected_roles
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def alternate_access_attributes_rules():
|
||||
config = """{
|
||||
"provider_type": "custom",
|
||||
"config": {},
|
||||
"resource_attribute_rules": [
|
||||
{
|
||||
"resource_type": "model",
|
||||
"resource_id": "my-model",
|
||||
"provider_id": "my-provider",
|
||||
"attributes": {
|
||||
"roles": ["roleA"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"resource_type": "model",
|
||||
"attributes": {
|
||||
"roles": ["roleB"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}"""
|
||||
return AuthenticationConfig.model_validate_json(config).resource_attribute_rules
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def alternate_access_attributes(alternate_access_attributes_rules):
|
||||
return ResourceAccessAttributes(alternate_access_attributes_rules)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"resource,expected_roles",
|
||||
[
|
||||
(ModelWithACL(identifier="my-model", provider_id="my-provider"), ["roleA"]),
|
||||
(ModelWithACL(identifier="another-model", provider_id="another-provider"), ["roleB"]),
|
||||
(dataset(identifier="my-dataset", provider_id="my-provider"), None),
|
||||
(dataset(identifier="another-dataset", provider_id="another-provider"), None),
|
||||
],
|
||||
)
|
||||
def test_apply_alternate(alternate_access_attributes, resource, expected_roles):
|
||||
if expected_roles:
|
||||
assert alternate_access_attributes.apply(resource, None)
|
||||
assert resource.access_attributes.roles == expected_roles
|
||||
else:
|
||||
assert not alternate_access_attributes.apply(resource, None)
|
||||
assert not resource.access_attributes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def checked_attributes(alternate_access_attributes_rules):
|
||||
attributes = ResourceAccessAttributes(alternate_access_attributes_rules)
|
||||
attributes.enable_access_checks()
|
||||
return attributes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"resource,user_attributes,fails,result",
|
||||
[
|
||||
(ModelWithACL(identifier="my-model", provider_id="my-provider"), {"roles": ["roleA"]}, False, True),
|
||||
(ModelWithACL(identifier="my-model", provider_id="my-provider"), {"roles": ["somethingelse"]}, True, False),
|
||||
(dataset(identifier="my-dataset", provider_id="my-provider"), {"roles": ["somethingelse"]}, False, False),
|
||||
(dataset(identifier="my-dataset", provider_id="my-provider"), None, False, False),
|
||||
],
|
||||
)
|
||||
def test_access_check_on_apply(checked_attributes, resource, user_attributes, fails, result):
|
||||
if fails:
|
||||
with pytest.raises(ValueError) as e:
|
||||
checked_attributes.apply(resource, user_attributes)
|
||||
assert "Access denied" in str(e.value)
|
||||
assert not resource.access_attributes
|
||||
else:
|
||||
assert checked_attributes.apply(resource, user_attributes) == result
|
|
@ -10,7 +10,8 @@ import pytest
|
|||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, AccessAttributesRule, ModelWithACL
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
|
||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||
|
||||
|
||||
|
@ -32,6 +33,7 @@ async def test_setup(cached_disk_dist_registry):
|
|||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
resource_attributes=ResourceAccessAttributes([]),
|
||||
)
|
||||
yield cached_disk_dist_registry, routing_table
|
||||
|
||||
|
@ -223,3 +225,70 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
|
|||
}
|
||||
model = await routing_table.get_model("auto-access-model")
|
||||
assert model.identifier == "auto-access-model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup_with_resource_attribute_rules(cached_disk_dist_registry):
|
||||
mock_inference = Mock()
|
||||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = Api.inference
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
resource_attributes = ResourceAccessAttributes(
|
||||
[
|
||||
# model-1 can be accessed by users in project foo or those who have admin role
|
||||
AccessAttributesRule(
|
||||
resource_type="model", resource_id="model-1", attributes=AccessAttributes(projects=["foo"])
|
||||
),
|
||||
# model-2 can be accessed by users in project bar or those who have admin role
|
||||
AccessAttributesRule(
|
||||
resource_type="model", resource_id="model-2", attributes=AccessAttributes(projects=["bar"])
|
||||
),
|
||||
# any other model can be accessed or created only by those who have admin role
|
||||
AccessAttributesRule(resource_type="model", attributes=AccessAttributes(roles=["admin"])),
|
||||
]
|
||||
)
|
||||
resource_attributes.enable_access_checks()
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
resource_attributes=resource_attributes,
|
||||
)
|
||||
yield routing_table
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_with_resource_attribute_rules(mock_get_auth_attributes, test_setup_with_resource_attribute_rules):
|
||||
routing_table = test_setup_with_resource_attribute_rules
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["admin"],
|
||||
"projects": ["foo", "bar"],
|
||||
}
|
||||
await routing_table.register_model("model-1", provider_id="test_provider")
|
||||
await routing_table.register_model("model-2", provider_id="test_provider")
|
||||
await routing_table.register_model("model-3", provider_id="test_provider")
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
"projects": ["foo"],
|
||||
}
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-2")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.register_model("model-4", provider_id="test_provider")
|
||||
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
"projects": ["bar"],
|
||||
}
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-1")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.register_model("model-5", provider_id="test_provider")
|
||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
|
||||
from llama_stack.distribution.routers.routers import InferenceRouter
|
||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||
|
@ -102,7 +103,7 @@ async def test_resolve_impls_basic():
|
|||
mock_module.get_provider_impl = AsyncMock(return_value=impl)
|
||||
sys.modules["test_module"] = mock_module
|
||||
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry)
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry, ResourceAccessAttributes([]))
|
||||
|
||||
assert Api.inference in impls
|
||||
assert isinstance(impls[Api.inference], InferenceRouter)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue