mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Add fixtures for SqliteKVStore, DiskDistributionRegistry and CachedDiskDistributionRegistry. And use them in tests that had all been duplicating similar setups. ## Test Plan unit tests continue to run Signed-off-by: Derek Higgins <derekh@redhat.com>
225 lines
8.7 KiB
Python
225 lines
8.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.datatypes import Api
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
|
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
|
|
|
|
|
class AsyncMock(MagicMock):
|
|
async def __call__(self, *args, **kwargs):
|
|
return super().__call__(*args, **kwargs)
|
|
|
|
|
|
def _return_model(model):
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_setup(cached_disk_dist_registry):
|
|
mock_inference = Mock()
|
|
mock_inference.__provider_spec__ = MagicMock()
|
|
mock_inference.__provider_spec__.api = Api.inference
|
|
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
|
routing_table = ModelsRoutingTable(
|
|
impls_by_provider_id={"test_provider": mock_inference},
|
|
dist_registry=cached_disk_dist_registry,
|
|
)
|
|
yield cached_disk_dist_registry, routing_table
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|
registry, routing_table = test_setup
|
|
model_public = ModelWithACL(
|
|
identifier="model-public",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-public",
|
|
model_type=ModelType.llm,
|
|
)
|
|
model_admin_only = ModelWithACL(
|
|
identifier="model-admin",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-admin",
|
|
model_type=ModelType.llm,
|
|
access_attributes=AccessAttributes(roles=["admin"]),
|
|
)
|
|
model_data_scientist = ModelWithACL(
|
|
identifier="model-data-scientist",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-data-scientist",
|
|
model_type=ModelType.llm,
|
|
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
|
)
|
|
await registry.register(model_public)
|
|
await registry.register(model_admin_only)
|
|
await registry.register(model_data_scientist)
|
|
|
|
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
|
all_models = await routing_table.list_models()
|
|
assert len(all_models.data) == 2
|
|
|
|
model = await routing_table.get_model("model-public")
|
|
assert model.identifier == "model-public"
|
|
model = await routing_table.get_model("model-admin")
|
|
assert model.identifier == "model-admin"
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-data-scientist")
|
|
|
|
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
|
all_models = await routing_table.list_models()
|
|
assert len(all_models.data) == 1
|
|
assert all_models.data[0].identifier == "model-public"
|
|
model = await routing_table.get_model("model-public")
|
|
assert model.identifier == "model-public"
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-admin")
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-data-scientist")
|
|
|
|
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
|
all_models = await routing_table.list_models()
|
|
assert len(all_models.data) == 2
|
|
model_ids = [m.identifier for m in all_models.data]
|
|
assert "model-public" in model_ids
|
|
assert "model-data-scientist" in model_ids
|
|
assert "model-admin" not in model_ids
|
|
model = await routing_table.get_model("model-public")
|
|
assert model.identifier == "model-public"
|
|
model = await routing_table.get_model("model-data-scientist")
|
|
assert model.identifier == "model-data-scientist"
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-admin")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
|
registry, routing_table = test_setup
|
|
model_public = ModelWithACL(
|
|
identifier="model-updates",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-updates",
|
|
model_type=ModelType.llm,
|
|
)
|
|
await registry.register(model_public)
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": ["user"],
|
|
}
|
|
model = await routing_table.get_model("model-updates")
|
|
assert model.identifier == "model-updates"
|
|
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
|
await registry.update(model_public)
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": ["user"],
|
|
}
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-updates")
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": ["admin"],
|
|
}
|
|
model = await routing_table.get_model("model-updates")
|
|
assert model.identifier == "model-updates"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
|
registry, routing_table = test_setup
|
|
model = ModelWithACL(
|
|
identifier="model-empty-attrs",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-empty-attrs",
|
|
model_type=ModelType.llm,
|
|
access_attributes=AccessAttributes(),
|
|
)
|
|
await registry.register(model)
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": [],
|
|
}
|
|
result = await routing_table.get_model("model-empty-attrs")
|
|
assert result.identifier == "model-empty-attrs"
|
|
all_models = await routing_table.list_models()
|
|
model_ids = [m.identifier for m in all_models.data]
|
|
assert "model-empty-attrs" in model_ids
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
|
registry, routing_table = test_setup
|
|
model_public = ModelWithACL(
|
|
identifier="model-public-2",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-public-2",
|
|
model_type=ModelType.llm,
|
|
)
|
|
model_restricted = ModelWithACL(
|
|
identifier="model-restricted",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-restricted",
|
|
model_type=ModelType.llm,
|
|
access_attributes=AccessAttributes(roles=["admin"]),
|
|
)
|
|
await registry.register(model_public)
|
|
await registry.register(model_restricted)
|
|
mock_get_auth_attributes.return_value = None
|
|
model = await routing_table.get_model("model-public-2")
|
|
assert model.identifier == "model-public-2"
|
|
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-restricted")
|
|
|
|
all_models = await routing_table.list_models()
|
|
assert len(all_models.data) == 1
|
|
assert all_models.data[0].identifier == "model-public-2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
|
"""Test that newly created resources inherit access attributes from their creator."""
|
|
registry, routing_table = test_setup
|
|
|
|
# Set creator's attributes
|
|
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
|
mock_get_auth_attributes.return_value = creator_attributes
|
|
|
|
# Create model without explicit access attributes
|
|
model = ModelWithACL(
|
|
identifier="auto-access-model",
|
|
provider_id="test_provider",
|
|
provider_resource_id="auto-access-model",
|
|
model_type=ModelType.llm,
|
|
)
|
|
await routing_table.register_object(model)
|
|
|
|
# Verify the model got creator's attributes
|
|
registered_model = await routing_table.get_model("auto-access-model")
|
|
assert registered_model.access_attributes is not None
|
|
assert registered_model.access_attributes.roles == ["data-scientist"]
|
|
assert registered_model.access_attributes.teams == ["ml-team"]
|
|
assert registered_model.access_attributes.projects == ["llama-3"]
|
|
|
|
# Verify another user without matching attributes can't access it
|
|
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("auto-access-model")
|
|
|
|
# But a user with matching attributes can
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": ["data-scientist", "engineer"],
|
|
"teams": ["ml-team", "platform-team"],
|
|
"projects": ["llama-3"],
|
|
}
|
|
model = await routing_table.get_model("auto-access-model")
|
|
assert model.identifier == "auto-access-model"
|