mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-19 11:20:03 +00:00
feat: fine grained access control policy
This allows a set of rules to be defined for determining access to resources. Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
9623d5d230
commit
01ad876012
20 changed files with 724 additions and 214 deletions
|
@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple models and verify listing
|
||||
|
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple shields and verify listing
|
||||
|
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_providere",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
|
||||
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple datasets and verify listing
|
||||
|
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple scoring functions and verify listing
|
||||
|
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple benchmarks and verify listing
|
||||
|
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple tool groups and verify listing
|
||||
|
|
|
@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
|
|||
mock_apis["safety_api"],
|
||||
mock_apis["tool_runtime_api"],
|
||||
mock_apis["tool_groups_api"],
|
||||
{},
|
||||
)
|
||||
await impl.initialize()
|
||||
yield impl
|
||||
|
|
|
@ -13,23 +13,24 @@ import pytest
|
|||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup(sqlite_kvstore):
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
||||
yield agent_persistence
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Set creator's attributes for the session
|
||||
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
|
||||
|
||||
# Create a session
|
||||
session_id = await agent_persistence.create_session("Test Session")
|
||||
|
@ -43,8 +44,8 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with specific access attributes
|
||||
|
@ -55,6 +56,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
|||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
|||
)
|
||||
|
||||
# User with matching attributes can access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
)
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is not None
|
||||
assert retrieved_session.session_id == session_id
|
||||
|
||||
# User without matching attributes cannot access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
|
@ -87,6 +91,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
)
|
||||
|
||||
# Admin can add turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.add_turn_to_session(session_id, turn)
|
||||
|
||||
# Admin can get turn
|
||||
|
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
assert retrieved_turn.turn_id == turn_id
|
||||
|
||||
# Regular user cannot get turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
|
||||
|
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
|
@ -140,6 +145,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
turn_id = str(uuid.uuid4())
|
||||
|
||||
# Admin user can set inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
||||
|
||||
# Admin user can get inference iterations
|
||||
|
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
assert infer_iters == 5
|
||||
|
||||
# Regular user cannot get inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters is None
|
||||
|
||||
|
|
|
@ -7,10 +7,14 @@
|
|||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||
|
||||
|
||||
|
@ -32,13 +36,14 @@ async def test_setup(cached_disk_dist_registry):
|
|||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
policy={},
|
||||
)
|
||||
yield cached_disk_dist_registry, routing_table
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public",
|
||||
|
@ -64,7 +69,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
await registry.register(model_admin_only)
|
||||
await registry.register(model_data_scientist)
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
|
||||
|
@ -75,7 +80,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public"
|
||||
|
@ -86,7 +91,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
|
@ -102,8 +107,8 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-updates",
|
||||
|
@ -112,28 +117,37 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
|||
model_type=ModelType.llm,
|
||||
)
|
||||
await registry.register(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["user"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
||||
await registry.update(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["user"],
|
||||
},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-updates")
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["admin"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-attrs",
|
||||
|
@ -143,9 +157,12 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
|
|||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
await registry.register(model)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": [],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": [],
|
||||
},
|
||||
)
|
||||
result = await routing_table.get_model("model-empty-attrs")
|
||||
assert result.identifier == "model-empty-attrs"
|
||||
all_models = await routing_table.list_models()
|
||||
|
@ -154,8 +171,8 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public-2",
|
||||
|
@ -172,7 +189,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
|||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_restricted)
|
||||
mock_get_auth_attributes.return_value = None
|
||||
mock_get_authenticated_user.return_value = User("test-user", None)
|
||||
model = await routing_table.get_model("model-public-2")
|
||||
assert model.identifier == "model-public-2"
|
||||
|
||||
|
@ -185,14 +202,14 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
"""Test that newly created resources inherit access attributes from their creator."""
|
||||
registry, routing_table = test_setup
|
||||
|
||||
# Set creator's attributes
|
||||
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
|
||||
|
||||
# Create model without explicit access attributes
|
||||
model = ModelWithACL(
|
||||
|
@ -211,15 +228,262 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
|
|||
assert registered_model.access_attributes.projects == ["llama-3"]
|
||||
|
||||
# Verify another user without matching attributes can't access it
|
||||
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("auto-access-model")
|
||||
|
||||
# But a user with matching attributes can
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
"projects": ["llama-3"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
"projects": ["llama-3"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("auto-access-model")
|
||||
assert model.identifier == "auto-access-model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup_with_access_policy(cached_disk_dist_registry):
|
||||
mock_inference = Mock()
|
||||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = Api.inference
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
|
||||
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [create, read, delete]
|
||||
description: user-1 has full access to all models
|
||||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
principal: user-3
|
||||
actions: [read]
|
||||
resource: model::model-2
|
||||
description: user-3 has read access to model-2 only
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
policy=policy,
|
||||
)
|
||||
yield routing_table
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
|
||||
routing_table = test_setup_with_access_policy
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-1",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.register_model("model-1", provider_id="test_provider")
|
||||
await routing_table.register_model("model-2", provider_id="test_provider")
|
||||
await routing_table.register_model("model-3", provider_id="test_provider")
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
model = await routing_table.get_model("model-3")
|
||||
assert model.identifier == "model-3"
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-2",
|
||||
{
|
||||
"roles": ["user"],
|
||||
"projects": ["foo"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-2")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-4", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-1")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-3",
|
||||
{
|
||||
"roles": ["user"],
|
||||
"projects": ["bar"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-1")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-5", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-2")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-1",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.unregister_model("model-3")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
|
||||
|
||||
def test_permit_when():
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: resource.namespaces
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_permit_unless():
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
resource: model::*
|
||||
unless:
|
||||
- user_not_in: resource.namespaces
|
||||
- user_in: resource.teams
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_forbid_when():
|
||||
config = """
|
||||
- forbid:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: resource.namespaces
|
||||
- permit:
|
||||
actions: [read]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_forbid_unless():
|
||||
config = """
|
||||
- forbid:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
unless:
|
||||
user_in: resource.namespaces
|
||||
- permit:
|
||||
actions: [read]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_condition_with_literal():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: role::admin
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_condition_with_unrecognised_literal():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: whatever
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", None))
|
||||
|
||||
|
||||
def test_empty_condition():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: {}
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", None))
|
||||
|
|
|
@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
|
|||
add_protocol_methods(SampleImpl, Inference)
|
||||
|
||||
mock_module.get_provider_impl = AsyncMock(return_value=impl)
|
||||
mock_module.get_provider_impl.__text_signature__ = "()"
|
||||
sys.modules["test_module"] = mock_module
|
||||
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry)
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
|
||||
|
||||
assert Api.inference in impls
|
||||
assert isinstance(impls[Api.inference], InferenceRouter)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue