feat: allow access attributes for resources to be configured

This allows a set of rules to be defined for determining the access attributes to apply to
a particular resource. It also checks that the attributes determined for a new resource to
be registered are matched by attributes associated with the request context.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-05-06 18:54:58 +01:00
parent 0cc0731189
commit 490e77bffa
10 changed files with 402 additions and 19 deletions

View file

@ -10,7 +10,8 @@ import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.datatypes import AccessAttributes, AccessAttributesRule, ModelWithACL
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
@ -32,6 +33,7 @@ async def test_setup(cached_disk_dist_registry):
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
resource_attributes=ResourceAccessAttributes([]),
)
yield cached_disk_dist_registry, routing_table
@ -223,3 +225,70 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
}
model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model"
@pytest.fixture
async def test_setup_with_resource_attribute_rules(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
resource_attributes = ResourceAccessAttributes(
[
# model-1 can be accessed by users in project foo or those who have admin role
AccessAttributesRule(
resource_type="model", resource_id="model-1", attributes=AccessAttributes(projects=["foo"])
),
# model-2 can be accessed by users in project bar or those who have admin role
AccessAttributesRule(
resource_type="model", resource_id="model-2", attributes=AccessAttributes(projects=["bar"])
),
# any other model can be accessed or created only by those who have admin role
AccessAttributesRule(resource_type="model", attributes=AccessAttributes(roles=["admin"])),
]
)
resource_attributes.enable_access_checks()
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
resource_attributes=resource_attributes,
)
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
async def test_with_resource_attribute_rules(mock_get_auth_attributes, test_setup_with_resource_attribute_rules):
routing_table = test_setup_with_resource_attribute_rules
mock_get_auth_attributes.return_value = {
"roles": ["admin"],
"projects": ["foo", "bar"],
}
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
mock_get_auth_attributes.return_value = {
"roles": ["user"],
"projects": ["foo"],
}
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(ValueError):
await routing_table.register_model("model-4", provider_id="test_provider")
mock_get_auth_attributes.return_value = {
"roles": ["user"],
"projects": ["bar"],
}
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(ValueError):
await routing_table.register_model("model-5", provider_id="test_provider")