# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from unittest.mock import MagicMock, Mock, patch import pytest from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import AccessAttributes, AccessAttributesRule, ModelWithACL from llama_stack.distribution.resource_attributes import ResourceAccessAttributes from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable class AsyncMock(MagicMock): async def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) def _return_model(model): return model @pytest.fixture async def test_setup(cached_disk_dist_registry): mock_inference = Mock() mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry, resource_attributes=ResourceAccessAttributes([]), ) yield cached_disk_dist_registry, routing_table @pytest.mark.asyncio @patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( identifier="model-public", provider_id="test_provider", provider_resource_id="model-public", model_type=ModelType.llm, ) model_admin_only = ModelWithACL( identifier="model-admin", provider_id="test_provider", provider_resource_id="model-admin", model_type=ModelType.llm, access_attributes=AccessAttributes(roles=["admin"]), ) model_data_scientist = ModelWithACL( identifier="model-data-scientist", provider_id="test_provider", provider_resource_id="model-data-scientist", model_type=ModelType.llm, access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]), ) await registry.register(model_public) await registry.register(model_admin_only) await registry.register(model_data_scientist) mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]} all_models = await routing_table.list_models() assert len(all_models.data) == 2 model = await routing_table.get_model("model-public") assert model.identifier == "model-public" model = await routing_table.get_model("model-admin") assert model.identifier == "model-admin" with pytest.raises(ValueError): await routing_table.get_model("model-data-scientist") mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]} all_models = await routing_table.list_models() assert len(all_models.data) == 1 assert all_models.data[0].identifier == "model-public" model = await routing_table.get_model("model-public") assert model.identifier == "model-public" with pytest.raises(ValueError): await routing_table.get_model("model-admin") with pytest.raises(ValueError): await routing_table.get_model("model-data-scientist") mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]} all_models = await routing_table.list_models() assert len(all_models.data) == 2 model_ids = [m.identifier for m in all_models.data] assert "model-public" in model_ids assert "model-data-scientist" in model_ids assert "model-admin" not in model_ids model = await routing_table.get_model("model-public") assert model.identifier == "model-public" model = await routing_table.get_model("model-data-scientist") assert model.identifier == "model-data-scientist" with pytest.raises(ValueError): await routing_table.get_model("model-admin") @pytest.mark.asyncio @patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( identifier="model-updates", provider_id="test_provider", provider_resource_id="model-updates", model_type=ModelType.llm, ) await registry.register(model_public) mock_get_auth_attributes.return_value = { "roles": ["user"], } model = await routing_table.get_model("model-updates") assert model.identifier == "model-updates" model_public.access_attributes = AccessAttributes(roles=["admin"]) await registry.update(model_public) mock_get_auth_attributes.return_value = { "roles": ["user"], } with pytest.raises(ValueError): await routing_table.get_model("model-updates") mock_get_auth_attributes.return_value = { "roles": ["admin"], } model = await routing_table.get_model("model-updates") assert model.identifier == "model-updates" @pytest.mark.asyncio @patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model = ModelWithACL( identifier="model-empty-attrs", provider_id="test_provider", provider_resource_id="model-empty-attrs", model_type=ModelType.llm, access_attributes=AccessAttributes(), ) await registry.register(model) mock_get_auth_attributes.return_value = { "roles": [], } result = await routing_table.get_model("model-empty-attrs") assert result.identifier == "model-empty-attrs" all_models = await routing_table.list_models() model_ids = [m.identifier for m in all_models.data] assert "model-empty-attrs" in model_ids @pytest.mark.asyncio @patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") async def test_no_user_attributes(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( identifier="model-public-2", provider_id="test_provider", provider_resource_id="model-public-2", model_type=ModelType.llm, ) model_restricted = ModelWithACL( identifier="model-restricted", provider_id="test_provider", provider_resource_id="model-restricted", model_type=ModelType.llm, access_attributes=AccessAttributes(roles=["admin"]), ) await registry.register(model_public) await registry.register(model_restricted) mock_get_auth_attributes.return_value = None model = await routing_table.get_model("model-public-2") assert model.identifier == "model-public-2" with pytest.raises(ValueError): await routing_table.get_model("model-restricted") all_models = await routing_table.list_models() assert len(all_models.data) == 1 assert all_models.data[0].identifier == "model-public-2" @pytest.mark.asyncio @patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup): """Test that newly created resources inherit access attributes from their creator.""" registry, routing_table = test_setup # Set creator's attributes creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]} mock_get_auth_attributes.return_value = creator_attributes # Create model without explicit access attributes model = ModelWithACL( identifier="auto-access-model", provider_id="test_provider", provider_resource_id="auto-access-model", model_type=ModelType.llm, ) await routing_table.register_object(model) # Verify the model got creator's attributes registered_model = await routing_table.get_model("auto-access-model") assert registered_model.access_attributes is not None assert registered_model.access_attributes.roles == ["data-scientist"] assert registered_model.access_attributes.teams == ["ml-team"] assert registered_model.access_attributes.projects == ["llama-3"] # Verify another user without matching attributes can't access it mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]} with pytest.raises(ValueError): await routing_table.get_model("auto-access-model") # But a user with matching attributes can mock_get_auth_attributes.return_value = { "roles": ["data-scientist", "engineer"], "teams": ["ml-team", "platform-team"], "projects": ["llama-3"], } model = await routing_table.get_model("auto-access-model") assert model.identifier == "auto-access-model" @pytest.fixture async def test_setup_with_resource_attribute_rules(cached_disk_dist_registry): mock_inference = Mock() mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) resource_attributes = ResourceAccessAttributes( [ # model-1 can be accessed by users in project foo or those who have admin role AccessAttributesRule( resource_type="model", resource_id="model-1", attributes=AccessAttributes(projects=["foo"]) ), # model-2 can be accessed by users in project bar or those who have admin role AccessAttributesRule( resource_type="model", resource_id="model-2", attributes=AccessAttributes(projects=["bar"]) ), # any other model can be accessed or created only by those who have admin role AccessAttributesRule(resource_type="model", attributes=AccessAttributes(roles=["admin"])), ] ) resource_attributes.enable_access_checks() routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry, resource_attributes=resource_attributes, ) yield routing_table @pytest.mark.asyncio @patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") async def test_with_resource_attribute_rules(mock_get_auth_attributes, test_setup_with_resource_attribute_rules): routing_table = test_setup_with_resource_attribute_rules mock_get_auth_attributes.return_value = { "roles": ["admin"], "projects": ["foo", "bar"], } await routing_table.register_model("model-1", provider_id="test_provider") await routing_table.register_model("model-2", provider_id="test_provider") await routing_table.register_model("model-3", provider_id="test_provider") mock_get_auth_attributes.return_value = { "roles": ["user"], "projects": ["foo"], } model = await routing_table.get_model("model-1") assert model.identifier == "model-1" with pytest.raises(ValueError): await routing_table.get_model("model-2") with pytest.raises(ValueError): await routing_table.get_model("model-3") with pytest.raises(ValueError): await routing_table.register_model("model-4", provider_id="test_provider") mock_get_auth_attributes.return_value = { "roles": ["user"], "projects": ["bar"], } model = await routing_table.get_model("model-2") assert model.identifier == "model-2" with pytest.raises(ValueError): await routing_table.get_model("model-1") with pytest.raises(ValueError): await routing_table.get_model("model-3") with pytest.raises(ValueError): await routing_table.register_model("model-5", provider_id="test_provider")