feat(server): add attribute based access control for resources

This commit is contained in:
Ashwin Bharambe 2025-03-19 08:10:56 -07:00
parent 7c0448456e
commit b937a49436
8 changed files with 862 additions and 35 deletions

View file

@ -40,10 +40,12 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.datatypes import (
AccessAttributes,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -184,6 +186,11 @@ class CommonRoutingTableImpl(RoutingTable):
if not obj:
return None
# Check if user has permission to access this object
if not self._check_access(obj):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
@ -200,6 +207,13 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
@ -212,7 +226,89 @@ class CommonRoutingTableImpl(RoutingTable):
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [obj for obj in filtered_objs if self._check_access(obj)]
return filtered_objs
def _check_access(self, obj: RoutableObjectWithProvider) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj: The resource object to check access for
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
return True
# Get user attributes from context
user_attributes = get_auth_attributes()
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
# Convert AccessAttributes to dictionary for checking
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
# If the model_dump is empty (all fields are None), allow access
if not obj_attributes:
return True
# Check each attribute category (requires ALL categories to match)
for attr_key, required_values in obj_attributes.items():
user_values = user_attributes.get(attr_key, [])
# No values for this category in user attributes
if not user_values:
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
)
return False
# None of the values in this category match (need at least one match per category)
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
return True
class ModelsRoutingTable(CommonRoutingTableImpl, Models):