forked from phoenix-oss/llama-stack-mirror
feat: make sure agent sessions are under access control (#1737)
This builds on top of #1703. Agent sessions are now properly access controlled. ## Test Plan Added unit tests
This commit is contained in:
parent
d7a6d92466
commit
03b5c61bfc
4 changed files with 255 additions and 17 deletions
|
@ -6,13 +6,17 @@
|
|||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool:
|
||||
def check_access(
|
||||
obj_identifier: str,
|
||||
obj_attributes: Optional[AccessAttributes],
|
||||
user_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
Access control algorithm:
|
||||
|
@ -43,39 +47,40 @@ def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict
|
|||
# - The extra "projects" attribute is ignored
|
||||
|
||||
Args:
|
||||
obj: The resource object to check access for
|
||||
obj_identifier: The identifier of the resource object to check access for
|
||||
obj_attributes: The access attributes of the resource object
|
||||
user_attributes: The attributes of the current user
|
||||
|
||||
Returns:
|
||||
bool: True if access is granted, False if denied
|
||||
"""
|
||||
# If object has no access attributes, allow access by default
|
||||
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
|
||||
if not obj_attributes:
|
||||
dict_attribs = obj_attributes.model_dump(exclude_none=True)
|
||||
if not dict_attribs:
|
||||
return True
|
||||
|
||||
# Check each attribute category (requires ALL categories to match)
|
||||
for attr_key, required_values in obj_attributes.items():
|
||||
# TODO: formalize this into a proper ABAC policy
|
||||
for attr_key, required_values in dict_attribs.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
if not user_values:
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
|
||||
)
|
||||
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
|
||||
return False
|
||||
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': "
|
||||
f"Access denied to {obj_identifier}: "
|
||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
|
||||
logger.debug(f"Access granted to {obj_identifier}")
|
||||
return True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue