This commit is contained in:
Derek Higgins 2025-12-03 01:04:15 +00:00 committed by GitHub
commit 18ff08071f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 231 additions and 5 deletions

View file

@ -14,6 +14,9 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.core.access_control.access_control import is_action_allowed
from llama_stack.core.datatypes import ModelWithOwner
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import ( from llama_stack_api import (
@ -93,15 +96,42 @@ class InferenceRouter(Inference):
provider = await self.routing_table.get_provider_impl(model.identifier) provider = await self.routing_table.get_provider_impl(model.identifier)
return provider, model.provider_resource_id return provider, model.provider_resource_id
# Handles cases where clients use the provider format directly
return await self._get_provider_by_fallback(model_id, expected_model_type)
async def _get_provider_by_fallback(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
"""
Handle fallback case where model_id is in provider_id/provider_resource_id format.
"""
splits = model_id.split("/", maxsplit=1) splits = model_id.split("/", maxsplit=1)
if len(splits) != 2: if len(splits) != 2:
raise ModelNotFoundError(model_id) raise ModelNotFoundError(model_id)
provider_id, provider_resource_id = splits provider_id, provider_resource_id = splits
# Check if provider exists
if provider_id not in self.routing_table.impls_by_provider_id: if provider_id not in self.routing_table.impls_by_provider_id:
logger.warning(f"Provider {provider_id} not found for model {model_id}") logger.warning(f"Provider {provider_id} not found for model {model_id}")
raise ModelNotFoundError(model_id) raise ModelNotFoundError(model_id)
# Create a temporary model object for RBAC check
temp_model = ModelWithOwner(
identifier=model_id,
provider_id=provider_id,
provider_resource_id=provider_resource_id,
model_type=expected_model_type,
type="model",
metadata={}, # Empty metadata for temporary object
)
# Perform RBAC check
user = get_authenticated_user()
if not is_action_allowed(self.routing_table.policy, "read", temp_model, user):
logger.debug(
f"Access denied to model '{model_id}' via fallback path for user {user.principal if user else 'anonymous'}"
)
raise ModelNotFoundError(model_id)
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
async def rerank( async def rerank(

View file

@ -7,11 +7,12 @@
import time import time
from typing import Any from typing import Any
from llama_stack.core.access_control.access_control import is_action_allowed
from llama_stack.core.datatypes import ( from llama_stack.core.datatypes import (
ModelWithOwner, ModelWithOwner,
RegistryEntrySource, RegistryEntrySource,
) )
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData, get_authenticated_user
from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack_api import ( from llama_stack_api import (
@ -66,6 +67,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return [] return []
dynamic_models = [] dynamic_models = []
user = get_authenticated_user()
for provider_id, provider in self.impls_by_provider_id.items(): for provider_id, provider in self.impls_by_provider_id.items():
# Check if this provider supports provider_data # Check if this provider supports provider_data
@ -93,15 +95,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if not models: if not models:
continue continue
# Ensure models have fully qualified identifiers with provider_id prefix # Ensure models have fully qualified identifiers and apply RBAC filtering
for model in models: for model in models:
# Only add prefix if model identifier doesn't already have it # Only add prefix if model identifier doesn't already have it
if not model.identifier.startswith(f"{provider_id}/"): if not model.identifier.startswith(f"{provider_id}/"):
model.identifier = f"{provider_id}/{model.provider_resource_id}" model.identifier = f"{provider_id}/{model.provider_resource_id}"
dynamic_models.append(model) # Convert to ModelWithOwner for RBAC check
temp_model = ModelWithOwner(
identifier=model.identifier,
provider_id=provider_id,
provider_resource_id=model.provider_resource_id,
model_type=model.model_type,
type="model",
metadata=model.metadata,
)
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data") # Apply RBAC check - only include models user has read permission for
if is_action_allowed(self.policy, "read", temp_model, user):
dynamic_models.append(model)
else:
logger.debug(
f"Access denied to dynamic model '{model.identifier}' for user {user.principal if user else 'anonymous'}"
)
logger.debug(
f"Fetched {len(dynamic_models)} accessible models from provider {provider_id} using provider_data"
)
except Exception as e: except Exception as e:
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}") logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")

View file

@ -12,8 +12,9 @@ from pydantic import TypeAdapter, ValidationError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack_api import Api, ModelType from llama_stack_api import Api, Model, ModelNotFoundError, ModelType
class AsyncMock(MagicMock): class AsyncMock(MagicMock):
@ -557,3 +558,178 @@ def test_condition_reprs(condition):
from llama_stack.core.access_control.conditions import parse_condition from llama_stack.core.access_control.conditions import parse_condition
assert condition == str(parse_condition(condition)) assert condition == str(parse_condition(condition))
@pytest.fixture
def restricted_user():
"""User with limited access."""
return User("restricted-user", {"roles": ["user"]})
@pytest.fixture
def admin_user():
"""User with admin access."""
return User("admin-user", {"roles": ["admin"]})
@pytest.fixture
def rbac_policy():
"""RBAC policy that restricts access to certain models."""
from llama_stack.core.access_control.datatypes import Action, Scope
return [
# Admins get full access
AccessRule(
permit=Scope(actions=list(Action)),
when=["user with admin in roles"],
),
# Regular users only get read access to their own resources
AccessRule(
permit=Scope(actions=[Action.READ]),
when=["user is owner"],
),
]
class TestInferenceRouterRBACBypass:
"""Test RBAC bypass vulnerability in inference router fallback path."""
@pytest.fixture
def mock_routing_table(self):
"""Create a mock routing table for testing."""
routing_table = AsyncMock()
routing_table.impls_by_provider_id = {"test-provider": AsyncMock()}
routing_table.policy = []
return routing_table
@patch("llama_stack.core.routers.inference.get_authenticated_user")
async def test_registry_path_and_fallback_path_consistent(
self, mock_get_user, mock_routing_table, restricted_user, admin_user, rbac_policy
):
"""Test that registry path and fallback path have consistent RBAC enforcement."""
mock_routing_table.policy = rbac_policy
# Create a model owned by admin
admin_model = ModelWithOwner(
identifier="admin-model",
provider_id="test-provider",
provider_resource_id="admin-resource",
model_type=ModelType.llm,
type="model",
metadata={},
owner=admin_user,
)
# Setup router
router = InferenceRouter(
routing_table=mock_routing_table,
store=None,
)
# Test 1: Restricted user tries to access via registry (should fail)
mock_get_user.return_value = restricted_user
mock_routing_table.get_object_by_identifier.return_value = None # RBAC blocks it
with pytest.raises(ModelNotFoundError):
await router._get_model_provider("admin-model", "llm")
# Test 2: Restricted user tries to access via fallback path (should also fail)
mock_routing_table.get_object_by_identifier.return_value = None
with pytest.raises(ModelNotFoundError):
await router._get_model_provider("test-provider/admin-resource", "llm")
# Test 3: Admin user can access via registry
mock_get_user.return_value = admin_user
mock_routing_table.get_object_by_identifier.return_value = admin_model
provider_mock = AsyncMock()
mock_routing_table.get_provider_impl.return_value = provider_mock
provider, resource_id = await router._get_model_provider("admin-model", "llm")
assert provider == provider_mock
assert resource_id == "admin-resource"
# Test 4: Admin user can also access via fallback path
mock_routing_table.get_object_by_identifier.return_value = None
provider, resource_id = await router._get_model_provider("test-provider/admin-resource", "llm")
assert provider == mock_routing_table.impls_by_provider_id["test-provider"]
assert resource_id == "admin-resource"
class TestModelListingRBACBypass:
"""Test RBAC bypass vulnerability in dynamic model listing via provider_data."""
@patch("llama_stack.core.routing_tables.models.instantiate_class_type")
@patch("llama_stack.core.routing_tables.models.PROVIDER_DATA_VAR")
@patch("llama_stack.core.routing_tables.models.get_authenticated_user")
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
async def test_dynamic_models_respect_rbac(
self,
mock_get_user_common,
mock_get_user_models,
mock_provider_data,
mock_instantiate_class,
cached_disk_dist_registry,
rbac_policy,
admin_user,
restricted_user,
):
"""Test that models fetched via provider_data are filtered by RBAC."""
from llama_stack.core.request_headers import NeedsRequestProviderData
# Create a mock provider that supports provider_data
mock_provider = Mock(spec=NeedsRequestProviderData)
mock_provider.__provider_spec__ = MagicMock()
mock_provider.__provider_spec__.api = Api.inference
mock_provider.__provider_spec__.provider_data_validator = "dict"
# Mock the validator to always succeed
mock_validator = MagicMock(return_value={})
mock_instantiate_class.return_value = mock_validator
# Mock list_models to return dynamic models
# These are fetched via provider_data and don't have owners initially
dynamic_model1 = Model(
identifier="dynamic-model-1",
provider_id="test-provider",
provider_resource_id="dynamic-model-1",
model_type=ModelType.llm,
metadata={},
)
dynamic_model2 = Model(
identifier="dynamic-model-2",
provider_id="test-provider",
provider_resource_id="dynamic-model-2",
model_type=ModelType.llm,
metadata={},
)
mock_provider.list_models = AsyncMock(return_value=[dynamic_model1, dynamic_model2])
# Setup routing table with policy (no models pre-registered in registry)
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test-provider": mock_provider},
dist_registry=cached_disk_dist_registry,
policy=rbac_policy,
)
# Set up provider_data context (user has credentials for this provider)
mock_provider_data.get.return_value = {"api_key": "test-key"}
# Test 1: Admin user can see dynamic models
# Admin rule allows all actions, so they can see models even without ownership
mock_get_user_common.return_value = admin_user
mock_get_user_models.return_value = admin_user
result = await routing_table.list_models()
model_ids = [m.identifier for m in result.data]
assert "test-provider/dynamic-model-1" in model_ids
assert "test-provider/dynamic-model-2" in model_ids
# Test 2: Restricted user CANNOT see dynamic models
# Dynamic models have no owner, and policy requires either admin role OR ownership
# This demonstrates the fix: before, these would be returned without RBAC checks
mock_get_user_common.return_value = restricted_user
mock_get_user_models.return_value = restricted_user
result = await routing_table.list_models()
model_ids = [m.identifier for m in result.data]
# Restricted user should see no models (no ownership, not admin)
assert len(model_ids) == 0