diff --git a/llama_stack/distribution/access_control.py b/llama_stack/distribution/access_control.py index 7c7f12937..0651ab6eb 100644 --- a/llama_stack/distribution/access_control.py +++ b/llama_stack/distribution/access_control.py @@ -6,13 +6,17 @@ from typing import Any, Dict, Optional -from llama_stack.distribution.datatypes import RoutableObjectWithProvider +from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.log import get_logger logger = get_logger(__name__, category="core") -def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool: +def check_access( + obj_identifier: str, + obj_attributes: Optional[AccessAttributes], + user_attributes: Optional[Dict[str, Any]] = None, +) -> bool: """Check if the current user has access to the given object, based on access attributes. Access control algorithm: @@ -43,39 +47,40 @@ def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict # - The extra "projects" attribute is ignored Args: - obj: The resource object to check access for + obj_identifier: The identifier of the resource object to check access for + obj_attributes: The access attributes of the resource object + user_attributes: The attributes of the current user Returns: bool: True if access is granted, False if denied """ # If object has no access attributes, allow access by default - if not hasattr(obj, "access_attributes") or not obj.access_attributes: + if not obj_attributes: return True # If no user attributes, deny access to objects with access control if not user_attributes: return False - obj_attributes = obj.access_attributes.model_dump(exclude_none=True) - if not obj_attributes: + dict_attribs = obj_attributes.model_dump(exclude_none=True) + if not dict_attribs: return True # Check each attribute category (requires ALL categories to match) - for attr_key, required_values in obj_attributes.items(): + # TODO: formalize this into a proper ABAC policy + for attr_key, required_values in dict_attribs.items(): user_values = user_attributes.get(attr_key, []) if not user_values: - logger.debug( - f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'" - ) + logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'") return False if not any(val in user_values for val in required_values): logger.debug( - f"Access denied to {obj.type} '{obj.identifier}': " + f"Access denied to {obj_identifier}: " f"no match for attribute '{attr_key}', required one of {required_values}" ) return False - logger.debug(f"Access granted to {obj.type} '{obj.identifier}'") + logger.debug(f"Access granted to {obj_identifier}") return True diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index a2bc10fc1..d444b03a3 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -198,7 +198,7 @@ class CommonRoutingTableImpl(RoutingTable): return None # Check if user has permission to access this object - if not check_access(obj, get_auth_attributes()): + if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") return None @@ -241,7 +241,11 @@ class CommonRoutingTableImpl(RoutingTable): # Apply attribute-based access control filtering if filtered_objs: - filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())] + filtered_objs = [ + obj + for obj in filtered_objs + if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) + ] return filtered_objs diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index e7d7d1828..202d43609 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -13,6 +13,9 @@ from typing import List, Optional from pydantic import BaseModel from llama_stack.apis.agents import ToolExecutionStep, Turn +from llama_stack.distribution.access_control import check_access +from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel): # TODO: is this used anywhere? vector_db_id: Optional[str] = None started_at: datetime + access_attributes: Optional[AccessAttributes] = None class AgentPersistence: @@ -33,11 +37,18 @@ class AgentPersistence: async def create_session(self, name: str) -> str: session_id = str(uuid.uuid4()) + + # Get current user's auth attributes for new sessions + auth_attributes = get_auth_attributes() + access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None + session_info = AgentSessionInfo( session_id=session_id, session_name=name, started_at=datetime.now(timezone.utc), + access_attributes=access_attributes, ) + await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", value=session_info.model_dump_json(), @@ -51,12 +62,34 @@ class AgentPersistence: if not value: return None - return AgentSessionInfo(**json.loads(value)) + session_info = AgentSessionInfo(**json.loads(value)) + + # Check access to session + if not self._check_session_access(session_info): + return None + + return session_info + + def _check_session_access(self, session_info: AgentSessionInfo) -> bool: + """Check if current user has access to the session.""" + # Handle backward compatibility for old sessions without access control + if not hasattr(session_info, "access_attributes"): + return True + + return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes()) + + async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]: + """Get session info if the user has access to it. For internal use by sub-session methods.""" + session_info = await self.get_session_info(session_id) + if not session_info: + return None + + return session_info async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): - session_info = await self.get_session_info(session_id) + session_info = await self.get_session_if_accessible(session_id) if session_info is None: - raise ValueError(f"Session {session_id} not found") + raise ValueError(f"Session {session_id} not found or access denied") session_info.vector_db_id = vector_db_id await self.kvstore.set( @@ -65,12 +98,18 @@ class AgentPersistence: ) async def add_turn_to_session(self, session_id: str, turn: Turn): + if not await self.get_session_if_accessible(session_id): + raise ValueError(f"Session {session_id} not found or access denied") + await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", value=turn.model_dump_json(), ) async def get_session_turns(self, session_id: str) -> List[Turn]: + if not await self.get_session_if_accessible(session_id): + raise ValueError(f"Session {session_id} not found or access denied") + values = await self.kvstore.range( start_key=f"session:{self.agent_id}:{session_id}:", end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", @@ -87,6 +126,9 @@ class AgentPersistence: return turns async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]: + if not await self.get_session_if_accessible(session_id): + raise ValueError(f"Session {session_id} not found or access denied") + value = await self.kvstore.get( key=f"session:{self.agent_id}:{session_id}:{turn_id}", ) @@ -95,24 +137,36 @@ class AgentPersistence: return Turn(**json.loads(value)) async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): + if not await self.get_session_if_accessible(session_id): + raise ValueError(f"Session {session_id} not found or access denied") + await self.kvstore.set( key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", value=step.model_dump_json(), ) async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: + if not await self.get_session_if_accessible(session_id): + return None + value = await self.kvstore.get( key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", ) return ToolExecutionStep(**json.loads(value)) if value else None async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int): + if not await self.get_session_if_accessible(session_id): + raise ValueError(f"Session {session_id} not found or access denied") + await self.kvstore.set( key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", value=str(num_infer_iters), ) async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: + if not await self.get_session_if_accessible(session_id): + return None + value = await self.kvstore.get( key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", ) diff --git a/tests/unit/providers/agents/test_persistence_access_control.py b/tests/unit/providers/agents/test_persistence_access_control.py new file mode 100644 index 000000000..ab181a4ae --- /dev/null +++ b/tests/unit/providers/agents/test_persistence_access_control.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import shutil +import tempfile +import uuid +from datetime import datetime +from unittest.mock import patch + +import pytest + +from llama_stack.apis.agents import Turn +from llama_stack.apis.inference import CompletionMessage, StopReason +from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl + + +@pytest.fixture +async def test_setup(): + temp_dir = tempfile.mkdtemp() + db_path = os.path.join(temp_dir, "test_persistence_access_control.db") + kvstore_config = SqliteKVStoreConfig(db_path=db_path) + kvstore = SqliteKVStoreImpl(kvstore_config) + await kvstore.initialize() + agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore) + yield agent_persistence + shutil.rmtree(temp_dir) + + +@pytest.mark.asyncio +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") +async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup): + agent_persistence = test_setup + + # Set creator's attributes for the session + creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]} + mock_get_auth_attributes.return_value = creator_attributes + + # Create a session + session_id = await agent_persistence.create_session("Test Session") + + # Get the session and verify access attributes were set + session_info = await agent_persistence.get_session_info(session_id) + assert session_info is not None + assert session_info.access_attributes is not None + assert session_info.access_attributes.roles == ["researcher"] + assert session_info.access_attributes.teams == ["ai-team"] + + +@pytest.mark.asyncio +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") +async def test_session_access_control(mock_get_auth_attributes, test_setup): + agent_persistence = test_setup + + # Create a session with specific access attributes + session_id = str(uuid.uuid4()) + session_info = AgentSessionInfo( + session_id=session_id, + session_name="Restricted Session", + started_at=datetime.now(), + access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), + ) + + await agent_persistence.kvstore.set( + key=f"session:{agent_persistence.agent_id}:{session_id}", + value=session_info.model_dump_json(), + ) + + # User with matching attributes can access + mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]} + retrieved_session = await agent_persistence.get_session_info(session_id) + assert retrieved_session is not None + assert retrieved_session.session_id == session_id + + # User without matching attributes cannot access + mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]} + retrieved_session = await agent_persistence.get_session_info(session_id) + assert retrieved_session is None + + +@pytest.mark.asyncio +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") +async def test_turn_access_control(mock_get_auth_attributes, test_setup): + agent_persistence = test_setup + + # Create a session with restricted access + session_id = str(uuid.uuid4()) + session_info = AgentSessionInfo( + session_id=session_id, + session_name="Restricted Session", + started_at=datetime.now(), + access_attributes=AccessAttributes(roles=["admin"]), + ) + + await agent_persistence.kvstore.set( + key=f"session:{agent_persistence.agent_id}:{session_id}", + value=session_info.model_dump_json(), + ) + + # Create a turn for this session + turn_id = str(uuid.uuid4()) + turn = Turn( + session_id=session_id, + turn_id=turn_id, + steps=[], + started_at=datetime.now(), + input_messages=[], + output_message=CompletionMessage( + content="Hello", + stop_reason=StopReason.end_of_turn, + ), + ) + + # Admin can add turn + mock_get_auth_attributes.return_value = {"roles": ["admin"]} + await agent_persistence.add_turn_to_session(session_id, turn) + + # Admin can get turn + retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id) + assert retrieved_turn is not None + assert retrieved_turn.turn_id == turn_id + + # Regular user cannot get turn + mock_get_auth_attributes.return_value = {"roles": ["user"]} + with pytest.raises(ValueError): + await agent_persistence.get_session_turn(session_id, turn_id) + + # Regular user cannot get turns for session + with pytest.raises(ValueError): + await agent_persistence.get_session_turns(session_id) + + +@pytest.mark.asyncio +@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes") +async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup): + agent_persistence = test_setup + + # Create a session with restricted access + session_id = str(uuid.uuid4()) + session_info = AgentSessionInfo( + session_id=session_id, + session_name="Restricted Session", + started_at=datetime.now(), + access_attributes=AccessAttributes(roles=["admin"]), + ) + + await agent_persistence.kvstore.set( + key=f"session:{agent_persistence.agent_id}:{session_id}", + value=session_info.model_dump_json(), + ) + + turn_id = str(uuid.uuid4()) + + # Admin user can set inference iterations + mock_get_auth_attributes.return_value = {"roles": ["admin"]} + await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5) + + # Admin user can get inference iterations + infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id) + assert infer_iters == 5 + + # Regular user cannot get inference iterations + mock_get_auth_attributes.return_value = {"roles": ["user"]} + infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id) + assert infer_iters is None + + # Regular user cannot set inference iterations (should raise ValueError) + with pytest.raises(ValueError): + await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)