mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 08:40:05 +00:00
feat(server): add attribute based access control for resources
This commit is contained in:
parent
7c0448456e
commit
b937a49436
8 changed files with 862 additions and 35 deletions
143
tests/unit/registry/test_registry_acl.py
Normal file
143
tests/unit/registry/test_registry_acl.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.server.auth import AccessAttributes
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def kvstore():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_registry_acl.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
yield kvstore
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def registry(kvstore):
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
async def test_registry_cache_with_acl(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-acl-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
||||
)
|
||||
|
||||
success = await registry.register(model)
|
||||
assert success
|
||||
|
||||
cached_model = registry.get_cached("model", "model-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.identifier == "model-acl"
|
||||
assert cached_model.access_attributes.roles == ["admin"]
|
||||
assert cached_model.access_attributes.teams == ["ai-team"]
|
||||
|
||||
fetched_model = await registry.get("model", "model-acl")
|
||||
assert fetched_model is not None
|
||||
assert fetched_model.identifier == "model-acl"
|
||||
assert fetched_model.access_attributes.roles == ["admin"]
|
||||
|
||||
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
||||
await registry.update(model)
|
||||
|
||||
updated_cached = registry.get_cached("model", "model-acl")
|
||||
assert updated_cached is not None
|
||||
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
||||
assert updated_cached.access_attributes.projects == ["project-x"]
|
||||
assert updated_cached.access_attributes.teams is None
|
||||
|
||||
new_registry = CachedDiskDistributionRegistry(registry._kvstore)
|
||||
await new_registry.initialize()
|
||||
|
||||
new_model = await new_registry.get("model", "model-acl")
|
||||
assert new_model is not None
|
||||
assert new_model.identifier == "model-acl"
|
||||
assert new_model.access_attributes.roles == ["admin", "user"]
|
||||
assert new_model.access_attributes.projects == ["project-x"]
|
||||
assert new_model.access_attributes.teams is None
|
||||
|
||||
|
||||
async def test_registry_empty_acl(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
|
||||
cached_model = registry.get_cached("model", "model-empty-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is not None
|
||||
assert cached_model.access_attributes.roles is None
|
||||
assert cached_model.access_attributes.teams is None
|
||||
assert cached_model.access_attributes.projects is None
|
||||
assert cached_model.access_attributes.namespaces is None
|
||||
|
||||
all_models = await registry.get_all()
|
||||
assert len(all_models) == 1
|
||||
|
||||
model = ModelWithACL(
|
||||
identifier="model-no-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource-2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
|
||||
cached_model = registry.get_cached("model", "model-no-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is None
|
||||
|
||||
all_models = await registry.get_all()
|
||||
assert len(all_models) == 2
|
||||
|
||||
|
||||
async def test_registry_serialization(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-serialize",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["admin", "researcher"],
|
||||
teams=["ai-team", "ml-team"],
|
||||
projects=["project-a", "project-b"],
|
||||
namespaces=["prod", "staging"],
|
||||
),
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
registry.cache.clear()
|
||||
|
||||
loaded_model = await registry.get("model", "model-serialize")
|
||||
|
||||
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
||||
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
||||
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
|
||||
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
|
||||
236
tests/unit/server/test_access_control.py
Normal file
236
tests/unit/server/test_access_control.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super(AsyncMock, self).__call__(*args, **kwargs)
|
||||
|
||||
|
||||
def _return_model(model):
|
||||
return model
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_setup():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_access_control.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
|
||||
mock_inference = Mock()
|
||||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = "inference"
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=registry,
|
||||
)
|
||||
yield registry, routing_table
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_admin_only = ModelWithACL(
|
||||
identifier="model-admin",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-admin",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
model_data_scientist = ModelWithACL(
|
||||
identifier="model-data-scientist",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-data-scientist",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_admin_only)
|
||||
await registry.register(model_data_scientist)
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 3
|
||||
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
model = await routing_table.get_model("model-admin")
|
||||
assert model.identifier == "model-admin"
|
||||
model = await routing_table.get_model("model-data-scientist")
|
||||
assert model.identifier == "model-data-scientist"
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public"
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-admin")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
assert "model-public" in model_ids
|
||||
assert "model-data-scientist" in model_ids
|
||||
assert "model-admin" not in model_ids
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
model = await routing_table.get_model("model-data-scientist")
|
||||
assert model.identifier == "model-data-scientist"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-admin")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-updates",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-updates",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await registry.register(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
||||
await registry.update(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-updates")
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["admin"],
|
||||
}
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-attrs",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-empty-attrs",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
await registry.register(model)
|
||||
mock_get_auth_attributes.return_value = {}
|
||||
result = await routing_table.get_model("model-empty-attrs")
|
||||
assert result.identifier == "model-empty-attrs"
|
||||
all_models = await routing_table.list_models()
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
assert "model-empty-attrs" in model_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public-2",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public-2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_restricted = ModelWithACL(
|
||||
identifier="model-restricted",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-restricted",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_restricted)
|
||||
mock_get_auth_attributes.return_value = None
|
||||
model = await routing_table.get_model("model-public-2")
|
||||
assert model.identifier == "model-public-2"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-restricted")
|
||||
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public-2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
"""Test that newly created resources inherit access attributes from their creator."""
|
||||
registry, routing_table = test_setup
|
||||
|
||||
# Set creator's attributes
|
||||
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
|
||||
# Create model without explicit access attributes
|
||||
model = ModelWithACL(
|
||||
identifier="auto-access-model",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="auto-access-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await routing_table.register_object(model)
|
||||
|
||||
# Verify the model got creator's attributes
|
||||
registered_model = await routing_table.get_model("auto-access-model")
|
||||
assert registered_model.access_attributes is not None
|
||||
assert registered_model.access_attributes.roles == ["data-scientist"]
|
||||
assert registered_model.access_attributes.teams == ["ml-team"]
|
||||
assert registered_model.access_attributes.projects == ["llama-3"]
|
||||
|
||||
# Verify another user without matching attributes can't access it
|
||||
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("auto-access-model")
|
||||
|
||||
# But a user with matching attributes can
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
}
|
||||
model = await routing_table.get_model("auto-access-model")
|
||||
assert model.identifier == "auto-access-model"
|
||||
|
|
@ -13,6 +13,15 @@ from fastapi.testclient import TestClient
|
|||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_endpoint():
|
||||
return "http://mock-auth-service/validate"
|
||||
|
|
@ -45,6 +54,26 @@ def client(app):
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scope():
|
||||
return {
|
||||
"type": "http",
|
||||
"path": "/models/list",
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"authorization", b"Bearer test-api-key"),
|
||||
(b"user-agent", b"test-user-agent"),
|
||||
],
|
||||
"query_string": b"limit=100&offset=0",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
return AuthenticationMiddleware(mock_app, mock_auth_endpoint), mock_app
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
|
@ -122,3 +151,59 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
|||
assert "authorization" in payload["request"]["headers"]
|
||||
assert "param1" in payload["request"]["params"]
|
||||
assert "param2" in payload["request"]["params"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_with_access_attributes(mock_middleware, mock_scope):
|
||||
middleware, mock_app = mock_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
200,
|
||||
{
|
||||
"access_attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["project-x", "project-y"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"]["roles"] == ["admin", "user"]
|
||||
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
|
||||
assert mock_scope["user_attributes"]["projects"] == ["project-x", "project-y"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful"
|
||||
# No access_attributes
|
||||
},
|
||||
)
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"] == {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue