feat: make sure agent sessions are under access control (#1737)

This builds on top of #1703.

Agent sessions are now properly access controlled.

## Test Plan

Added unit tests
This commit is contained in:
Ashwin Bharambe 2025-03-21 07:31:16 -07:00 committed by GitHub
parent d7a6d92466
commit 03b5c61bfc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 255 additions and 17 deletions

View file

@ -6,13 +6,17 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import RoutableObjectWithProvider from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(__name__, category="core") logger = get_logger(__name__, category="core")
def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool: def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes. """Check if the current user has access to the given object, based on access attributes.
Access control algorithm: Access control algorithm:
@ -43,39 +47,40 @@ def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict
# - The extra "projects" attribute is ignored # - The extra "projects" attribute is ignored
Args: Args:
obj: The resource object to check access for obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns: Returns:
bool: True if access is granted, False if denied bool: True if access is granted, False if denied
""" """
# If object has no access attributes, allow access by default # If object has no access attributes, allow access by default
if not hasattr(obj, "access_attributes") or not obj.access_attributes: if not obj_attributes:
return True return True
# If no user attributes, deny access to objects with access control # If no user attributes, deny access to objects with access control
if not user_attributes: if not user_attributes:
return False return False
obj_attributes = obj.access_attributes.model_dump(exclude_none=True) dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not obj_attributes: if not dict_attribs:
return True return True
# Check each attribute category (requires ALL categories to match) # Check each attribute category (requires ALL categories to match)
for attr_key, required_values in obj_attributes.items(): # TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, []) user_values = user_attributes.get(attr_key, [])
if not user_values: if not user_values:
logger.debug( logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
)
return False return False
if not any(val in user_values for val in required_values): if not any(val in user_values for val in required_values):
logger.debug( logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': " f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}" f"no match for attribute '{attr_key}', required one of {required_values}"
) )
return False return False
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'") logger.debug(f"Access granted to {obj_identifier}")
return True return True

View file

@ -198,7 +198,7 @@ class CommonRoutingTableImpl(RoutingTable):
return None return None
# Check if user has permission to access this object # Check if user has permission to access this object
if not check_access(obj, get_auth_attributes()): if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None return None
@ -241,7 +241,11 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering # Apply attribute-based access control filtering
if filtered_objs: if filtered_objs:
filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())] filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs return filtered_objs

View file

@ -13,6 +13,9 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: Optional[str] = None vector_db_id: Optional[str] = None
started_at: datetime started_at: datetime
access_attributes: Optional[AccessAttributes] = None
class AgentPersistence: class AgentPersistence:
@ -33,11 +37,18 @@ class AgentPersistence:
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes()
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo( session_info = AgentSessionInfo(
session_id=session_id, session_id=session_id,
session_name=name, session_name=name,
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
) )
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(), value=session_info.model_dump_json(),
@ -51,12 +62,34 @@ class AgentPersistence:
if not value: if not value:
return None return None
return AgentSessionInfo(**json.loads(value)) session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"):
return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_info(session_id) session_info = await self.get_session_if_accessible(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found or access denied")
session_info.vector_db_id = vector_db_id session_info.vector_db_id = vector_db_id
await self.kvstore.set( await self.kvstore.set(
@ -65,12 +98,18 @@ class AgentPersistence:
) )
async def add_turn_to_session(self, session_id: str, turn: Turn): async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(), value=turn.model_dump_json(),
) )
async def get_session_turns(self, session_id: str) -> List[Turn]: async def get_session_turns(self, session_id: str) -> List[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range( values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:", start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
@ -87,6 +126,9 @@ class AgentPersistence:
return turns return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]: async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn_id}",
) )
@ -95,24 +137,36 @@ class AgentPersistence:
return Turn(**json.loads(value)) return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(), value=step.model_dump_json(),
) )
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
) )
return ToolExecutionStep(**json.loads(value)) if value else None return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int): async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters), value=str(num_infer_iters),
) )
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
) )

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
import tempfile
import uuid
from datetime import datetime
from unittest.mock import patch
import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
async def test_setup():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_persistence_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore)
yield agent_persistence
shutil.rmtree(temp_dir)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_auth_attributes.return_value = creator_attributes
# Create a session
session_id = await agent_persistence.create_session("Test Session")
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.access_attributes is not None
assert session_info.access_attributes.roles == ["researcher"]
assert session_info.access_attributes.teams == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# User with matching attributes can access
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# Create a turn for this session
turn_id = str(uuid.uuid4())
turn = Turn(
session_id=session_id,
turn_id=turn_id,
steps=[],
started_at=datetime.now(),
input_messages=[],
output_message=CompletionMessage(
content="Hello",
stop_reason=StopReason.end_of_turn,
),
)
# Admin can add turn
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
assert retrieved_turn is not None
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_auth_attributes.return_value = {"roles": ["user"]}
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
# Regular user cannot get turns for session
with pytest.raises(ValueError):
await agent_persistence.get_session_turns(session_id)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_auth_attributes.return_value = {"roles": ["user"]}
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None
# Regular user cannot set inference iterations (should raise ValueError)
with pytest.raises(ValueError):
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)