diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 446a88ca0..635f63ad1 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -16,7 +16,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput from llama_stack.apis.eval import Eval from llama_stack.apis.inference import Inference from llama_stack.apis.models import Model, ModelInput -from llama_stack.apis.resource import Resource +from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput @@ -63,6 +63,15 @@ class AccessAttributes(BaseModel): ) +class AccessAttributesRule(BaseModel): + """Rule for associating AccessAttributes with particular resources""" + + resource_type: ResourceType | None = Field(default=None) + resource_id: str | None = Field(default=None) + provider_id: str | None = Field(default=None) + attributes: AccessAttributes + + class ResourceWithACL(Resource): """Extension of Resource that adds attribute-based access control capabilities. @@ -233,6 +242,9 @@ class AuthenticationConfig(BaseModel): ..., description="Provider-specific configuration", ) + resource_attribute_rules: list[AccessAttributesRule] = Field( + default=[], description="Rules for determining access attributes for resources" + ) class ServerConfig(BaseModel): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 257c495c3..f796cb703 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -34,6 +34,7 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.resource_attributes import ResourceAccessAttributes from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger @@ -118,6 +119,7 @@ async def resolve_impls( run_config: StackRunConfig, provider_registry: ProviderRegistry, dist_registry: DistributionRegistry, + resource_attributes: ResourceAccessAttributes, ) -> dict[Api, Any]: """ Resolves provider implementations by: @@ -140,7 +142,7 @@ async def resolve_impls( sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) - return await instantiate_providers(sorted_providers, router_apis, dist_registry) + return await instantiate_providers(sorted_providers, router_apis, dist_registry, resource_attributes) def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: @@ -243,7 +245,10 @@ def sort_providers_by_deps( async def instantiate_providers( - sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry + sorted_providers: list[tuple[str, ProviderWithSpec]], + router_apis: set[Api], + dist_registry: DistributionRegistry, + resource_attributes: ResourceAccessAttributes, ) -> dict: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} @@ -258,7 +263,7 @@ async def instantiate_providers( if isinstance(provider.spec, RoutingTableProviderSpec): inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] - impl = await instantiate_provider(provider, deps, inner_impls, dist_registry) + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, resource_attributes) if api_str.startswith("inner-"): inner_impls_by_provider_id[api_str][provider.provider_id] = impl @@ -308,6 +313,7 @@ async def instantiate_provider( deps: dict[Api, Any], inner_impls: dict[str, Any], dist_registry: DistributionRegistry, + resource_attributes: ResourceAccessAttributes, ): provider_spec = provider.spec if not hasattr(provider_spec, "module"): @@ -332,7 +338,7 @@ async def instantiate_provider( method = "get_routing_table_impl" config = None - args = [provider_spec.api, inner_impls, deps, dist_registry] + args = [provider_spec.api, inner_impls, deps, dist_registry, resource_attributes] else: method = "get_provider_impl" diff --git a/llama_stack/distribution/resource_attributes.py b/llama_stack/distribution/resource_attributes.py new file mode 100644 index 000000000..304aa37ed --- /dev/null +++ b/llama_stack/distribution/resource_attributes.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.distribution.access_control import check_access +from llama_stack.distribution.datatypes import AccessAttributes, AccessAttributesRule, ResourceWithACL + + +def match_access_attributes_rule( + rule: AccessAttributesRule, resource_type: str, resource_id: str, provider_id: str +) -> bool: + if rule.resource_type and rule.resource_type.value != resource_type: + return False + if rule.resource_id and rule.resource_id != resource_id: + return False + if rule.provider_id and rule.provider_id != provider_id: + return False + return True + + +class ResourceAccessAttributes: + def __init__(self, rules: list[AccessAttributesRule]) -> None: + self.rules = rules + self.access_check_enabled = False + + def enable_access_checks(self): + self.access_check_enabled = True + + def get(self, resource_type: str, resource_id: str, provider_id: str) -> AccessAttributes | None: + for rule in self.rules: + if match_access_attributes_rule(rule, resource_type, resource_id, provider_id): + return rule.attributes + return None + + def apply(self, resource: ResourceWithACL, user_attributes: dict[str, list[str]] | None) -> bool: + """Sets the resource access attributes based on the specified rules. + + Returns True if a matching rule was found for this resource. + + If access checks have been enable, also checks whether the user attributes allow the + resource to be created. + """ + + resource_attributes = self.get(resource.type, resource.identifier, resource.provider_id) + if resource_attributes: + if self.access_check_enabled and not check_access( + resource.identifier, resource_attributes, user_attributes + ): + raise ValueError(f"Access denied: {resource.type} '{resource.identifier}'") + resource.access_attributes = resource_attributes + return True + return False diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index cd2a296f2..773fd4d90 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,6 +7,7 @@ from typing import Any from llama_stack.distribution.datatypes import RoutedProtocol +from llama_stack.distribution.resource_attributes import ResourceAccessAttributes from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable @@ -26,6 +27,7 @@ async def get_routing_table_impl( impls_by_provider_id: dict[str, RoutedProtocol], _deps, dist_registry: DistributionRegistry, + resource_attributes: ResourceAccessAttributes, ) -> Any: api_to_tables = { "vector_dbs": VectorDBsRoutingTable, @@ -40,7 +42,7 @@ async def get_routing_table_impl( if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) + impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, resource_attributes) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index c04562197..52a1f5c1c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -58,6 +58,7 @@ from llama_stack.distribution.datatypes import ( VectorDBWithACL, ) from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.resource_attributes import ResourceAccessAttributes from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable @@ -114,9 +115,11 @@ class CommonRoutingTableImpl(RoutingTable): self, impls_by_provider_id: dict[str, RoutedProtocol], dist_registry: DistributionRegistry, + resource_attributes: ResourceAccessAttributes, ) -> None: self.impls_by_provider_id = impls_by_provider_id self.dist_registry = dist_registry + self.resource_attributes = resource_attributes async def initialize(self) -> None: async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: @@ -219,8 +222,12 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] - # If object supports access control but no attributes set, use creator's attributes - if not obj.access_attributes: + if self.resource_attributes.apply(obj, get_auth_attributes()): + logger.info( + f"Setting access attributes for {obj.type} '{obj.identifier}' based on resource attribute rules" + ) + else: + # If object supports access control but no attributes set, use creator's attributes creator_attributes = get_auth_attributes() if creator_attributes: obj.access_attributes = AccessAttributes(**creator_attributes) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index fc68dc016..7200831b3 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -38,6 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls +from llama_stack.distribution.resource_attributes import ResourceAccessAttributes from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger @@ -223,12 +224,19 @@ async def construct_stack( run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None ) -> dict[Api, Any]: dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) - impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) + if run_config.server.auth: + resource_attributes = ResourceAccessAttributes(run_config.server.auth.resource_attribute_rules) + else: + resource_attributes = ResourceAccessAttributes([]) + impls = await resolve_impls( + run_config, provider_registry or get_provider_registry(run_config), dist_registry, resource_attributes + ) # Add internal implementations after all other providers are resolved add_internal_implementations(impls, run_config) await register_resources(run_config, impls) + resource_attributes.enable_access_checks() return impls diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 4e6585ad6..f509b9dac 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -17,6 +17,7 @@ from llama_stack.apis.models.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter from llama_stack.apis.vector_dbs.vector_dbs import VectorDB +from llama_stack.distribution.resource_attributes import ResourceAccessAttributes from llama_stack.distribution.routers.routing_tables import ( BenchmarksRoutingTable, DatasetsRoutingTable, @@ -123,7 +124,9 @@ class ToolGroupsImpl(Impl): @pytest.mark.asyncio async def test_models_routing_table(cached_disk_dist_registry): - table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry) + table = ModelsRoutingTable( + {"test_provider": InferenceImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]) + ) await table.initialize() # Register multiple models and verify listing @@ -165,7 +168,9 @@ async def test_models_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_shields_routing_table(cached_disk_dist_registry): - table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry) + table = ShieldsRoutingTable( + {"test_provider": SafetyImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]) + ) await table.initialize() # Register multiple shields and verify listing @@ -181,10 +186,14 @@ async def test_shields_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_vectordbs_routing_table(cached_disk_dist_registry): - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry) + table = VectorDBsRoutingTable( + {"test_provider": VectorDBImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]) + ) await table.initialize() - m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry) + m_table = ModelsRoutingTable( + {"test_providere": InferenceImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]) + ) await m_table.initialize() await m_table.register_model( model_id="test-model", @@ -211,7 +220,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_datasets_routing_table(cached_disk_dist_registry): - table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry) + table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])) await table.initialize() # Register multiple datasets and verify listing @@ -237,7 +246,9 @@ async def test_datasets_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_scoring_functions_routing_table(cached_disk_dist_registry): - table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry) + table = ScoringFunctionsRoutingTable( + {"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]) + ) await table.initialize() # Register multiple scoring functions and verify listing @@ -263,7 +274,9 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_benchmarks_routing_table(cached_disk_dist_registry): - table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry) + table = BenchmarksRoutingTable( + {"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]) + ) await table.initialize() # Register multiple benchmarks and verify listing @@ -281,7 +294,9 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry): @pytest.mark.asyncio async def test_tool_groups_routing_table(cached_disk_dist_registry): - table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry) + table = ToolGroupsRoutingTable( + {"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]) + ) await table.initialize() # Register multiple tool groups and verify listing diff --git a/tests/unit/distribution/test_resource_attributes.py b/tests/unit/distribution/test_resource_attributes.py new file mode 100644 index 000000000..126993d9f --- /dev/null +++ b/tests/unit/distribution/test_resource_attributes.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.datasets import URIDataSource +from llama_stack.distribution.datatypes import ( + AuthenticationConfig, + DatasetWithACL, + ModelWithACL, + ShieldWithACL, +) +from llama_stack.distribution.resource_attributes import ResourceAccessAttributes, match_access_attributes_rule + + +@pytest.fixture +def rules(): + config = """{ + "provider_type": "custom", + "config": {}, + "resource_attribute_rules": [ + { + "resource_type": "model", + "resource_id": "my-model", + "provider_id": "my-provider", + "attributes": { + "roles": ["role1"], + "teams": ["team1"], + "projects": ["project1"], + "namespaces": ["namespace1"] + } + }, + { + "resource_type": "dataset", + "provider_id": "my-provider", + "attributes": { + "roles": ["role2"], + "teams": ["team2"], + "projects": ["project2"], + "namespaces": ["namespace2"] + } + }, + { + "provider_id": "my-provider", + "attributes": { + "roles": ["role3"], + "teams": ["team3"], + "projects": ["project3"], + "namespaces": ["namespace3"] + } + }, + { + "resource_type": "model", + "attributes": { + "roles": ["role4"], + "teams": ["team4"], + "projects": ["project4"], + "namespaces": ["namespace4"] + } + }, + { + "attributes": { + "roles": ["role5"], + "teams": ["team5"], + "projects": ["project5"], + "namespaces": ["namespace5"] + } + } + ] +}""" + return AuthenticationConfig.model_validate_json(config).resource_attribute_rules + + +def test_match_access_attributes_rule(rules): + assert match_access_attributes_rule(rules[0], "model", "my-model", "my-provider") + assert not match_access_attributes_rule(rules[0], "model", "another-model", "my-provider") + assert not match_access_attributes_rule(rules[0], "model", "my-model", "another-provider") + assert not match_access_attributes_rule(rules[0], "dataset", "my-model", "my-provider") + + assert match_access_attributes_rule(rules[1], "dataset", "my-data", "my-provider") + assert match_access_attributes_rule(rules[1], "dataset", "different-data", "my-provider") + assert match_access_attributes_rule(rules[1], "dataset", "any-data", "my-provider") + assert not match_access_attributes_rule(rules[1], "model", "a-model", "my-provider") + assert not match_access_attributes_rule(rules[1], "dataset", "any-data", "another-provider") + assert not match_access_attributes_rule(rules[1], "model", "my-model", "my-provider") + + assert match_access_attributes_rule(rules[2], "dataset", "foo", "my-provider") + assert match_access_attributes_rule(rules[2], "model", "foo", "my-provider") + assert match_access_attributes_rule(rules[2], "vector_db", "bar", "my-provider") + assert not match_access_attributes_rule(rules[2], "dataset", "foo", "another-provider") + + assert match_access_attributes_rule(rules[3], "model", "foo", "my-provider") + assert match_access_attributes_rule(rules[3], "model", "bar", "my-provider") + assert match_access_attributes_rule(rules[3], "model", "foo", "another-provider") + assert not match_access_attributes_rule(rules[3], "dataset", "bar", "my-provider") + + assert match_access_attributes_rule(rules[4], "model", "foo", "my-provider") + assert match_access_attributes_rule(rules[4], "model", "bar", "my-provider") + assert match_access_attributes_rule(rules[4], "model", "foo", "another-provider") + assert match_access_attributes_rule(rules[4], "dataset", "bar", "my-provider") + assert match_access_attributes_rule(rules[4], "vector_db", "baz", "any-provider") + + +@pytest.fixture +def resource_access_attributes(rules): + return ResourceAccessAttributes(rules) + + +def dataset(identifier: str, provider_id: str) -> DatasetWithACL: + return DatasetWithACL( + identifier=identifier, + provider_id=provider_id, + purpose="eval/question-answer", + source=URIDataSource(uri="https://a.com/a.jsonl"), + ) + + +@pytest.mark.parametrize( + "resource,expected_roles", + [ + (ModelWithACL(identifier="my-model", provider_id="my-provider"), ["role1"]), + (ModelWithACL(identifier="another-model", provider_id="my-provider"), ["role3"]), + (ModelWithACL(identifier="third-model", provider_id="another-provider"), ["role4"]), + (dataset(identifier="my-dataset", provider_id="my-provider"), ["role2"]), + (dataset(identifier="another-dataset", provider_id="another-provider"), ["role5"]), + (ShieldWithACL(identifier="my-shield", provider_id="my-provider"), ["role3"]), + (ShieldWithACL(identifier="another-shield", provider_id="another-provider"), ["role5"]), + ], +) +def test_apply(resource_access_attributes, resource, expected_roles): + assert resource_access_attributes.apply(resource, None) + assert resource.access_attributes.roles == expected_roles + + +@pytest.fixture +def alternate_access_attributes_rules(): + config = """{ + "provider_type": "custom", + "config": {}, + "resource_attribute_rules": [ + { + "resource_type": "model", + "resource_id": "my-model", + "provider_id": "my-provider", + "attributes": { + "roles": ["roleA"] + } + }, + { + "resource_type": "model", + "attributes": { + "roles": ["roleB"] + } + } + ] +}""" + return AuthenticationConfig.model_validate_json(config).resource_attribute_rules + + +@pytest.fixture +def alternate_access_attributes(alternate_access_attributes_rules): + return ResourceAccessAttributes(alternate_access_attributes_rules) + + +@pytest.mark.parametrize( + "resource,expected_roles", + [ + (ModelWithACL(identifier="my-model", provider_id="my-provider"), ["roleA"]), + (ModelWithACL(identifier="another-model", provider_id="another-provider"), ["roleB"]), + (dataset(identifier="my-dataset", provider_id="my-provider"), None), + (dataset(identifier="another-dataset", provider_id="another-provider"), None), + ], +) +def test_apply_alternate(alternate_access_attributes, resource, expected_roles): + if expected_roles: + assert alternate_access_attributes.apply(resource, None) + assert resource.access_attributes.roles == expected_roles + else: + assert not alternate_access_attributes.apply(resource, None) + assert not resource.access_attributes + + +@pytest.fixture +def checked_attributes(alternate_access_attributes_rules): + attributes = ResourceAccessAttributes(alternate_access_attributes_rules) + attributes.enable_access_checks() + return attributes + + +@pytest.mark.parametrize( + "resource,user_attributes,fails,result", + [ + (ModelWithACL(identifier="my-model", provider_id="my-provider"), {"roles": ["roleA"]}, False, True), + (ModelWithACL(identifier="my-model", provider_id="my-provider"), {"roles": ["somethingelse"]}, True, False), + (dataset(identifier="my-dataset", provider_id="my-provider"), {"roles": ["somethingelse"]}, False, False), + (dataset(identifier="my-dataset", provider_id="my-provider"), None, False, False), + ], +) +def test_access_check_on_apply(checked_attributes, resource, user_attributes, fails, result): + if fails: + with pytest.raises(ValueError) as e: + checked_attributes.apply(resource, user_attributes) + assert "Access denied" in str(e.value) + assert not resource.access_attributes + else: + assert checked_attributes.apply(resource, user_attributes) == result diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index b5e9c2698..f3067ebe9 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -10,7 +10,8 @@ import pytest from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL +from llama_stack.distribution.datatypes import AccessAttributes, AccessAttributesRule, ModelWithACL +from llama_stack.distribution.resource_attributes import ResourceAccessAttributes from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable @@ -32,6 +33,7 @@ async def test_setup(cached_disk_dist_registry): routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry, + resource_attributes=ResourceAccessAttributes([]), ) yield cached_disk_dist_registry, routing_table @@ -223,3 +225,70 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup) } model = await routing_table.get_model("auto-access-model") assert model.identifier == "auto-access-model" + + +@pytest.fixture +async def test_setup_with_resource_attribute_rules(cached_disk_dist_registry): + mock_inference = Mock() + mock_inference.__provider_spec__ = MagicMock() + mock_inference.__provider_spec__.api = Api.inference + mock_inference.register_model = AsyncMock(side_effect=_return_model) + resource_attributes = ResourceAccessAttributes( + [ + # model-1 can be accessed by users in project foo or those who have admin role + AccessAttributesRule( + resource_type="model", resource_id="model-1", attributes=AccessAttributes(projects=["foo"]) + ), + # model-2 can be accessed by users in project bar or those who have admin role + AccessAttributesRule( + resource_type="model", resource_id="model-2", attributes=AccessAttributes(projects=["bar"]) + ), + # any other model can be accessed or created only by those who have admin role + AccessAttributesRule(resource_type="model", attributes=AccessAttributes(roles=["admin"])), + ] + ) + resource_attributes.enable_access_checks() + routing_table = ModelsRoutingTable( + impls_by_provider_id={"test_provider": mock_inference}, + dist_registry=cached_disk_dist_registry, + resource_attributes=resource_attributes, + ) + yield routing_table + + +@pytest.mark.asyncio +@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +async def test_with_resource_attribute_rules(mock_get_auth_attributes, test_setup_with_resource_attribute_rules): + routing_table = test_setup_with_resource_attribute_rules + mock_get_auth_attributes.return_value = { + "roles": ["admin"], + "projects": ["foo", "bar"], + } + await routing_table.register_model("model-1", provider_id="test_provider") + await routing_table.register_model("model-2", provider_id="test_provider") + await routing_table.register_model("model-3", provider_id="test_provider") + mock_get_auth_attributes.return_value = { + "roles": ["user"], + "projects": ["foo"], + } + model = await routing_table.get_model("model-1") + assert model.identifier == "model-1" + with pytest.raises(ValueError): + await routing_table.get_model("model-2") + with pytest.raises(ValueError): + await routing_table.get_model("model-3") + with pytest.raises(ValueError): + await routing_table.register_model("model-4", provider_id="test_provider") + + mock_get_auth_attributes.return_value = { + "roles": ["user"], + "projects": ["bar"], + } + model = await routing_table.get_model("model-2") + assert model.identifier == "model-2" + with pytest.raises(ValueError): + await routing_table.get_model("model-1") + with pytest.raises(ValueError): + await routing_table.get_model("model-3") + with pytest.raises(ValueError): + await routing_table.register_model("model-5", provider_id="test_provider") diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index 3af9535a0..e763bda87 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -19,6 +19,7 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.resource_attributes import ResourceAccessAttributes from llama_stack.distribution.routers.routers import InferenceRouter from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec @@ -102,7 +103,7 @@ async def test_resolve_impls_basic(): mock_module.get_provider_impl = AsyncMock(return_value=impl) sys.modules["test_module"] = mock_module - impls = await resolve_impls(run_config, provider_registry, dist_registry) + impls = await resolve_impls(run_config, provider_registry, dist_registry, ResourceAccessAttributes([])) assert Api.inference in impls assert isinstance(impls[Api.inference], InferenceRouter)