mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-06 10:37:22 +00:00
fix: RBAC bypass vulnerabilities in model access (#4270)
Closes security gaps where RBAC checks could be bypassed: o Inference router: Added RBAC enforcement in the fallback path to ensure access control is applied consistently. o Model listing: Dynamic models fetched via provider_data were returned without RBAC checks. Added filtering to ensure users only see models they have permission to access. Both fixes create temporary ModelWithOwner objects for RBAC validation, maintaining security through consistent access control enforcement. Closes: #4269 Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
7f43051a63
commit
8940be23c4
3 changed files with 229 additions and 5 deletions
|
|
@ -14,6 +14,9 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
|||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.core.access_control.access_control import is_action_allowed
|
||||
from llama_stack.core.datatypes import ModelWithOwner
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import (
|
||||
|
|
@ -93,15 +96,41 @@ class InferenceRouter(Inference):
|
|||
provider = await self.routing_table.get_provider_impl(model.identifier)
|
||||
return provider, model.provider_resource_id
|
||||
|
||||
# Handles cases where clients use the provider format directly
|
||||
return await self._get_provider_by_fallback(model_id, expected_model_type)
|
||||
|
||||
async def _get_provider_by_fallback(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||
"""
|
||||
Handle fallback case where model_id is in provider_id/provider_resource_id format.
|
||||
"""
|
||||
splits = model_id.split("/", maxsplit=1)
|
||||
if len(splits) != 2:
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
provider_id, provider_resource_id = splits
|
||||
|
||||
# Check if provider exists
|
||||
if provider_id not in self.routing_table.impls_by_provider_id:
|
||||
logger.warning(f"Provider {provider_id} not found for model {model_id}")
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
# Create a temporary model object for RBAC check
|
||||
temp_model = ModelWithOwner(
|
||||
identifier=model_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_resource_id,
|
||||
model_type=expected_model_type,
|
||||
metadata={}, # Empty metadata for temporary object
|
||||
)
|
||||
|
||||
# Perform RBAC check
|
||||
user = get_authenticated_user()
|
||||
if not is_action_allowed(self.routing_table.policy, "read", temp_model, user):
|
||||
logger.debug(
|
||||
f"Access denied to model '{model_id}' via fallback path for user {user.principal if user else 'anonymous'}"
|
||||
)
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
|
||||
|
||||
async def rerank(
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@
|
|||
import time
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.access_control.access_control import is_action_allowed
|
||||
from llama_stack.core.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData, get_authenticated_user
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
|
|
@ -66,6 +67,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return []
|
||||
|
||||
dynamic_models = []
|
||||
user = get_authenticated_user()
|
||||
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
# Check if this provider supports provider_data
|
||||
|
|
@ -93,15 +95,32 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if not models:
|
||||
continue
|
||||
|
||||
# Ensure models have fully qualified identifiers with provider_id prefix
|
||||
# Ensure models have fully qualified identifiers and apply RBAC filtering
|
||||
for model in models:
|
||||
# Only add prefix if model identifier doesn't already have it
|
||||
if not model.identifier.startswith(f"{provider_id}/"):
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
dynamic_models.append(model)
|
||||
# Convert to ModelWithOwner for RBAC check
|
||||
temp_model = ModelWithOwner(
|
||||
identifier=model.identifier,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=model.provider_resource_id,
|
||||
model_type=model.model_type,
|
||||
metadata=model.metadata,
|
||||
)
|
||||
|
||||
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
||||
# Apply RBAC check - only include models user has read permission for
|
||||
if is_action_allowed(self.policy, "read", temp_model, user):
|
||||
dynamic_models.append(model)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Access denied to dynamic model '{model.identifier}' for user {user.principal if user else 'anonymous'}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(dynamic_models)} accessible models from provider {provider_id} using provider_data"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue