mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
This PR introduces a way to implement Attribute Based Access Control (ABAC) for the Llama Stack server. The rough design is: - https://github.com/meta-llama/llama-stack/pull/1626 added a way for the Llama Stack server to query an authenticator - We build upon that and expect "access attributes" as part of the response. These attributes indicate the scopes available for the request. - We use these attributes to perform access control for registered resources as well as for constructing the default access control policies for newly created resources. - By default, if you support authentication but don't return access attributes, we will add a unique namespace pointing to the API_KEY. That way, all resources by default will be scoped to API_KEYs. An important aspect of this design is that Llama Stack stays out of the business of credential management or the CRUD for attributes. How you manage your namespaces or projects is entirely up to you. The design only implements access control checks for the metadata / book-keeping information that the Stack tracks. ### Limitations - Currently, read vs. write vs. admin permissions aren't made explicit, but this can be easily extended by adding appropriate attributes to the `AccessAttributes` data structure. - This design does not apply to agent instances since they are not considered resources the Stack knows about. Agent instances are completely within the scope of the Agents API provider. ### Test Plan Added unit tests, existing integration tests
240 lines
9.3 KiB
Python
240 lines
9.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.datatypes import Api
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
|
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
|
|
|
|
|
class AsyncMock(MagicMock):
|
|
async def __call__(self, *args, **kwargs):
|
|
return super(AsyncMock, self).__call__(*args, **kwargs)
|
|
|
|
|
|
def _return_model(model):
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_setup():
|
|
temp_dir = tempfile.mkdtemp()
|
|
db_path = os.path.join(temp_dir, "test_access_control.db")
|
|
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
|
kvstore = SqliteKVStoreImpl(kvstore_config)
|
|
await kvstore.initialize()
|
|
registry = CachedDiskDistributionRegistry(kvstore)
|
|
await registry.initialize()
|
|
|
|
mock_inference = Mock()
|
|
mock_inference.__provider_spec__ = MagicMock()
|
|
mock_inference.__provider_spec__.api = Api.inference
|
|
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
|
routing_table = ModelsRoutingTable(
|
|
impls_by_provider_id={"test_provider": mock_inference},
|
|
dist_registry=registry,
|
|
)
|
|
yield registry, routing_table
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|
registry, routing_table = test_setup
|
|
model_public = ModelWithACL(
|
|
identifier="model-public",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-public",
|
|
model_type=ModelType.llm,
|
|
)
|
|
model_admin_only = ModelWithACL(
|
|
identifier="model-admin",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-admin",
|
|
model_type=ModelType.llm,
|
|
access_attributes=AccessAttributes(roles=["admin"]),
|
|
)
|
|
model_data_scientist = ModelWithACL(
|
|
identifier="model-data-scientist",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-data-scientist",
|
|
model_type=ModelType.llm,
|
|
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
|
)
|
|
await registry.register(model_public)
|
|
await registry.register(model_admin_only)
|
|
await registry.register(model_data_scientist)
|
|
|
|
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
|
all_models = await routing_table.list_models()
|
|
assert len(all_models.data) == 2
|
|
|
|
model = await routing_table.get_model("model-public")
|
|
assert model.identifier == "model-public"
|
|
model = await routing_table.get_model("model-admin")
|
|
assert model.identifier == "model-admin"
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-data-scientist")
|
|
|
|
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
|
all_models = await routing_table.list_models()
|
|
assert len(all_models.data) == 1
|
|
assert all_models.data[0].identifier == "model-public"
|
|
model = await routing_table.get_model("model-public")
|
|
assert model.identifier == "model-public"
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-admin")
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-data-scientist")
|
|
|
|
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
|
all_models = await routing_table.list_models()
|
|
assert len(all_models.data) == 2
|
|
model_ids = [m.identifier for m in all_models.data]
|
|
assert "model-public" in model_ids
|
|
assert "model-data-scientist" in model_ids
|
|
assert "model-admin" not in model_ids
|
|
model = await routing_table.get_model("model-public")
|
|
assert model.identifier == "model-public"
|
|
model = await routing_table.get_model("model-data-scientist")
|
|
assert model.identifier == "model-data-scientist"
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-admin")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
|
registry, routing_table = test_setup
|
|
model_public = ModelWithACL(
|
|
identifier="model-updates",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-updates",
|
|
model_type=ModelType.llm,
|
|
)
|
|
await registry.register(model_public)
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": ["user"],
|
|
}
|
|
model = await routing_table.get_model("model-updates")
|
|
assert model.identifier == "model-updates"
|
|
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
|
await registry.update(model_public)
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": ["user"],
|
|
}
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-updates")
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": ["admin"],
|
|
}
|
|
model = await routing_table.get_model("model-updates")
|
|
assert model.identifier == "model-updates"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
|
registry, routing_table = test_setup
|
|
model = ModelWithACL(
|
|
identifier="model-empty-attrs",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-empty-attrs",
|
|
model_type=ModelType.llm,
|
|
access_attributes=AccessAttributes(),
|
|
)
|
|
await registry.register(model)
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": [],
|
|
}
|
|
result = await routing_table.get_model("model-empty-attrs")
|
|
assert result.identifier == "model-empty-attrs"
|
|
all_models = await routing_table.list_models()
|
|
model_ids = [m.identifier for m in all_models.data]
|
|
assert "model-empty-attrs" in model_ids
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
|
registry, routing_table = test_setup
|
|
model_public = ModelWithACL(
|
|
identifier="model-public-2",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-public-2",
|
|
model_type=ModelType.llm,
|
|
)
|
|
model_restricted = ModelWithACL(
|
|
identifier="model-restricted",
|
|
provider_id="test_provider",
|
|
provider_resource_id="model-restricted",
|
|
model_type=ModelType.llm,
|
|
access_attributes=AccessAttributes(roles=["admin"]),
|
|
)
|
|
await registry.register(model_public)
|
|
await registry.register(model_restricted)
|
|
mock_get_auth_attributes.return_value = None
|
|
model = await routing_table.get_model("model-public-2")
|
|
assert model.identifier == "model-public-2"
|
|
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("model-restricted")
|
|
|
|
all_models = await routing_table.list_models()
|
|
assert len(all_models.data) == 1
|
|
assert all_models.data[0].identifier == "model-public-2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
|
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
|
"""Test that newly created resources inherit access attributes from their creator."""
|
|
registry, routing_table = test_setup
|
|
|
|
# Set creator's attributes
|
|
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
|
mock_get_auth_attributes.return_value = creator_attributes
|
|
|
|
# Create model without explicit access attributes
|
|
model = ModelWithACL(
|
|
identifier="auto-access-model",
|
|
provider_id="test_provider",
|
|
provider_resource_id="auto-access-model",
|
|
model_type=ModelType.llm,
|
|
)
|
|
await routing_table.register_object(model)
|
|
|
|
# Verify the model got creator's attributes
|
|
registered_model = await routing_table.get_model("auto-access-model")
|
|
assert registered_model.access_attributes is not None
|
|
assert registered_model.access_attributes.roles == ["data-scientist"]
|
|
assert registered_model.access_attributes.teams == ["ml-team"]
|
|
assert registered_model.access_attributes.projects == ["llama-3"]
|
|
|
|
# Verify another user without matching attributes can't access it
|
|
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
|
with pytest.raises(ValueError):
|
|
await routing_table.get_model("auto-access-model")
|
|
|
|
# But a user with matching attributes can
|
|
mock_get_auth_attributes.return_value = {
|
|
"roles": ["data-scientist", "engineer"],
|
|
"teams": ["ml-team", "platform-team"],
|
|
"projects": ["llama-3"],
|
|
}
|
|
model = await routing_table.get_model("auto-access-model")
|
|
assert model.identifier == "auto-access-model"
|