feat: make sure agent sessions are under access control (#1737)

This builds on top of #1703.

Agent sessions are now properly access controlled.

## Test Plan

Added unit tests
This commit is contained in:
Ashwin Bharambe 2025-03-21 07:31:16 -07:00 committed by GitHub
parent d7a6d92466
commit 03b5c61bfc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 255 additions and 17 deletions

View file

@ -6,13 +6,17 @@
from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool:
def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
@ -43,39 +47,40 @@ def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict
# - The extra "projects" attribute is ignored
Args:
obj: The resource object to check access for
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
if not obj_attributes:
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
for attr_key, required_values in obj_attributes.items():
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
)
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': "
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
logger.debug(f"Access granted to {obj_identifier}")
return True