feat: make sure agent sessions are under access control (#1737)

This builds on top of #1703.

Agent sessions are now properly access controlled.

## Test Plan

Added unit tests
This commit is contained in:
Ashwin Bharambe 2025-03-21 07:31:16 -07:00 committed by GitHub
parent d7a6d92466
commit 03b5c61bfc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 255 additions and 17 deletions

View file

@ -6,13 +6,17 @@
from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool:
def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
@ -43,39 +47,40 @@ def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict
# - The extra "projects" attribute is ignored
Args:
obj: The resource object to check access for
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
if not obj_attributes:
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
for attr_key, required_values in obj_attributes.items():
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
)
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': "
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -198,7 +198,7 @@ class CommonRoutingTableImpl(RoutingTable):
return None
# Check if user has permission to access this object
if not check_access(obj, get_auth_attributes()):
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
@ -241,7 +241,11 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())]
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs

View file

@ -13,6 +13,9 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
# TODO: is this used anywhere?
vector_db_id: Optional[str] = None
started_at: datetime
access_attributes: Optional[AccessAttributes] = None
class AgentPersistence:
@ -33,11 +37,18 @@ class AgentPersistence:
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes()
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
@ -51,12 +62,34 @@ class AgentPersistence:
if not value:
return None
return AgentSessionInfo(**json.loads(value))
session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"):
return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_info(session_id)
session_info = await self.get_session_if_accessible(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
raise ValueError(f"Session {session_id} not found or access denied")
session_info.vector_db_id = vector_db_id
await self.kvstore.set(
@ -65,12 +98,18 @@ class AgentPersistence:
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
@ -87,6 +126,9 @@ class AgentPersistence:
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
@ -95,24 +137,36 @@ class AgentPersistence:
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
import tempfile
import uuid
from datetime import datetime
from unittest.mock import patch
import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
async def test_setup():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_persistence_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore)
yield agent_persistence
shutil.rmtree(temp_dir)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_auth_attributes.return_value = creator_attributes
# Create a session
session_id = await agent_persistence.create_session("Test Session")
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.access_attributes is not None
assert session_info.access_attributes.roles == ["researcher"]
assert session_info.access_attributes.teams == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# User with matching attributes can access
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# Create a turn for this session
turn_id = str(uuid.uuid4())
turn = Turn(
session_id=session_id,
turn_id=turn_id,
steps=[],
started_at=datetime.now(),
input_messages=[],
output_message=CompletionMessage(
content="Hello",
stop_reason=StopReason.end_of_turn,
),
)
# Admin can add turn
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
assert retrieved_turn is not None
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_auth_attributes.return_value = {"roles": ["user"]}
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
# Regular user cannot get turns for session
with pytest.raises(ValueError):
await agent_persistence.get_session_turns(session_id)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_auth_attributes.return_value = {"roles": ["user"]}
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None
# Regular user cannot set inference iterations (should raise ValueError)
with pytest.raises(ValueError):
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)