mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-27 23:22:01 +00:00
feat: allow access attributes for resources to be configured
This allows a set of rules to be defined for determining the access attributes to apply to a particular resource. It also checks that the attributes determined for a new resource to be registered are matched by attributes associated with the request context. Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
0cc0731189
commit
490e77bffa
10 changed files with 402 additions and 19 deletions
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.models.models import Model, ModelType
|
|||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes
|
||||
from llama_stack.distribution.routers.routing_tables import (
|
||||
BenchmarksRoutingTable,
|
||||
DatasetsRoutingTable,
|
||||
|
|
@ -123,7 +124,9 @@ class ToolGroupsImpl(Impl):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
|
||||
table = ModelsRoutingTable(
|
||||
{"test_provider": InferenceImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple models and verify listing
|
||||
|
|
@ -165,7 +168,9 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
|
||||
table = ShieldsRoutingTable(
|
||||
{"test_provider": SafetyImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple shields and verify listing
|
||||
|
|
@ -181,10 +186,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
|
||||
table = VectorDBsRoutingTable(
|
||||
{"test_provider": VectorDBImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
|
||||
m_table = ModelsRoutingTable(
|
||||
{"test_providere": InferenceImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
|
|
@ -211,7 +220,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
|
||||
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([]))
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple datasets and verify listing
|
||||
|
|
@ -237,7 +246,9 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
|
||||
table = ScoringFunctionsRoutingTable(
|
||||
{"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple scoring functions and verify listing
|
||||
|
|
@ -263,7 +274,9 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
|
||||
table = BenchmarksRoutingTable(
|
||||
{"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple benchmarks and verify listing
|
||||
|
|
@ -281,7 +294,9 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
|
||||
table = ToolGroupsRoutingTable(
|
||||
{"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, ResourceAccessAttributes([])
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple tool groups and verify listing
|
||||
|
|
|
|||
209
tests/unit/distribution/test_resource_attributes.py
Normal file
209
tests/unit/distribution/test_resource_attributes.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datasets import URIDataSource
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
DatasetWithACL,
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
)
|
||||
from llama_stack.distribution.resource_attributes import ResourceAccessAttributes, match_access_attributes_rule
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rules():
|
||||
config = """{
|
||||
"provider_type": "custom",
|
||||
"config": {},
|
||||
"resource_attribute_rules": [
|
||||
{
|
||||
"resource_type": "model",
|
||||
"resource_id": "my-model",
|
||||
"provider_id": "my-provider",
|
||||
"attributes": {
|
||||
"roles": ["role1"],
|
||||
"teams": ["team1"],
|
||||
"projects": ["project1"],
|
||||
"namespaces": ["namespace1"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"resource_type": "dataset",
|
||||
"provider_id": "my-provider",
|
||||
"attributes": {
|
||||
"roles": ["role2"],
|
||||
"teams": ["team2"],
|
||||
"projects": ["project2"],
|
||||
"namespaces": ["namespace2"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"provider_id": "my-provider",
|
||||
"attributes": {
|
||||
"roles": ["role3"],
|
||||
"teams": ["team3"],
|
||||
"projects": ["project3"],
|
||||
"namespaces": ["namespace3"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"resource_type": "model",
|
||||
"attributes": {
|
||||
"roles": ["role4"],
|
||||
"teams": ["team4"],
|
||||
"projects": ["project4"],
|
||||
"namespaces": ["namespace4"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"attributes": {
|
||||
"roles": ["role5"],
|
||||
"teams": ["team5"],
|
||||
"projects": ["project5"],
|
||||
"namespaces": ["namespace5"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}"""
|
||||
return AuthenticationConfig.model_validate_json(config).resource_attribute_rules
|
||||
|
||||
|
||||
def test_match_access_attributes_rule(rules):
|
||||
assert match_access_attributes_rule(rules[0], "model", "my-model", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[0], "model", "another-model", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[0], "model", "my-model", "another-provider")
|
||||
assert not match_access_attributes_rule(rules[0], "dataset", "my-model", "my-provider")
|
||||
|
||||
assert match_access_attributes_rule(rules[1], "dataset", "my-data", "my-provider")
|
||||
assert match_access_attributes_rule(rules[1], "dataset", "different-data", "my-provider")
|
||||
assert match_access_attributes_rule(rules[1], "dataset", "any-data", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[1], "model", "a-model", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[1], "dataset", "any-data", "another-provider")
|
||||
assert not match_access_attributes_rule(rules[1], "model", "my-model", "my-provider")
|
||||
|
||||
assert match_access_attributes_rule(rules[2], "dataset", "foo", "my-provider")
|
||||
assert match_access_attributes_rule(rules[2], "model", "foo", "my-provider")
|
||||
assert match_access_attributes_rule(rules[2], "vector_db", "bar", "my-provider")
|
||||
assert not match_access_attributes_rule(rules[2], "dataset", "foo", "another-provider")
|
||||
|
||||
assert match_access_attributes_rule(rules[3], "model", "foo", "my-provider")
|
||||
assert match_access_attributes_rule(rules[3], "model", "bar", "my-provider")
|
||||
assert match_access_attributes_rule(rules[3], "model", "foo", "another-provider")
|
||||
assert not match_access_attributes_rule(rules[3], "dataset", "bar", "my-provider")
|
||||
|
||||
assert match_access_attributes_rule(rules[4], "model", "foo", "my-provider")
|
||||
assert match_access_attributes_rule(rules[4], "model", "bar", "my-provider")
|
||||
assert match_access_attributes_rule(rules[4], "model", "foo", "another-provider")
|
||||
assert match_access_attributes_rule(rules[4], "dataset", "bar", "my-provider")
|
||||
assert match_access_attributes_rule(rules[4], "vector_db", "baz", "any-provider")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resource_access_attributes(rules):
|
||||
return ResourceAccessAttributes(rules)
|
||||
|
||||
|
||||
def dataset(identifier: str, provider_id: str) -> DatasetWithACL:
|
||||
return DatasetWithACL(
|
||||
identifier=identifier,
|
||||
provider_id=provider_id,
|
||||
purpose="eval/question-answer",
|
||||
source=URIDataSource(uri="https://a.com/a.jsonl"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"resource,expected_roles",
|
||||
[
|
||||
(ModelWithACL(identifier="my-model", provider_id="my-provider"), ["role1"]),
|
||||
(ModelWithACL(identifier="another-model", provider_id="my-provider"), ["role3"]),
|
||||
(ModelWithACL(identifier="third-model", provider_id="another-provider"), ["role4"]),
|
||||
(dataset(identifier="my-dataset", provider_id="my-provider"), ["role2"]),
|
||||
(dataset(identifier="another-dataset", provider_id="another-provider"), ["role5"]),
|
||||
(ShieldWithACL(identifier="my-shield", provider_id="my-provider"), ["role3"]),
|
||||
(ShieldWithACL(identifier="another-shield", provider_id="another-provider"), ["role5"]),
|
||||
],
|
||||
)
|
||||
def test_apply(resource_access_attributes, resource, expected_roles):
|
||||
assert resource_access_attributes.apply(resource, None)
|
||||
assert resource.access_attributes.roles == expected_roles
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def alternate_access_attributes_rules():
|
||||
config = """{
|
||||
"provider_type": "custom",
|
||||
"config": {},
|
||||
"resource_attribute_rules": [
|
||||
{
|
||||
"resource_type": "model",
|
||||
"resource_id": "my-model",
|
||||
"provider_id": "my-provider",
|
||||
"attributes": {
|
||||
"roles": ["roleA"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"resource_type": "model",
|
||||
"attributes": {
|
||||
"roles": ["roleB"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}"""
|
||||
return AuthenticationConfig.model_validate_json(config).resource_attribute_rules
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def alternate_access_attributes(alternate_access_attributes_rules):
|
||||
return ResourceAccessAttributes(alternate_access_attributes_rules)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"resource,expected_roles",
|
||||
[
|
||||
(ModelWithACL(identifier="my-model", provider_id="my-provider"), ["roleA"]),
|
||||
(ModelWithACL(identifier="another-model", provider_id="another-provider"), ["roleB"]),
|
||||
(dataset(identifier="my-dataset", provider_id="my-provider"), None),
|
||||
(dataset(identifier="another-dataset", provider_id="another-provider"), None),
|
||||
],
|
||||
)
|
||||
def test_apply_alternate(alternate_access_attributes, resource, expected_roles):
|
||||
if expected_roles:
|
||||
assert alternate_access_attributes.apply(resource, None)
|
||||
assert resource.access_attributes.roles == expected_roles
|
||||
else:
|
||||
assert not alternate_access_attributes.apply(resource, None)
|
||||
assert not resource.access_attributes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def checked_attributes(alternate_access_attributes_rules):
|
||||
attributes = ResourceAccessAttributes(alternate_access_attributes_rules)
|
||||
attributes.enable_access_checks()
|
||||
return attributes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"resource,user_attributes,fails,result",
|
||||
[
|
||||
(ModelWithACL(identifier="my-model", provider_id="my-provider"), {"roles": ["roleA"]}, False, True),
|
||||
(ModelWithACL(identifier="my-model", provider_id="my-provider"), {"roles": ["somethingelse"]}, True, False),
|
||||
(dataset(identifier="my-dataset", provider_id="my-provider"), {"roles": ["somethingelse"]}, False, False),
|
||||
(dataset(identifier="my-dataset", provider_id="my-provider"), None, False, False),
|
||||
],
|
||||
)
|
||||
def test_access_check_on_apply(checked_attributes, resource, user_attributes, fails, result):
|
||||
if fails:
|
||||
with pytest.raises(ValueError) as e:
|
||||
checked_attributes.apply(resource, user_attributes)
|
||||
assert "Access denied" in str(e.value)
|
||||
assert not resource.access_attributes
|
||||
else:
|
||||
assert checked_attributes.apply(resource, user_attributes) == result
|
||||
Loading…
Add table
Add a link
Reference in a new issue