llama-stack-mirror/tests/unit/distribution/routers/test_routing_tables.py
grs 7c1998db25
feat: fine grained access control policy (#2264)
This allows a set of rules to be defined for determining access to
resources. The rules are (loosely) based on the cedar policy format.

A rule defines a list of action either to permit or to forbid. It may
specify a principal or a resource that must match for the rule to take
effect. It may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies.

A list of rules is held for each type to be protected and tried in order
to find a match. If a match is found, the request is permitted or
forbidden depening on the type of rule. If no match is found, the
request is denied. If no rules are specified for a given type, a rule
that allows any action as long as the resource attributes match the user
attributes is added (i.e. the previous behaviour is the default.

Some examples in yaml:

```
    model:
    - permit:
      principal: user-1
      actions: [create, read, delete]
      comment: user-1 has full access to all models
    - permit:
      principal: user-2
      actions: [read]
      resource: model-1
      comment: user-2 has read access to model-1 only
    - permit:
      actions: [read]
      when:
        user_in: resource.namespaces
      comment: any user has read access to models with matching attributes
    vector_db:
    - forbid:
      actions: [create, read, delete]
      unless:
        user_in: role::admin
      comment: only user with admin role can use vector_db resources
```

---------

Signed-off-by: Gordon Sim <gsim@redhat.com>
2025-06-03 14:51:12 -07:00

298 lines
10 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Unit tests for the routing tables
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl:
def __init__(self, api: Api):
self.api = api
@property
def __provider_spec__(self):
_provider_spec = AsyncMock()
_provider_spec.api = self.api
return _provider_spec
class InferenceImpl(Impl):
def __init__(self):
super().__init__(Api.inference)
async def register_model(self, model: Model):
return model
async def unregister_model(self, model_id: str):
return model_id
class SafetyImpl(Impl):
def __init__(self):
super().__init__(Api.safety)
async def register_shield(self, shield: Shield):
return shield
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
class DatasetsImpl(Impl):
def __init__(self):
super().__init__(Api.datasetio)
async def register_dataset(self, dataset: Dataset):
return dataset
async def unregister_dataset(self, dataset_id: str):
return dataset_id
class ScoringFunctionsImpl(Impl):
def __init__(self):
super().__init__(Api.scoring)
async def list_scoring_functions(self):
return []
async def register_scoring_function(self, scoring_fn):
return scoring_fn
class BenchmarksImpl(Impl):
def __init__(self):
super().__init__(Api.eval)
async def register_benchmark(self, benchmark):
return benchmark
class ToolGroupsImpl(Impl):
def __init__(self):
super().__init__(Api.tool_runtime)
async def register_toolgroup(self, toolgroup: ToolGroup):
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str):
return toolgroup_id
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
return ListToolDefsResponse(
data=[
ToolDef(
name="test-tool",
description="Test tool",
parameters=[ToolParameter(name="test-param", description="Test param", parameter_type="string")],
)
]
)
@pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple models and verify listing
await table.register_model(model_id="test-model", provider_id="test_provider")
await table.register_model(model_id="test-model-2", provider_id="test_provider")
models = await table.list_models()
assert len(models.data) == 2
model_ids = {m.identifier for m in models.data}
assert "test-model" in model_ids
assert "test-model-2" in model_ids
# Test openai list models
openai_models = await table.openai_list_models()
assert len(openai_models.data) == 2
openai_model_ids = {m.id for m in openai_models.data}
assert "test-model" in openai_model_ids
assert "test-model-2" in openai_model_ids
# Test get_object_by_identifier
model = await table.get_object_by_identifier("model", "test-model")
assert model is not None
assert model.identifier == "test-model"
# Test get_object_by_identifier on non-existent object
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
assert non_existent is None
await table.unregister_model(model_id="test-model")
await table.unregister_model(model_id="test-model-2")
models = await table.list_models()
assert len(models.data) == 0
# Test openai list models
openai_models = await table.openai_list_models()
assert len(openai_models.data) == 0
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple shields and verify listing
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
shields = await table.list_shields()
assert len(shields.data) == 2
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" in shield_ids
assert "test-shield-2" in shield_ids
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple datasets and verify listing
await table.register_dataset(
dataset_id="test-dataset", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource(uri="test-uri")
)
await table.register_dataset(
dataset_id="test-dataset-2", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource(uri="test-uri-2")
)
datasets = await table.list_datasets()
assert len(datasets.data) == 2
dataset_ids = {d.identifier for d in datasets.data}
assert "test-dataset" in dataset_ids
assert "test-dataset-2" in dataset_ids
await table.unregister_dataset(dataset_id="test-dataset")
await table.unregister_dataset(dataset_id="test-dataset-2")
datasets = await table.list_datasets()
assert len(datasets.data) == 0
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple scoring functions and verify listing
await table.register_scoring_function(
scoring_fn_id="test-scoring-fn",
provider_id="test_provider",
description="Test scoring function",
return_type=NumberType(),
)
await table.register_scoring_function(
scoring_fn_id="test-scoring-fn-2",
provider_id="test_provider",
description="Another test scoring function",
return_type=NumberType(),
)
scoring_functions = await table.list_scoring_functions()
assert len(scoring_functions.data) == 2
scoring_fn_ids = {fn.identifier for fn in scoring_functions.data}
assert "test-scoring-fn" in scoring_fn_ids
assert "test-scoring-fn-2" in scoring_fn_ids
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple benchmarks and verify listing
await table.register_benchmark(
benchmark_id="test-benchmark",
dataset_id="test-dataset",
scoring_functions=["test-scoring-fn", "test-scoring-fn-2"],
)
benchmarks = await table.list_benchmarks()
assert len(benchmarks.data) == 1
benchmark_ids = {b.identifier for b in benchmarks.data}
assert "test-benchmark" in benchmark_ids
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple tool groups and verify listing
await table.register_tool_group(
toolgroup_id="test-toolgroup",
provider_id="test_provider",
)
tool_groups = await table.list_tool_groups()
assert len(tool_groups.data) == 1
tool_group_ids = {tg.identifier for tg in tool_groups.data}
assert "test-toolgroup" in tool_group_ids
await table.unregister_toolgroup(toolgroup_id="test-toolgroup")
tool_groups = await table.list_tool_groups()
assert len(tool_groups.data) == 0