feat: fine grained access control policy (#2264)

This allows a set of rules to be defined for determining access to
resources. The rules are (loosely) based on the cedar policy format.

A rule defines a list of action either to permit or to forbid. It may
specify a principal or a resource that must match for the rule to take
effect. It may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies.

A list of rules is held for each type to be protected and tried in order
to find a match. If a match is found, the request is permitted or
forbidden depening on the type of rule. If no match is found, the
request is denied. If no rules are specified for a given type, a rule
that allows any action as long as the resource attributes match the user
attributes is added (i.e. the previous behaviour is the default.

Some examples in yaml:

```
    model:
    - permit:
      principal: user-1
      actions: [create, read, delete]
      comment: user-1 has full access to all models
    - permit:
      principal: user-2
      actions: [read]
      resource: model-1
      comment: user-2 has read access to model-1 only
    - permit:
      actions: [read]
      when:
        user_in: resource.namespaces
      comment: any user has read access to models with matching attributes
    vector_db:
    - forbid:
      actions: [create, read, delete]
      unless:
        user_in: role::admin
      comment: only user with admin role can use vector_db resources
```

---------

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
grs 2025-06-03 17:51:12 -04:00 committed by GitHub
parent 8bee2954be
commit 7c1998db25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 956 additions and 450 deletions

View file

@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple models and verify listing
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple shields and verify listing
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_providere",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple datasets and verify listing
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple scoring functions and verify listing
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple benchmarks and verify listing
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple tool groups and verify listing

View file

@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
{},
)
await impl.initialize()
yield impl

View file

@ -12,24 +12,24 @@ import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_auth_attributes.return_value = creator_attributes
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
# Create a session
session_id = await agent_persistence.create_session("Test Session")
@ -37,14 +37,15 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.access_attributes is not None
assert session_info.access_attributes.roles == ["researcher"]
assert session_info.access_attributes.teams == ["ai-team"]
assert session_info.owner is not None
assert session_info.owner.attributes is not None
assert session_info.owner.attributes["roles"] == ["researcher"]
assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
@ -53,8 +54,9 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
)
# User with matching attributes can access
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
mock_get_authenticated_user.return_value = User(
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
)
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
@ -85,8 +89,9 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
)
# Admin can add turn
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_auth_attributes.return_value = {"roles": ["user"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
@ -138,8 +143,9 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_auth_attributes.return_value = {"roles": ["user"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None

View file

@ -8,19 +8,18 @@
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithACL
from llama_stack.distribution.server.auth_providers import AccessAttributes
from llama_stack.distribution.datatypes import ModelWithOwner, User
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-acl",
provider_id="test-provider",
provider_resource_id="model-acl-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
)
success = await cached_disk_dist_registry.register(model)
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
assert cached_model is not None
assert cached_model.identifier == "model-acl"
assert cached_model.access_attributes.roles == ["admin"]
assert cached_model.access_attributes.teams == ["ai-team"]
assert cached_model.owner.principal == "testuser"
assert cached_model.owner.attributes["roles"] == ["admin"]
assert cached_model.owner.attributes["teams"] == ["ai-team"]
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
assert fetched_model is not None
assert fetched_model.identifier == "model-acl"
assert fetched_model.access_attributes.roles == ["admin"]
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
await cached_disk_dist_registry.update(model)
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
assert updated_cached is not None
assert updated_cached.access_attributes.roles == ["admin", "user"]
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
assert fetched_model.owner.attributes["roles"] == ["admin"]
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize()
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
new_model = await new_registry.get("model", "model-acl")
assert new_model is not None
assert new_model.identifier == "model-acl"
assert new_model.access_attributes.roles == ["admin", "user"]
assert new_model.access_attributes.projects == ["project-x"]
assert new_model.access_attributes.teams is None
assert new_model.owner.principal == "testuser"
assert new_model.owner.attributes["roles"] == ["admin"]
assert new_model.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-empty-acl",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(),
owner=User("testuser", None),
)
await cached_disk_dist_registry.register(model)
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
assert cached_model is not None
assert cached_model.access_attributes is not None
assert cached_model.access_attributes.roles is None
assert cached_model.access_attributes.teams is None
assert cached_model.access_attributes.projects is None
assert cached_model.access_attributes.namespaces is None
assert cached_model.owner is not None
assert cached_model.owner.attributes is None
all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 1
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-no-acl",
provider_id="test-provider",
provider_resource_id="model-resource-2",
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
assert cached_model is not None
assert cached_model.access_attributes is None
assert cached_model.owner is None
all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 2
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_registry_serialization(cached_disk_dist_registry):
attributes = AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
projects=["project-a", "project-b"],
namespaces=["prod", "staging"],
)
attributes = {
"roles": ["admin", "researcher"],
"teams": ["ai-team", "ml-team"],
"projects": ["project-a", "project-b"],
"namespaces": ["prod", "staging"],
}
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-serialize",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=attributes,
owner=User("bob", attributes),
)
await cached_disk_dist_registry.register(model)
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
loaded_model = await new_registry.get("model", "model-serialize")
assert loaded_model is not None
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]

View file

@ -7,10 +7,13 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
import yaml
from pydantic import TypeAdapter, ValidationError
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -32,39 +35,40 @@ async def test_setup(cached_disk_dist_registry):
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy={},
)
yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-public",
provider_id="test_provider",
provider_resource_id="model-public",
model_type=ModelType.llm,
)
model_admin_only = ModelWithACL(
model_admin_only = ModelWithOwner(
identifier="model-admin",
provider_id="test_provider",
provider_resource_id="model-admin",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("testuser", {"roles": ["admin"]}),
)
model_data_scientist = ModelWithACL(
model_data_scientist = ModelWithOwner(
identifier="model-data-scientist",
provider_id="test_provider",
provider_resource_id="model-data-scientist",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
)
await registry.register(model_public)
await registry.register(model_admin_only)
await registry.register(model_data_scientist)
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
@ -75,7 +79,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public"
@ -86,7 +90,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
model_ids = [m.identifier for m in all_models.data]
@ -102,50 +106,62 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-updates",
provider_id="test_provider",
provider_resource_id="model-updates",
model_type=ModelType.llm,
)
await registry.register(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["user"],
},
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"])
model_public.owner = User("testuser", {"roles": ["admin"]})
await registry.update(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["user"],
},
)
with pytest.raises(ValueError):
await routing_table.get_model("model-updates")
mock_get_auth_attributes.return_value = {
"roles": ["admin"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["admin"],
},
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-empty-attrs",
provider_id="test_provider",
provider_resource_id="model-empty-attrs",
model_type=ModelType.llm,
access_attributes=AccessAttributes(),
owner=User("testuser", {}),
)
await registry.register(model)
mock_get_auth_attributes.return_value = {
"roles": [],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": [],
},
)
result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs"
all_models = await routing_table.list_models()
@ -154,25 +170,25 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-public-2",
provider_id="test_provider",
provider_resource_id="model-public-2",
model_type=ModelType.llm,
)
model_restricted = ModelWithACL(
model_restricted = ModelWithOwner(
identifier="model-restricted",
provider_id="test_provider",
provider_resource_id="model-restricted",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("testuser", {"roles": ["admin"]}),
)
await registry.register(model_public)
await registry.register(model_restricted)
mock_get_auth_attributes.return_value = None
mock_get_authenticated_user.return_value = User("test-user", None)
model = await routing_table.get_model("model-public-2")
assert model.identifier == "model-public-2"
@ -185,17 +201,17 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator."""
registry, routing_table = test_setup
# Set creator's attributes
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
mock_get_auth_attributes.return_value = creator_attributes
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
# Create model without explicit access attributes
model = ModelWithACL(
model = ModelWithOwner(
identifier="auto-access-model",
provider_id="test_provider",
provider_resource_id="auto-access-model",
@ -205,21 +221,346 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
# Verify the model got creator's attributes
registered_model = await routing_table.get_model("auto-access-model")
assert registered_model.access_attributes is not None
assert registered_model.access_attributes.roles == ["data-scientist"]
assert registered_model.access_attributes.teams == ["ml-team"]
assert registered_model.access_attributes.projects == ["llama-3"]
assert registered_model.owner is not None
assert registered_model.owner.attributes is not None
assert registered_model.owner.attributes["roles"] == ["data-scientist"]
assert registered_model.owner.attributes["teams"] == ["ml-team"]
assert registered_model.owner.attributes["projects"] == ["llama-3"]
# Verify another user without matching attributes can't access it
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
with pytest.raises(ValueError):
await routing_table.get_model("auto-access-model")
# But a user with matching attributes can
mock_get_auth_attributes.return_value = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
},
)
model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model"
@pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
config = """
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy=policy,
)
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
{
"roles": ["user"],
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
{
"roles": ["user"],
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
def test_permit_when():
config = """
- permit:
principal: user-1
actions: [read]
when: user in owners namespaces
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_permit_unless():
config = """
- permit:
principal: user-1
actions: [read]
resource: model::*
unless:
- user not in owners namespaces
- user in owners teams
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_when():
config = """
- forbid:
principal: user-1
actions: [read]
when:
user in owners namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_unless():
config = """
- forbid:
principal: user-1
actions: [read]
unless:
user in owners namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_user_has_attribute():
config = """
- permit:
actions: [read]
when: user with admin in roles
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_user_does_not_have_attribute():
config = """
- permit:
actions: [read]
unless: user with admin not in roles
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_is_owner():
config = """
- permit:
actions: [read]
when: user is owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_is_not_owner():
config = """
- permit:
actions: [read]
unless: user is not owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_invalid_rule_permit_and_forbid_both_specified():
config = """
- permit:
actions: [read]
forbid:
actions: [create]
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_neither_permit_or_forbid_specified():
config = """
- when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_when_and_unless_both_specified():
config = """
- permit:
actions: [read]
when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_condition():
config = """
- permit:
actions: [read]
when: random words that are not valid
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
@pytest.mark.parametrize(
"condition",
[
"user is owner",
"user is not owner",
"user with dev in teams",
"user with default not in namespaces",
"user in owners roles",
"user not in owners projects",
],
)
def test_condition_reprs(condition):
from llama_stack.distribution.access_control.conditions import parse_condition
assert condition == str(parse_condition(condition))

View file

@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
{
"message": "Authentication successful",
"principal": "test-principal",
"access_attributes": {
"attributes": {
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
{
"message": "Authentication successful",
"principal": "test-principal",
"access_attributes": {
"attributes": {
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
@pytest.mark.asyncio
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
"""Test middleware behavior with no access attributes"""
middleware, mock_app = mock_http_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_client_instance
mock_client_instance.post.return_value = MockResponse(
200,
{
"message": "Authentication successful"
# No access_attributes
},
)
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
attributes = mock_scope["user_attributes"]
assert "roles" in attributes
assert attributes["roles"] == ["test.jwt.token"]
# oauth2 token provider tests
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
"aud": "llama-stack",
}
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
assert attributes.roles == ["my-user"]
assert attributes.teams == ["group1", "group2"]
assert attributes["roles"] == ["my-user"]
assert attributes["teams"] == ["group1", "group2"]
claims = {
"sub": "my-user",
"tenant": "my-tenant",
}
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
assert attributes.roles == ["my-user"]
assert attributes.namespaces == ["my-tenant"]
assert attributes["roles"] == ["my-user"]
assert attributes["namespaces"] == ["my-tenant"]
claims = {
"sub": "my-user",
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
"groups": "teams",
},
)
assert set(attributes.roles) == {"my-user", "my-username"}
assert set(attributes.teams) == {"my-team", "group1", "group2"}
assert attributes.namespaces == ["my-tenant"]
assert set(attributes["roles"]) == {"my-user", "my-username"}
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
assert attributes["namespaces"] == ["my-tenant"]
# TODO: add more tests for oauth2 token provider

View file

@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
add_protocol_methods(SampleImpl, Inference)
mock_module.get_provider_impl = AsyncMock(return_value=impl)
mock_module.get_provider_impl.__text_signature__ = "()"
sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry)
impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter)