mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
Merge 61f6bd78d0 into 4237eb4aaa
This commit is contained in:
commit
18ff08071f
3 changed files with 231 additions and 5 deletions
|
|
@ -14,6 +14,9 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
||||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.core.access_control.access_control import is_action_allowed
|
||||||
|
from llama_stack.core.datatypes import ModelWithOwner
|
||||||
|
from llama_stack.core.request_headers import get_authenticated_user
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
|
|
@ -93,15 +96,42 @@ class InferenceRouter(Inference):
|
||||||
provider = await self.routing_table.get_provider_impl(model.identifier)
|
provider = await self.routing_table.get_provider_impl(model.identifier)
|
||||||
return provider, model.provider_resource_id
|
return provider, model.provider_resource_id
|
||||||
|
|
||||||
|
# Handles cases where clients use the provider format directly
|
||||||
|
return await self._get_provider_by_fallback(model_id, expected_model_type)
|
||||||
|
|
||||||
|
async def _get_provider_by_fallback(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||||
|
"""
|
||||||
|
Handle fallback case where model_id is in provider_id/provider_resource_id format.
|
||||||
|
"""
|
||||||
splits = model_id.split("/", maxsplit=1)
|
splits = model_id.split("/", maxsplit=1)
|
||||||
if len(splits) != 2:
|
if len(splits) != 2:
|
||||||
raise ModelNotFoundError(model_id)
|
raise ModelNotFoundError(model_id)
|
||||||
|
|
||||||
provider_id, provider_resource_id = splits
|
provider_id, provider_resource_id = splits
|
||||||
|
|
||||||
|
# Check if provider exists
|
||||||
if provider_id not in self.routing_table.impls_by_provider_id:
|
if provider_id not in self.routing_table.impls_by_provider_id:
|
||||||
logger.warning(f"Provider {provider_id} not found for model {model_id}")
|
logger.warning(f"Provider {provider_id} not found for model {model_id}")
|
||||||
raise ModelNotFoundError(model_id)
|
raise ModelNotFoundError(model_id)
|
||||||
|
|
||||||
|
# Create a temporary model object for RBAC check
|
||||||
|
temp_model = ModelWithOwner(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=provider_resource_id,
|
||||||
|
model_type=expected_model_type,
|
||||||
|
type="model",
|
||||||
|
metadata={}, # Empty metadata for temporary object
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform RBAC check
|
||||||
|
user = get_authenticated_user()
|
||||||
|
if not is_action_allowed(self.routing_table.policy, "read", temp_model, user):
|
||||||
|
logger.debug(
|
||||||
|
f"Access denied to model '{model_id}' via fallback path for user {user.principal if user else 'anonymous'}"
|
||||||
|
)
|
||||||
|
raise ModelNotFoundError(model_id)
|
||||||
|
|
||||||
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
|
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
|
||||||
|
|
||||||
async def rerank(
|
async def rerank(
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,12 @@
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.core.access_control.access_control import is_action_allowed
|
||||||
from llama_stack.core.datatypes import (
|
from llama_stack.core.datatypes import (
|
||||||
ModelWithOwner,
|
ModelWithOwner,
|
||||||
RegistryEntrySource,
|
RegistryEntrySource,
|
||||||
)
|
)
|
||||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData, get_authenticated_user
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
|
|
@ -66,6 +67,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
dynamic_models = []
|
dynamic_models = []
|
||||||
|
user = get_authenticated_user()
|
||||||
|
|
||||||
for provider_id, provider in self.impls_by_provider_id.items():
|
for provider_id, provider in self.impls_by_provider_id.items():
|
||||||
# Check if this provider supports provider_data
|
# Check if this provider supports provider_data
|
||||||
|
|
@ -93,15 +95,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if not models:
|
if not models:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Ensure models have fully qualified identifiers with provider_id prefix
|
# Ensure models have fully qualified identifiers and apply RBAC filtering
|
||||||
for model in models:
|
for model in models:
|
||||||
# Only add prefix if model identifier doesn't already have it
|
# Only add prefix if model identifier doesn't already have it
|
||||||
if not model.identifier.startswith(f"{provider_id}/"):
|
if not model.identifier.startswith(f"{provider_id}/"):
|
||||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||||
|
|
||||||
dynamic_models.append(model)
|
# Convert to ModelWithOwner for RBAC check
|
||||||
|
temp_model = ModelWithOwner(
|
||||||
|
identifier=model.identifier,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=model.provider_resource_id,
|
||||||
|
model_type=model.model_type,
|
||||||
|
type="model",
|
||||||
|
metadata=model.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
# Apply RBAC check - only include models user has read permission for
|
||||||
|
if is_action_allowed(self.policy, "read", temp_model, user):
|
||||||
|
dynamic_models.append(model)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"Access denied to dynamic model '{model.identifier}' for user {user.principal if user else 'anonymous'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Fetched {len(dynamic_models)} accessible models from provider {provider_id} using provider_data"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,9 @@ from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
|
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
|
||||||
|
from llama_stack.core.routers.inference import InferenceRouter
|
||||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||||
from llama_stack_api import Api, ModelType
|
from llama_stack_api import Api, Model, ModelNotFoundError, ModelType
|
||||||
|
|
||||||
|
|
||||||
class AsyncMock(MagicMock):
|
class AsyncMock(MagicMock):
|
||||||
|
|
@ -557,3 +558,178 @@ def test_condition_reprs(condition):
|
||||||
from llama_stack.core.access_control.conditions import parse_condition
|
from llama_stack.core.access_control.conditions import parse_condition
|
||||||
|
|
||||||
assert condition == str(parse_condition(condition))
|
assert condition == str(parse_condition(condition))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def restricted_user():
|
||||||
|
"""User with limited access."""
|
||||||
|
return User("restricted-user", {"roles": ["user"]})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_user():
|
||||||
|
"""User with admin access."""
|
||||||
|
return User("admin-user", {"roles": ["admin"]})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rbac_policy():
|
||||||
|
"""RBAC policy that restricts access to certain models."""
|
||||||
|
from llama_stack.core.access_control.datatypes import Action, Scope
|
||||||
|
|
||||||
|
return [
|
||||||
|
# Admins get full access
|
||||||
|
AccessRule(
|
||||||
|
permit=Scope(actions=list(Action)),
|
||||||
|
when=["user with admin in roles"],
|
||||||
|
),
|
||||||
|
# Regular users only get read access to their own resources
|
||||||
|
AccessRule(
|
||||||
|
permit=Scope(actions=[Action.READ]),
|
||||||
|
when=["user is owner"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestInferenceRouterRBACBypass:
|
||||||
|
"""Test RBAC bypass vulnerability in inference router fallback path."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_routing_table(self):
|
||||||
|
"""Create a mock routing table for testing."""
|
||||||
|
routing_table = AsyncMock()
|
||||||
|
routing_table.impls_by_provider_id = {"test-provider": AsyncMock()}
|
||||||
|
routing_table.policy = []
|
||||||
|
return routing_table
|
||||||
|
|
||||||
|
@patch("llama_stack.core.routers.inference.get_authenticated_user")
|
||||||
|
async def test_registry_path_and_fallback_path_consistent(
|
||||||
|
self, mock_get_user, mock_routing_table, restricted_user, admin_user, rbac_policy
|
||||||
|
):
|
||||||
|
"""Test that registry path and fallback path have consistent RBAC enforcement."""
|
||||||
|
mock_routing_table.policy = rbac_policy
|
||||||
|
|
||||||
|
# Create a model owned by admin
|
||||||
|
admin_model = ModelWithOwner(
|
||||||
|
identifier="admin-model",
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="admin-resource",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
type="model",
|
||||||
|
metadata={},
|
||||||
|
owner=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup router
|
||||||
|
router = InferenceRouter(
|
||||||
|
routing_table=mock_routing_table,
|
||||||
|
store=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 1: Restricted user tries to access via registry (should fail)
|
||||||
|
mock_get_user.return_value = restricted_user
|
||||||
|
mock_routing_table.get_object_by_identifier.return_value = None # RBAC blocks it
|
||||||
|
with pytest.raises(ModelNotFoundError):
|
||||||
|
await router._get_model_provider("admin-model", "llm")
|
||||||
|
|
||||||
|
# Test 2: Restricted user tries to access via fallback path (should also fail)
|
||||||
|
mock_routing_table.get_object_by_identifier.return_value = None
|
||||||
|
with pytest.raises(ModelNotFoundError):
|
||||||
|
await router._get_model_provider("test-provider/admin-resource", "llm")
|
||||||
|
|
||||||
|
# Test 3: Admin user can access via registry
|
||||||
|
mock_get_user.return_value = admin_user
|
||||||
|
mock_routing_table.get_object_by_identifier.return_value = admin_model
|
||||||
|
provider_mock = AsyncMock()
|
||||||
|
mock_routing_table.get_provider_impl.return_value = provider_mock
|
||||||
|
|
||||||
|
provider, resource_id = await router._get_model_provider("admin-model", "llm")
|
||||||
|
assert provider == provider_mock
|
||||||
|
assert resource_id == "admin-resource"
|
||||||
|
|
||||||
|
# Test 4: Admin user can also access via fallback path
|
||||||
|
mock_routing_table.get_object_by_identifier.return_value = None
|
||||||
|
provider, resource_id = await router._get_model_provider("test-provider/admin-resource", "llm")
|
||||||
|
assert provider == mock_routing_table.impls_by_provider_id["test-provider"]
|
||||||
|
assert resource_id == "admin-resource"
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelListingRBACBypass:
|
||||||
|
"""Test RBAC bypass vulnerability in dynamic model listing via provider_data."""
|
||||||
|
|
||||||
|
@patch("llama_stack.core.routing_tables.models.instantiate_class_type")
|
||||||
|
@patch("llama_stack.core.routing_tables.models.PROVIDER_DATA_VAR")
|
||||||
|
@patch("llama_stack.core.routing_tables.models.get_authenticated_user")
|
||||||
|
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
|
||||||
|
async def test_dynamic_models_respect_rbac(
|
||||||
|
self,
|
||||||
|
mock_get_user_common,
|
||||||
|
mock_get_user_models,
|
||||||
|
mock_provider_data,
|
||||||
|
mock_instantiate_class,
|
||||||
|
cached_disk_dist_registry,
|
||||||
|
rbac_policy,
|
||||||
|
admin_user,
|
||||||
|
restricted_user,
|
||||||
|
):
|
||||||
|
"""Test that models fetched via provider_data are filtered by RBAC."""
|
||||||
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
|
|
||||||
|
# Create a mock provider that supports provider_data
|
||||||
|
mock_provider = Mock(spec=NeedsRequestProviderData)
|
||||||
|
mock_provider.__provider_spec__ = MagicMock()
|
||||||
|
mock_provider.__provider_spec__.api = Api.inference
|
||||||
|
mock_provider.__provider_spec__.provider_data_validator = "dict"
|
||||||
|
|
||||||
|
# Mock the validator to always succeed
|
||||||
|
mock_validator = MagicMock(return_value={})
|
||||||
|
mock_instantiate_class.return_value = mock_validator
|
||||||
|
|
||||||
|
# Mock list_models to return dynamic models
|
||||||
|
# These are fetched via provider_data and don't have owners initially
|
||||||
|
dynamic_model1 = Model(
|
||||||
|
identifier="dynamic-model-1",
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="dynamic-model-1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
dynamic_model2 = Model(
|
||||||
|
identifier="dynamic-model-2",
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="dynamic-model-2",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
mock_provider.list_models = AsyncMock(return_value=[dynamic_model1, dynamic_model2])
|
||||||
|
|
||||||
|
# Setup routing table with policy (no models pre-registered in registry)
|
||||||
|
routing_table = ModelsRoutingTable(
|
||||||
|
impls_by_provider_id={"test-provider": mock_provider},
|
||||||
|
dist_registry=cached_disk_dist_registry,
|
||||||
|
policy=rbac_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up provider_data context (user has credentials for this provider)
|
||||||
|
mock_provider_data.get.return_value = {"api_key": "test-key"}
|
||||||
|
|
||||||
|
# Test 1: Admin user can see dynamic models
|
||||||
|
# Admin rule allows all actions, so they can see models even without ownership
|
||||||
|
mock_get_user_common.return_value = admin_user
|
||||||
|
mock_get_user_models.return_value = admin_user
|
||||||
|
|
||||||
|
result = await routing_table.list_models()
|
||||||
|
model_ids = [m.identifier for m in result.data]
|
||||||
|
assert "test-provider/dynamic-model-1" in model_ids
|
||||||
|
assert "test-provider/dynamic-model-2" in model_ids
|
||||||
|
|
||||||
|
# Test 2: Restricted user CANNOT see dynamic models
|
||||||
|
# Dynamic models have no owner, and policy requires either admin role OR ownership
|
||||||
|
# This demonstrates the fix: before, these would be returned without RBAC checks
|
||||||
|
mock_get_user_common.return_value = restricted_user
|
||||||
|
mock_get_user_models.return_value = restricted_user
|
||||||
|
|
||||||
|
result = await routing_table.list_models()
|
||||||
|
model_ids = [m.identifier for m in result.data]
|
||||||
|
# Restricted user should see no models (no ownership, not admin)
|
||||||
|
assert len(model_ids) == 0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue