mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
Merge 61f6bd78d0 into 4237eb4aaa
This commit is contained in:
commit
18ff08071f
3 changed files with 231 additions and 5 deletions
|
|
@ -14,6 +14,9 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
|||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.core.access_control.access_control import is_action_allowed
|
||||
from llama_stack.core.datatypes import ModelWithOwner
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import (
|
||||
|
|
@ -93,15 +96,42 @@ class InferenceRouter(Inference):
|
|||
provider = await self.routing_table.get_provider_impl(model.identifier)
|
||||
return provider, model.provider_resource_id
|
||||
|
||||
# Handles cases where clients use the provider format directly
|
||||
return await self._get_provider_by_fallback(model_id, expected_model_type)
|
||||
|
||||
async def _get_provider_by_fallback(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||
"""
|
||||
Handle fallback case where model_id is in provider_id/provider_resource_id format.
|
||||
"""
|
||||
splits = model_id.split("/", maxsplit=1)
|
||||
if len(splits) != 2:
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
provider_id, provider_resource_id = splits
|
||||
|
||||
# Check if provider exists
|
||||
if provider_id not in self.routing_table.impls_by_provider_id:
|
||||
logger.warning(f"Provider {provider_id} not found for model {model_id}")
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
# Create a temporary model object for RBAC check
|
||||
temp_model = ModelWithOwner(
|
||||
identifier=model_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_resource_id,
|
||||
model_type=expected_model_type,
|
||||
type="model",
|
||||
metadata={}, # Empty metadata for temporary object
|
||||
)
|
||||
|
||||
# Perform RBAC check
|
||||
user = get_authenticated_user()
|
||||
if not is_action_allowed(self.routing_table.policy, "read", temp_model, user):
|
||||
logger.debug(
|
||||
f"Access denied to model '{model_id}' via fallback path for user {user.principal if user else 'anonymous'}"
|
||||
)
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
|
||||
|
||||
async def rerank(
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@
|
|||
import time
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.access_control.access_control import is_action_allowed
|
||||
from llama_stack.core.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData, get_authenticated_user
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
|
|
@ -66,6 +67,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return []
|
||||
|
||||
dynamic_models = []
|
||||
user = get_authenticated_user()
|
||||
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
# Check if this provider supports provider_data
|
||||
|
|
@ -93,15 +95,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if not models:
|
||||
continue
|
||||
|
||||
# Ensure models have fully qualified identifiers with provider_id prefix
|
||||
# Ensure models have fully qualified identifiers and apply RBAC filtering
|
||||
for model in models:
|
||||
# Only add prefix if model identifier doesn't already have it
|
||||
if not model.identifier.startswith(f"{provider_id}/"):
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
dynamic_models.append(model)
|
||||
# Convert to ModelWithOwner for RBAC check
|
||||
temp_model = ModelWithOwner(
|
||||
identifier=model.identifier,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=model.provider_resource_id,
|
||||
model_type=model.model_type,
|
||||
type="model",
|
||||
metadata=model.metadata,
|
||||
)
|
||||
|
||||
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
||||
# Apply RBAC check - only include models user has read permission for
|
||||
if is_action_allowed(self.policy, "read", temp_model, user):
|
||||
dynamic_models.append(model)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Access denied to dynamic model '{model.identifier}' for user {user.principal if user else 'anonymous'}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(dynamic_models)} accessible models from provider {provider_id} using provider_data"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ from pydantic import TypeAdapter, ValidationError
|
|||
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack_api import Api, ModelType
|
||||
from llama_stack_api import Api, Model, ModelNotFoundError, ModelType
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
|
|
@ -557,3 +558,178 @@ def test_condition_reprs(condition):
|
|||
from llama_stack.core.access_control.conditions import parse_condition
|
||||
|
||||
assert condition == str(parse_condition(condition))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restricted_user():
|
||||
"""User with limited access."""
|
||||
return User("restricted-user", {"roles": ["user"]})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user():
|
||||
"""User with admin access."""
|
||||
return User("admin-user", {"roles": ["admin"]})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rbac_policy():
|
||||
"""RBAC policy that restricts access to certain models."""
|
||||
from llama_stack.core.access_control.datatypes import Action, Scope
|
||||
|
||||
return [
|
||||
# Admins get full access
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=["user with admin in roles"],
|
||||
),
|
||||
# Regular users only get read access to their own resources
|
||||
AccessRule(
|
||||
permit=Scope(actions=[Action.READ]),
|
||||
when=["user is owner"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestInferenceRouterRBACBypass:
|
||||
"""Test RBAC bypass vulnerability in inference router fallback path."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_routing_table(self):
|
||||
"""Create a mock routing table for testing."""
|
||||
routing_table = AsyncMock()
|
||||
routing_table.impls_by_provider_id = {"test-provider": AsyncMock()}
|
||||
routing_table.policy = []
|
||||
return routing_table
|
||||
|
||||
@patch("llama_stack.core.routers.inference.get_authenticated_user")
|
||||
async def test_registry_path_and_fallback_path_consistent(
|
||||
self, mock_get_user, mock_routing_table, restricted_user, admin_user, rbac_policy
|
||||
):
|
||||
"""Test that registry path and fallback path have consistent RBAC enforcement."""
|
||||
mock_routing_table.policy = rbac_policy
|
||||
|
||||
# Create a model owned by admin
|
||||
admin_model = ModelWithOwner(
|
||||
identifier="admin-model",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="admin-resource",
|
||||
model_type=ModelType.llm,
|
||||
type="model",
|
||||
metadata={},
|
||||
owner=admin_user,
|
||||
)
|
||||
|
||||
# Setup router
|
||||
router = InferenceRouter(
|
||||
routing_table=mock_routing_table,
|
||||
store=None,
|
||||
)
|
||||
|
||||
# Test 1: Restricted user tries to access via registry (should fail)
|
||||
mock_get_user.return_value = restricted_user
|
||||
mock_routing_table.get_object_by_identifier.return_value = None # RBAC blocks it
|
||||
with pytest.raises(ModelNotFoundError):
|
||||
await router._get_model_provider("admin-model", "llm")
|
||||
|
||||
# Test 2: Restricted user tries to access via fallback path (should also fail)
|
||||
mock_routing_table.get_object_by_identifier.return_value = None
|
||||
with pytest.raises(ModelNotFoundError):
|
||||
await router._get_model_provider("test-provider/admin-resource", "llm")
|
||||
|
||||
# Test 3: Admin user can access via registry
|
||||
mock_get_user.return_value = admin_user
|
||||
mock_routing_table.get_object_by_identifier.return_value = admin_model
|
||||
provider_mock = AsyncMock()
|
||||
mock_routing_table.get_provider_impl.return_value = provider_mock
|
||||
|
||||
provider, resource_id = await router._get_model_provider("admin-model", "llm")
|
||||
assert provider == provider_mock
|
||||
assert resource_id == "admin-resource"
|
||||
|
||||
# Test 4: Admin user can also access via fallback path
|
||||
mock_routing_table.get_object_by_identifier.return_value = None
|
||||
provider, resource_id = await router._get_model_provider("test-provider/admin-resource", "llm")
|
||||
assert provider == mock_routing_table.impls_by_provider_id["test-provider"]
|
||||
assert resource_id == "admin-resource"
|
||||
|
||||
|
||||
class TestModelListingRBACBypass:
|
||||
"""Test RBAC bypass vulnerability in dynamic model listing via provider_data."""
|
||||
|
||||
@patch("llama_stack.core.routing_tables.models.instantiate_class_type")
|
||||
@patch("llama_stack.core.routing_tables.models.PROVIDER_DATA_VAR")
|
||||
@patch("llama_stack.core.routing_tables.models.get_authenticated_user")
|
||||
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
|
||||
async def test_dynamic_models_respect_rbac(
|
||||
self,
|
||||
mock_get_user_common,
|
||||
mock_get_user_models,
|
||||
mock_provider_data,
|
||||
mock_instantiate_class,
|
||||
cached_disk_dist_registry,
|
||||
rbac_policy,
|
||||
admin_user,
|
||||
restricted_user,
|
||||
):
|
||||
"""Test that models fetched via provider_data are filtered by RBAC."""
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
||||
# Create a mock provider that supports provider_data
|
||||
mock_provider = Mock(spec=NeedsRequestProviderData)
|
||||
mock_provider.__provider_spec__ = MagicMock()
|
||||
mock_provider.__provider_spec__.api = Api.inference
|
||||
mock_provider.__provider_spec__.provider_data_validator = "dict"
|
||||
|
||||
# Mock the validator to always succeed
|
||||
mock_validator = MagicMock(return_value={})
|
||||
mock_instantiate_class.return_value = mock_validator
|
||||
|
||||
# Mock list_models to return dynamic models
|
||||
# These are fetched via provider_data and don't have owners initially
|
||||
dynamic_model1 = Model(
|
||||
identifier="dynamic-model-1",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="dynamic-model-1",
|
||||
model_type=ModelType.llm,
|
||||
metadata={},
|
||||
)
|
||||
dynamic_model2 = Model(
|
||||
identifier="dynamic-model-2",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="dynamic-model-2",
|
||||
model_type=ModelType.llm,
|
||||
metadata={},
|
||||
)
|
||||
mock_provider.list_models = AsyncMock(return_value=[dynamic_model1, dynamic_model2])
|
||||
|
||||
# Setup routing table with policy (no models pre-registered in registry)
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test-provider": mock_provider},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
policy=rbac_policy,
|
||||
)
|
||||
|
||||
# Set up provider_data context (user has credentials for this provider)
|
||||
mock_provider_data.get.return_value = {"api_key": "test-key"}
|
||||
|
||||
# Test 1: Admin user can see dynamic models
|
||||
# Admin rule allows all actions, so they can see models even without ownership
|
||||
mock_get_user_common.return_value = admin_user
|
||||
mock_get_user_models.return_value = admin_user
|
||||
|
||||
result = await routing_table.list_models()
|
||||
model_ids = [m.identifier for m in result.data]
|
||||
assert "test-provider/dynamic-model-1" in model_ids
|
||||
assert "test-provider/dynamic-model-2" in model_ids
|
||||
|
||||
# Test 2: Restricted user CANNOT see dynamic models
|
||||
# Dynamic models have no owner, and policy requires either admin role OR ownership
|
||||
# This demonstrates the fix: before, these would be returned without RBAC checks
|
||||
mock_get_user_common.return_value = restricted_user
|
||||
mock_get_user_models.return_value = restricted_user
|
||||
|
||||
result = await routing_table.list_models()
|
||||
model_ids = [m.identifier for m in result.data]
|
||||
# Restricted user should see no models (no ownership, not admin)
|
||||
assert len(model_ids) == 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue