mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix: RBAC bypass vulnerabilities in model access
Closes security gaps where RBAC checks could be bypassed: o Inference router: Added RBAC enforcement in the fallback path to ensure access control is applied consistently. o Model listing: Dynamic models fetched via provider_data were returned without RBAC checks. Added filtering to ensure users only see models they have permission to access. Both fixes create temporary ModelWithOwner objects for RBAC validation, maintaining security through consistent access control enforcement. Closes: #4269 Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
ee107aadd6
commit
61f6bd78d0
3 changed files with 231 additions and 5 deletions
|
|
@ -12,8 +12,9 @@ from pydantic import TypeAdapter, ValidationError
|
|||
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack_api import Api, ModelType
|
||||
from llama_stack_api import Api, Model, ModelNotFoundError, ModelType
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
|
|
@ -557,3 +558,178 @@ def test_condition_reprs(condition):
|
|||
from llama_stack.core.access_control.conditions import parse_condition
|
||||
|
||||
assert condition == str(parse_condition(condition))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restricted_user():
|
||||
"""User with limited access."""
|
||||
return User("restricted-user", {"roles": ["user"]})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user():
|
||||
"""User with admin access."""
|
||||
return User("admin-user", {"roles": ["admin"]})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rbac_policy():
|
||||
"""RBAC policy that restricts access to certain models."""
|
||||
from llama_stack.core.access_control.datatypes import Action, Scope
|
||||
|
||||
return [
|
||||
# Admins get full access
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=["user with admin in roles"],
|
||||
),
|
||||
# Regular users only get read access to their own resources
|
||||
AccessRule(
|
||||
permit=Scope(actions=[Action.READ]),
|
||||
when=["user is owner"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestInferenceRouterRBACBypass:
|
||||
"""Test RBAC bypass vulnerability in inference router fallback path."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_routing_table(self):
|
||||
"""Create a mock routing table for testing."""
|
||||
routing_table = AsyncMock()
|
||||
routing_table.impls_by_provider_id = {"test-provider": AsyncMock()}
|
||||
routing_table.policy = []
|
||||
return routing_table
|
||||
|
||||
@patch("llama_stack.core.routers.inference.get_authenticated_user")
|
||||
async def test_registry_path_and_fallback_path_consistent(
|
||||
self, mock_get_user, mock_routing_table, restricted_user, admin_user, rbac_policy
|
||||
):
|
||||
"""Test that registry path and fallback path have consistent RBAC enforcement."""
|
||||
mock_routing_table.policy = rbac_policy
|
||||
|
||||
# Create a model owned by admin
|
||||
admin_model = ModelWithOwner(
|
||||
identifier="admin-model",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="admin-resource",
|
||||
model_type=ModelType.llm,
|
||||
type="model",
|
||||
metadata={},
|
||||
owner=admin_user,
|
||||
)
|
||||
|
||||
# Setup router
|
||||
router = InferenceRouter(
|
||||
routing_table=mock_routing_table,
|
||||
store=None,
|
||||
)
|
||||
|
||||
# Test 1: Restricted user tries to access via registry (should fail)
|
||||
mock_get_user.return_value = restricted_user
|
||||
mock_routing_table.get_object_by_identifier.return_value = None # RBAC blocks it
|
||||
with pytest.raises(ModelNotFoundError):
|
||||
await router._get_model_provider("admin-model", "llm")
|
||||
|
||||
# Test 2: Restricted user tries to access via fallback path (should also fail)
|
||||
mock_routing_table.get_object_by_identifier.return_value = None
|
||||
with pytest.raises(ModelNotFoundError):
|
||||
await router._get_model_provider("test-provider/admin-resource", "llm")
|
||||
|
||||
# Test 3: Admin user can access via registry
|
||||
mock_get_user.return_value = admin_user
|
||||
mock_routing_table.get_object_by_identifier.return_value = admin_model
|
||||
provider_mock = AsyncMock()
|
||||
mock_routing_table.get_provider_impl.return_value = provider_mock
|
||||
|
||||
provider, resource_id = await router._get_model_provider("admin-model", "llm")
|
||||
assert provider == provider_mock
|
||||
assert resource_id == "admin-resource"
|
||||
|
||||
# Test 4: Admin user can also access via fallback path
|
||||
mock_routing_table.get_object_by_identifier.return_value = None
|
||||
provider, resource_id = await router._get_model_provider("test-provider/admin-resource", "llm")
|
||||
assert provider == mock_routing_table.impls_by_provider_id["test-provider"]
|
||||
assert resource_id == "admin-resource"
|
||||
|
||||
|
||||
class TestModelListingRBACBypass:
|
||||
"""Test RBAC bypass vulnerability in dynamic model listing via provider_data."""
|
||||
|
||||
@patch("llama_stack.core.routing_tables.models.instantiate_class_type")
|
||||
@patch("llama_stack.core.routing_tables.models.PROVIDER_DATA_VAR")
|
||||
@patch("llama_stack.core.routing_tables.models.get_authenticated_user")
|
||||
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
|
||||
async def test_dynamic_models_respect_rbac(
|
||||
self,
|
||||
mock_get_user_common,
|
||||
mock_get_user_models,
|
||||
mock_provider_data,
|
||||
mock_instantiate_class,
|
||||
cached_disk_dist_registry,
|
||||
rbac_policy,
|
||||
admin_user,
|
||||
restricted_user,
|
||||
):
|
||||
"""Test that models fetched via provider_data are filtered by RBAC."""
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
||||
# Create a mock provider that supports provider_data
|
||||
mock_provider = Mock(spec=NeedsRequestProviderData)
|
||||
mock_provider.__provider_spec__ = MagicMock()
|
||||
mock_provider.__provider_spec__.api = Api.inference
|
||||
mock_provider.__provider_spec__.provider_data_validator = "dict"
|
||||
|
||||
# Mock the validator to always succeed
|
||||
mock_validator = MagicMock(return_value={})
|
||||
mock_instantiate_class.return_value = mock_validator
|
||||
|
||||
# Mock list_models to return dynamic models
|
||||
# These are fetched via provider_data and don't have owners initially
|
||||
dynamic_model1 = Model(
|
||||
identifier="dynamic-model-1",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="dynamic-model-1",
|
||||
model_type=ModelType.llm,
|
||||
metadata={},
|
||||
)
|
||||
dynamic_model2 = Model(
|
||||
identifier="dynamic-model-2",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="dynamic-model-2",
|
||||
model_type=ModelType.llm,
|
||||
metadata={},
|
||||
)
|
||||
mock_provider.list_models = AsyncMock(return_value=[dynamic_model1, dynamic_model2])
|
||||
|
||||
# Setup routing table with policy (no models pre-registered in registry)
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test-provider": mock_provider},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
policy=rbac_policy,
|
||||
)
|
||||
|
||||
# Set up provider_data context (user has credentials for this provider)
|
||||
mock_provider_data.get.return_value = {"api_key": "test-key"}
|
||||
|
||||
# Test 1: Admin user can see dynamic models
|
||||
# Admin rule allows all actions, so they can see models even without ownership
|
||||
mock_get_user_common.return_value = admin_user
|
||||
mock_get_user_models.return_value = admin_user
|
||||
|
||||
result = await routing_table.list_models()
|
||||
model_ids = [m.identifier for m in result.data]
|
||||
assert "test-provider/dynamic-model-1" in model_ids
|
||||
assert "test-provider/dynamic-model-2" in model_ids
|
||||
|
||||
# Test 2: Restricted user CANNOT see dynamic models
|
||||
# Dynamic models have no owner, and policy requires either admin role OR ownership
|
||||
# This demonstrates the fix: before, these would be returned without RBAC checks
|
||||
mock_get_user_common.return_value = restricted_user
|
||||
mock_get_user_models.return_value = restricted_user
|
||||
|
||||
result = await routing_table.list_models()
|
||||
model_ids = [m.identifier for m in result.data]
|
||||
# Restricted user should see no models (no ownership, not admin)
|
||||
assert len(model_ids) == 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue