mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: make sure agent sessions are under access control (#1737)
This builds on top of #1703. Agent sessions are now properly access controlled. ## Test Plan Added unit tests
This commit is contained in:
parent
d7a6d92466
commit
03b5c61bfc
4 changed files with 255 additions and 17 deletions
|
@ -13,6 +13,9 @@ from typing import List, Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
|
|||
# TODO: is this used anywhere?
|
||||
vector_db_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
|
@ -33,11 +37,18 @@ class AgentPersistence:
|
|||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Get current user's auth attributes for new sessions
|
||||
auth_attributes = get_auth_attributes()
|
||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
|
@ -51,12 +62,34 @@ class AgentPersistence:
|
|||
if not value:
|
||||
return None
|
||||
|
||||
return AgentSessionInfo(**json.loads(value))
|
||||
session_info = AgentSessionInfo(**json.loads(value))
|
||||
|
||||
# Check access to session
|
||||
if not self._check_session_access(session_info):
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes"):
|
||||
return True
|
||||
|
||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
|
||||
session_info = await self.get_session_info(session_id)
|
||||
session_info = await self.get_session_if_accessible(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
session_info.vector_db_id = vector_db_id
|
||||
await self.kvstore.set(
|
||||
|
@ -65,12 +98,18 @@ class AgentPersistence:
|
|||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
values = await self.kvstore.range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
|
@ -87,6 +126,9 @@ class AgentPersistence:
|
|||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
|
@ -95,24 +137,36 @@ class AgentPersistence:
|
|||
return Turn(**json.loads(value))
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue