feat: allow access attributes for resources to be configured

This allows a set of rules to be defined for determining the access attributes to apply to
a particular resource. It also checks that the attributes determined for a new resource to
be registered are matched by attributes associated with the request context.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-05-06 18:54:58 +01:00
parent 0cc0731189
commit 490e77bffa
10 changed files with 402 additions and 19 deletions

View file

@ -17,6 +17,7 @@ from llama_stack.apis.models.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
from llama_stack.distribution.routers.routing_tables import (
BenchmarksRoutingTable,
DatasetsRoutingTable,
@ -123,7 +124,9 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
table = ModelsRoutingTable(
{"test_provider": InferenceImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
)
await table.initialize()
# Register multiple models and verify listing
@ -165,7 +168,9 @@ async def test_models_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
table = ShieldsRoutingTable(
{"test_provider": SafetyImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
)
await table.initialize()
# Register multiple shields and verify listing
@ -181,10 +186,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
table = VectorDBsRoutingTable(
{"test_provider": VectorDBImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
)
await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
m_table = ModelsRoutingTable(
{"test_providere": InferenceImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
)
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
@ -211,7 +220,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]))
await table.initialize()
# Register multiple datasets and verify listing
@ -237,7 +246,9 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
table = ScoringFunctionsRoutingTable(
{"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
)
await table.initialize()
# Register multiple scoring functions and verify listing
@ -263,7 +274,9 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
table = BenchmarksRoutingTable(
{"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
)
await table.initialize()
# Register multiple benchmarks and verify listing
@ -281,7 +294,9 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
table = ToolGroupsRoutingTable(
{"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
)
await table.initialize()
# Register multiple tool groups and verify listing

View file

@ -0,0 +1,209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.datasets import URIDataSource
from llama_stack.distribution.datatypes import (
AuthenticationConfig,
DatasetWithACL,
ModelWithACL,
ShieldWithACL,
)
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes, match_access_attributes_rule
@pytest.fixture
def rules():
config = """{
"provider_type": "custom",
"config": {},
"resource_attribute_rules": [
{
"resource_type": "model",
"resource_id": "my-model",
"provider_id": "my-provider",
"attributes": {
"roles": ["role1"],
"teams": ["team1"],
"projects": ["project1"],
"namespaces": ["namespace1"]
}
},
{
"resource_type": "dataset",
"provider_id": "my-provider",
"attributes": {
"roles": ["role2"],
"teams": ["team2"],
"projects": ["project2"],
"namespaces": ["namespace2"]
}
},
{
"provider_id": "my-provider",
"attributes": {
"roles": ["role3"],
"teams": ["team3"],
"projects": ["project3"],
"namespaces": ["namespace3"]
}
},
{
"resource_type": "model",
"attributes": {
"roles": ["role4"],
"teams": ["team4"],
"projects": ["project4"],
"namespaces": ["namespace4"]
}
},
{
"attributes": {
"roles": ["role5"],
"teams": ["team5"],
"projects": ["project5"],
"namespaces": ["namespace5"]
}
}
]
}"""
return AuthenticationConfig.model_validate_json(config).resource_attribute_rules
def test_match_access_attributes_rule(rules):
assert match_access_attributes_rule(rules[0], "model", "my-model", "my-provider")
assert not match_access_attributes_rule(rules[0], "model", "another-model", "my-provider")
assert not match_access_attributes_rule(rules[0], "model", "my-model", "another-provider")
assert not match_access_attributes_rule(rules[0], "dataset", "my-model", "my-provider")
assert match_access_attributes_rule(rules[1], "dataset", "my-data", "my-provider")
assert match_access_attributes_rule(rules[1], "dataset", "different-data", "my-provider")
assert match_access_attributes_rule(rules[1], "dataset", "any-data", "my-provider")
assert not match_access_attributes_rule(rules[1], "model", "a-model", "my-provider")
assert not match_access_attributes_rule(rules[1], "dataset", "any-data", "another-provider")
assert not match_access_attributes_rule(rules[1], "model", "my-model", "my-provider")
assert match_access_attributes_rule(rules[2], "dataset", "foo", "my-provider")
assert match_access_attributes_rule(rules[2], "model", "foo", "my-provider")
assert match_access_attributes_rule(rules[2], "vector_db", "bar", "my-provider")
assert not match_access_attributes_rule(rules[2], "dataset", "foo", "another-provider")
assert match_access_attributes_rule(rules[3], "model", "foo", "my-provider")
assert match_access_attributes_rule(rules[3], "model", "bar", "my-provider")
assert match_access_attributes_rule(rules[3], "model", "foo", "another-provider")
assert not match_access_attributes_rule(rules[3], "dataset", "bar", "my-provider")
assert match_access_attributes_rule(rules[4], "model", "foo", "my-provider")
assert match_access_attributes_rule(rules[4], "model", "bar", "my-provider")
assert match_access_attributes_rule(rules[4], "model", "foo", "another-provider")
assert match_access_attributes_rule(rules[4], "dataset", "bar", "my-provider")
assert match_access_attributes_rule(rules[4], "vector_db", "baz", "any-provider")
@pytest.fixture
def resource_access_attributes(rules):
return ResourceAccessAttributes(rules)
def dataset(identifier: str, provider_id: str) -> DatasetWithACL:
return DatasetWithACL(
identifier=identifier,
provider_id=provider_id,
purpose="eval/question-answer",
source=URIDataSource(uri="https://a.com/a.jsonl"),
)
@pytest.mark.parametrize(
"resource,expected_roles",
[
(ModelWithACL(identifier="my-model", provider_id="my-provider"), ["role1"]),
(ModelWithACL(identifier="another-model", provider_id="my-provider"), ["role3"]),
(ModelWithACL(identifier="third-model", provider_id="another-provider"), ["role4"]),
(dataset(identifier="my-dataset", provider_id="my-provider"), ["role2"]),
(dataset(identifier="another-dataset", provider_id="another-provider"), ["role5"]),
(ShieldWithACL(identifier="my-shield", provider_id="my-provider"), ["role3"]),
(ShieldWithACL(identifier="another-shield", provider_id="another-provider"), ["role5"]),
],
)
def test_apply(resource_access_attributes, resource, expected_roles):
assert resource_access_attributes.apply(resource, None)
assert resource.access_attributes.roles == expected_roles
@pytest.fixture
def alternate_access_attributes_rules():
config = """{
"provider_type": "custom",
"config": {},
"resource_attribute_rules": [
{
"resource_type": "model",
"resource_id": "my-model",
"provider_id": "my-provider",
"attributes": {
"roles": ["roleA"]
}
},
{
"resource_type": "model",
"attributes": {
"roles": ["roleB"]
}
}
]
}"""
return AuthenticationConfig.model_validate_json(config).resource_attribute_rules
@pytest.fixture
def alternate_access_attributes(alternate_access_attributes_rules):
return ResourceAccessAttributes(alternate_access_attributes_rules)
@pytest.mark.parametrize(
"resource,expected_roles",
[
(ModelWithACL(identifier="my-model", provider_id="my-provider"), ["roleA"]),
(ModelWithACL(identifier="another-model", provider_id="another-provider"), ["roleB"]),
(dataset(identifier="my-dataset", provider_id="my-provider"), None),
(dataset(identifier="another-dataset", provider_id="another-provider"), None),
],
)
def test_apply_alternate(alternate_access_attributes, resource, expected_roles):
if expected_roles:
assert alternate_access_attributes.apply(resource, None)
assert resource.access_attributes.roles == expected_roles
else:
assert not alternate_access_attributes.apply(resource, None)
assert not resource.access_attributes
@pytest.fixture
def checked_attributes(alternate_access_attributes_rules):
attributes = ResourceAccessAttributes(alternate_access_attributes_rules)
attributes.enable_access_checks()
return attributes
@pytest.mark.parametrize(
"resource,user_attributes,fails,result",
[
(ModelWithACL(identifier="my-model", provider_id="my-provider"), {"roles": ["roleA"]}, False, True),
(ModelWithACL(identifier="my-model", provider_id="my-provider"), {"roles": ["somethingelse"]}, True, False),
(dataset(identifier="my-dataset", provider_id="my-provider"), {"roles": ["somethingelse"]}, False, False),
(dataset(identifier="my-dataset", provider_id="my-provider"), None, False, False),
],
)
def test_access_check_on_apply(checked_attributes, resource, user_attributes, fails, result):
if fails:
with pytest.raises(ValueError) as e:
checked_attributes.apply(resource, user_attributes)
assert "Access denied" in str(e.value)
assert not resource.access_attributes
else:
assert checked_attributes.apply(resource, user_attributes) == result

View file

@ -10,7 +10,8 @@ import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.datatypes import AccessAttributes, AccessAttributesRule, ModelWithACL
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
@ -32,6 +33,7 @@ async def test_setup(cached_disk_dist_registry):
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
resource_attributes=ResourceAccessAttributes([]),
)
yield cached_disk_dist_registry, routing_table
@ -223,3 +225,70 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
}
model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model"
@pytest.fixture
async def test_setup_with_resource_attribute_rules(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
resource_attributes = ResourceAccessAttributes(
[
# model-1 can be accessed by users in project foo or those who have admin role
AccessAttributesRule(
resource_type="model", resource_id="model-1", attributes=AccessAttributes(projects=["foo"])
),
# model-2 can be accessed by users in project bar or those who have admin role
AccessAttributesRule(
resource_type="model", resource_id="model-2", attributes=AccessAttributes(projects=["bar"])
),
# any other model can be accessed or created only by those who have admin role
AccessAttributesRule(resource_type="model", attributes=AccessAttributes(roles=["admin"])),
]
)
resource_attributes.enable_access_checks()
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
resource_attributes=resource_attributes,
)
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
async def test_with_resource_attribute_rules(mock_get_auth_attributes, test_setup_with_resource_attribute_rules):
routing_table = test_setup_with_resource_attribute_rules
mock_get_auth_attributes.return_value = {
"roles": ["admin"],
"projects": ["foo", "bar"],
}
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
mock_get_auth_attributes.return_value = {
"roles": ["user"],
"projects": ["foo"],
}
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(ValueError):
await routing_table.register_model("model-4", provider_id="test_provider")
mock_get_auth_attributes.return_value = {
"roles": ["user"],
"projects": ["bar"],
}
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(ValueError):
await routing_table.register_model("model-5", provider_id="test_provider")

View file

@ -19,6 +19,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig,
)
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
from llama_stack.distribution.routers.routers import InferenceRouter
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
@ -102,7 +103,7 @@ async def test_resolve_impls_basic():
mock_module.get_provider_impl = AsyncMock(return_value=impl)
sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry)
impls = await resolve_impls(run_config, provider_registry, dist_registry, ResourceAccessAttributes([]))
assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter)