feat: fine grained access control policy

This allows a set of rules to be defined for determining access to resources.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-05-06 18:54:58 +01:00
parent 9623d5d230
commit 01ad876012
20 changed files with 724 additions and 214 deletions

View file

@ -7,10 +7,14 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
import yaml
from pydantic import TypeAdapter
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL
from llama_stack.distribution.request_headers import User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -32,13 +36,14 @@ async def test_setup(cached_disk_dist_registry):
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy={},
)
yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-public",
@ -64,7 +69,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
await registry.register(model_admin_only)
await registry.register(model_data_scientist)
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
@ -75,7 +80,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public"
@ -86,7 +91,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
model_ids = [m.identifier for m in all_models.data]
@ -102,8 +107,8 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-updates",
@ -112,28 +117,37 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
model_type=ModelType.llm,
)
await registry.register(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["user"],
},
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"])
await registry.update(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["user"],
},
)
with pytest.raises(ValueError):
await routing_table.get_model("model-updates")
mock_get_auth_attributes.return_value = {
"roles": ["admin"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["admin"],
},
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model = ModelWithACL(
identifier="model-empty-attrs",
@ -143,9 +157,12 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
access_attributes=AccessAttributes(),
)
await registry.register(model)
mock_get_auth_attributes.return_value = {
"roles": [],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": [],
},
)
result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs"
all_models = await routing_table.list_models()
@ -154,8 +171,8 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
identifier="model-public-2",
@ -172,7 +189,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
)
await registry.register(model_public)
await registry.register(model_restricted)
mock_get_auth_attributes.return_value = None
mock_get_authenticated_user.return_value = User("test-user", None)
model = await routing_table.get_model("model-public-2")
assert model.identifier == "model-public-2"
@ -185,14 +202,14 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator."""
registry, routing_table = test_setup
# Set creator's attributes
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
mock_get_auth_attributes.return_value = creator_attributes
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
# Create model without explicit access attributes
model = ModelWithACL(
@ -211,15 +228,262 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
assert registered_model.access_attributes.projects == ["llama-3"]
# Verify another user without matching attributes can't access it
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
with pytest.raises(ValueError):
await routing_table.get_model("auto-access-model")
# But a user with matching attributes can
mock_get_auth_attributes.return_value = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
},
)
model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model"
@pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
config = """
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy=policy,
)
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
{
"roles": ["user"],
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
{
"roles": ["user"],
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
def test_permit_when():
config = """
- permit:
principal: user-1
actions: [read]
when:
user_in: resource.namespaces
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_permit_unless():
config = """
- permit:
principal: user-1
actions: [read]
resource: model::*
unless:
- user_not_in: resource.namespaces
- user_in: resource.teams
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_when():
config = """
- forbid:
principal: user-1
actions: [read]
when:
user_in: resource.namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_unless():
config = """
- forbid:
principal: user-1
actions: [read]
unless:
user_in: resource.namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_condition_with_literal():
config = """
- permit:
actions: [read]
when:
user_in: role::admin
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_condition_with_unrecognised_literal():
config = """
- permit:
actions: [read]
when:
user_in: whatever
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
access_attributes=AccessAttributes(namespaces=["foo"]),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", None))
def test_empty_condition():
config = """
- permit:
actions: [read]
when: {}
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithACL(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", None))