diff --git a/src/llama_stack/core/routers/inference.py b/src/llama_stack/core/routers/inference.py index 8a7ffaa5f..c3bab3493 100644 --- a/src/llama_stack/core/routers/inference.py +++ b/src/llama_stack/core/routers/inference.py @@ -14,6 +14,9 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam from pydantic import TypeAdapter +from llama_stack.core.access_control.access_control import is_action_allowed +from llama_stack.core.datatypes import ModelWithOwner +from llama_stack.core.request_headers import get_authenticated_user from llama_stack.log import get_logger from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack_api import ( @@ -93,15 +96,42 @@ class InferenceRouter(Inference): provider = await self.routing_table.get_provider_impl(model.identifier) return provider, model.provider_resource_id + # Handles cases where clients use the provider format directly + return await self._get_provider_by_fallback(model_id, expected_model_type) + + async def _get_provider_by_fallback(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]: + """ + Handle fallback case where model_id is in provider_id/provider_resource_id format. + """ splits = model_id.split("/", maxsplit=1) if len(splits) != 2: raise ModelNotFoundError(model_id) provider_id, provider_resource_id = splits + + # Check if provider exists if provider_id not in self.routing_table.impls_by_provider_id: logger.warning(f"Provider {provider_id} not found for model {model_id}") raise ModelNotFoundError(model_id) + # Create a temporary model object for RBAC check + temp_model = ModelWithOwner( + identifier=model_id, + provider_id=provider_id, + provider_resource_id=provider_resource_id, + model_type=expected_model_type, + type="model", + metadata={}, # Empty metadata for temporary object + ) + + # Perform RBAC check + user = get_authenticated_user() + if not is_action_allowed(self.routing_table.policy, "read", temp_model, user): + logger.debug( + f"Access denied to model '{model_id}' via fallback path for user {user.principal if user else 'anonymous'}" + ) + raise ModelNotFoundError(model_id) + return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id async def rerank( diff --git a/src/llama_stack/core/routing_tables/models.py b/src/llama_stack/core/routing_tables/models.py index 1facbb27b..47336e61e 100644 --- a/src/llama_stack/core/routing_tables/models.py +++ b/src/llama_stack/core/routing_tables/models.py @@ -7,11 +7,12 @@ import time from typing import Any +from llama_stack.core.access_control.access_control import is_action_allowed from llama_stack.core.datatypes import ( ModelWithOwner, RegistryEntrySource, ) -from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData +from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData, get_authenticated_user from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from llama_stack_api import ( @@ -66,6 +67,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return [] dynamic_models = [] + user = get_authenticated_user() for provider_id, provider in self.impls_by_provider_id.items(): # Check if this provider supports provider_data @@ -93,15 +95,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if not models: continue - # Ensure models have fully qualified identifiers with provider_id prefix + # Ensure models have fully qualified identifiers and apply RBAC filtering for model in models: # Only add prefix if model identifier doesn't already have it if not model.identifier.startswith(f"{provider_id}/"): model.identifier = f"{provider_id}/{model.provider_resource_id}" - dynamic_models.append(model) + # Convert to ModelWithOwner for RBAC check + temp_model = ModelWithOwner( + identifier=model.identifier, + provider_id=provider_id, + provider_resource_id=model.provider_resource_id, + model_type=model.model_type, + type="model", + metadata=model.metadata, + ) - logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data") + # Apply RBAC check - only include models user has read permission for + if is_action_allowed(self.policy, "read", temp_model, user): + dynamic_models.append(model) + else: + logger.debug( + f"Access denied to dynamic model '{model.identifier}' for user {user.principal if user else 'anonymous'}" + ) + + logger.debug( + f"Fetched {len(dynamic_models)} accessible models from provider {provider_id} using provider_data" + ) except Exception as e: logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}") diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 23a9636d5..bf6a24c90 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -12,8 +12,9 @@ from pydantic import TypeAdapter, ValidationError from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User +from llama_stack.core.routers.inference import InferenceRouter from llama_stack.core.routing_tables.models import ModelsRoutingTable -from llama_stack_api import Api, ModelType +from llama_stack_api import Api, Model, ModelNotFoundError, ModelType class AsyncMock(MagicMock): @@ -557,3 +558,178 @@ def test_condition_reprs(condition): from llama_stack.core.access_control.conditions import parse_condition assert condition == str(parse_condition(condition)) + + +@pytest.fixture +def restricted_user(): + """User with limited access.""" + return User("restricted-user", {"roles": ["user"]}) + + +@pytest.fixture +def admin_user(): + """User with admin access.""" + return User("admin-user", {"roles": ["admin"]}) + + +@pytest.fixture +def rbac_policy(): + """RBAC policy that restricts access to certain models.""" + from llama_stack.core.access_control.datatypes import Action, Scope + + return [ + # Admins get full access + AccessRule( + permit=Scope(actions=list(Action)), + when=["user with admin in roles"], + ), + # Regular users only get read access to their own resources + AccessRule( + permit=Scope(actions=[Action.READ]), + when=["user is owner"], + ), + ] + + +class TestInferenceRouterRBACBypass: + """Test RBAC bypass vulnerability in inference router fallback path.""" + + @pytest.fixture + def mock_routing_table(self): + """Create a mock routing table for testing.""" + routing_table = AsyncMock() + routing_table.impls_by_provider_id = {"test-provider": AsyncMock()} + routing_table.policy = [] + return routing_table + + @patch("llama_stack.core.routers.inference.get_authenticated_user") + async def test_registry_path_and_fallback_path_consistent( + self, mock_get_user, mock_routing_table, restricted_user, admin_user, rbac_policy + ): + """Test that registry path and fallback path have consistent RBAC enforcement.""" + mock_routing_table.policy = rbac_policy + + # Create a model owned by admin + admin_model = ModelWithOwner( + identifier="admin-model", + provider_id="test-provider", + provider_resource_id="admin-resource", + model_type=ModelType.llm, + type="model", + metadata={}, + owner=admin_user, + ) + + # Setup router + router = InferenceRouter( + routing_table=mock_routing_table, + store=None, + ) + + # Test 1: Restricted user tries to access via registry (should fail) + mock_get_user.return_value = restricted_user + mock_routing_table.get_object_by_identifier.return_value = None # RBAC blocks it + with pytest.raises(ModelNotFoundError): + await router._get_model_provider("admin-model", "llm") + + # Test 2: Restricted user tries to access via fallback path (should also fail) + mock_routing_table.get_object_by_identifier.return_value = None + with pytest.raises(ModelNotFoundError): + await router._get_model_provider("test-provider/admin-resource", "llm") + + # Test 3: Admin user can access via registry + mock_get_user.return_value = admin_user + mock_routing_table.get_object_by_identifier.return_value = admin_model + provider_mock = AsyncMock() + mock_routing_table.get_provider_impl.return_value = provider_mock + + provider, resource_id = await router._get_model_provider("admin-model", "llm") + assert provider == provider_mock + assert resource_id == "admin-resource" + + # Test 4: Admin user can also access via fallback path + mock_routing_table.get_object_by_identifier.return_value = None + provider, resource_id = await router._get_model_provider("test-provider/admin-resource", "llm") + assert provider == mock_routing_table.impls_by_provider_id["test-provider"] + assert resource_id == "admin-resource" + + +class TestModelListingRBACBypass: + """Test RBAC bypass vulnerability in dynamic model listing via provider_data.""" + + @patch("llama_stack.core.routing_tables.models.instantiate_class_type") + @patch("llama_stack.core.routing_tables.models.PROVIDER_DATA_VAR") + @patch("llama_stack.core.routing_tables.models.get_authenticated_user") + @patch("llama_stack.core.routing_tables.common.get_authenticated_user") + async def test_dynamic_models_respect_rbac( + self, + mock_get_user_common, + mock_get_user_models, + mock_provider_data, + mock_instantiate_class, + cached_disk_dist_registry, + rbac_policy, + admin_user, + restricted_user, + ): + """Test that models fetched via provider_data are filtered by RBAC.""" + from llama_stack.core.request_headers import NeedsRequestProviderData + + # Create a mock provider that supports provider_data + mock_provider = Mock(spec=NeedsRequestProviderData) + mock_provider.__provider_spec__ = MagicMock() + mock_provider.__provider_spec__.api = Api.inference + mock_provider.__provider_spec__.provider_data_validator = "dict" + + # Mock the validator to always succeed + mock_validator = MagicMock(return_value={}) + mock_instantiate_class.return_value = mock_validator + + # Mock list_models to return dynamic models + # These are fetched via provider_data and don't have owners initially + dynamic_model1 = Model( + identifier="dynamic-model-1", + provider_id="test-provider", + provider_resource_id="dynamic-model-1", + model_type=ModelType.llm, + metadata={}, + ) + dynamic_model2 = Model( + identifier="dynamic-model-2", + provider_id="test-provider", + provider_resource_id="dynamic-model-2", + model_type=ModelType.llm, + metadata={}, + ) + mock_provider.list_models = AsyncMock(return_value=[dynamic_model1, dynamic_model2]) + + # Setup routing table with policy (no models pre-registered in registry) + routing_table = ModelsRoutingTable( + impls_by_provider_id={"test-provider": mock_provider}, + dist_registry=cached_disk_dist_registry, + policy=rbac_policy, + ) + + # Set up provider_data context (user has credentials for this provider) + mock_provider_data.get.return_value = {"api_key": "test-key"} + + # Test 1: Admin user can see dynamic models + # Admin rule allows all actions, so they can see models even without ownership + mock_get_user_common.return_value = admin_user + mock_get_user_models.return_value = admin_user + + result = await routing_table.list_models() + model_ids = [m.identifier for m in result.data] + assert "test-provider/dynamic-model-1" in model_ids + assert "test-provider/dynamic-model-2" in model_ids + + # Test 2: Restricted user CANNOT see dynamic models + # Dynamic models have no owner, and policy requires either admin role OR ownership + # This demonstrates the fix: before, these would be returned without RBAC checks + mock_get_user_common.return_value = restricted_user + mock_get_user_models.return_value = restricted_user + + result = await routing_table.list_models() + model_ids = [m.identifier for m in result.data] + # Restricted user should see no models (no ownership, not admin) + assert len(model_ids) == 0