diff --git a/llama_stack/distribution/access_control.py b/llama_stack/distribution/access_control.py new file mode 100644 index 000000000..7c7f12937 --- /dev/null +++ b/llama_stack/distribution/access_control.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from llama_stack.distribution.datatypes import RoutableObjectWithProvider +from llama_stack.log import get_logger + +logger = get_logger(__name__, category="core") + + +def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool: + """Check if the current user has access to the given object, based on access attributes. + + Access control algorithm: + 1. If the resource has no access_attributes, access is GRANTED to all authenticated users + 2. If the user has no attributes, access is DENIED to any object with access_attributes defined + 3. For each attribute category in the resource's access_attributes: + a. If the user lacks that category, access is DENIED + b. If the user has the category but none of the required values, access is DENIED + c. If the user has at least one matching value in each required category, access is GRANTED + + Example: + # Resource requires: + access_attributes = AccessAttributes( + roles=["admin", "data-scientist"], + teams=["ml-team"] + ) + + # User has: + user_attributes = { + "roles": ["data-scientist", "engineer"], + "teams": ["ml-team", "infra-team"], + "projects": ["llama-3"] + } + + # Result: Access GRANTED + # - User has the "data-scientist" role (matches one of the required roles) + # - AND user is part of the "ml-team" (matches the required team) + # - The extra "projects" attribute is ignored + + Args: + obj: The resource object to check access for + + Returns: + bool: True if access is granted, False if denied + """ + # If object has no access attributes, allow access by default + if not hasattr(obj, "access_attributes") or not obj.access_attributes: + return True + + # If no user attributes, deny access to objects with access control + if not user_attributes: + return False + + obj_attributes = obj.access_attributes.model_dump(exclude_none=True) + if not obj_attributes: + return True + + # Check each attribute category (requires ALL categories to match) + for attr_key, required_values in obj_attributes.items(): + user_values = user_attributes.get(attr_key, []) + + if not user_values: + logger.debug( + f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'" + ) + return False + + if not any(val in user_values for val in required_values): + logger.debug( + f"Access denied to {obj.type} '{obj.identifier}': " + f"no match for attribute '{attr_key}', required one of {required_values}" + ) + return False + + logger.debug(f"Access granted to {obj.type} '{obj.identifier}'") + return True diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 560fa92b9..f756c8621 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -39,6 +39,7 @@ from llama_stack.apis.tools import ( ToolHost, ) from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.distribution.access_control import check_access from llama_stack.distribution.datatypes import ( AccessAttributes, RoutableObject, @@ -187,7 +188,7 @@ class CommonRoutingTableImpl(RoutingTable): return None # Check if user has permission to access this object - if not self._check_access(obj): + if not check_access(obj, get_auth_attributes()): logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") return None @@ -230,81 +231,10 @@ class CommonRoutingTableImpl(RoutingTable): # Apply attribute-based access control filtering if filtered_objs: - filtered_objs = [obj for obj in filtered_objs if self._check_access(obj)] + filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())] return filtered_objs - def _check_access(self, obj: RoutableObjectWithProvider) -> bool: - """Check if the current user has access to the given object, based on access attributes. - - Access control algorithm: - 1. If the resource has no access_attributes, access is GRANTED to all authenticated users - 2. If the user has no attributes, access is DENIED to any object with access_attributes defined - 3. For each attribute category in the resource's access_attributes: - a. If the user lacks that category, access is DENIED - b. If the user has the category but none of the required values, access is DENIED - c. If the user has at least one matching value in each required category, access is GRANTED - - Example: - # Resource requires: - access_attributes = AccessAttributes( - roles=["admin", "data-scientist"], - teams=["ml-team"] - ) - - # User has: - user_attributes = { - "roles": ["data-scientist", "engineer"], - "teams": ["ml-team", "infra-team"], - "projects": ["llama-3"] - } - - # Result: Access GRANTED - # - User has the "data-scientist" role (matches one of the required roles) - # - AND user is part of the "ml-team" (matches the required team) - # - The extra "projects" attribute is ignored - - Args: - obj: The resource object to check access for - - Returns: - bool: True if access is granted, False if denied - """ - # If object has no access attributes, allow access by default - if not hasattr(obj, "access_attributes") or not obj.access_attributes: - return True - - # Get user attributes from request context - user_attributes = get_auth_attributes() - - # If no user attributes, deny access to objects with access control - if not user_attributes: - return False - - obj_attributes = obj.access_attributes.model_dump(exclude_none=True) - if not obj_attributes: - return True - - # Check each attribute category (requires ALL categories to match) - for attr_key, required_values in obj_attributes.items(): - user_values = user_attributes.get(attr_key, []) - - if not user_values: - logger.debug( - f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'" - ) - return False - - if not any(val in user_values for val in required_values): - logger.debug( - f"Access denied to {obj.type} '{obj.identifier}': " - f"no match for attribute '{attr_key}', required one of {required_values}" - ) - return False - - logger.debug(f"Access granted to {obj.type} '{obj.identifier}'") - return True - class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> ListModelsResponse: