diff --git a/src/llama_stack/providers/inline/agents/meta_reference/persistence.py b/src/llama_stack/providers/inline/agents/meta_reference/persistence.py index 63857cebe..9e0598bf1 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -38,6 +38,7 @@ class AgentInfo(AgentConfig): @dataclass class SessionResource: """Concrete implementation of ProtectedResource for session access control.""" + type: str identifier: str owner: ProtocolUser # Use the protocol type for structural compatibility @@ -63,15 +64,15 @@ class AgentPersistence: turns=[], identifier=name, # should this be qualified in any way? ) - # Both identifier and owner are set above, safe to use for access control - assert session_info.identifier is not None and session_info.owner is not None - resource = SessionResource( - type=session_info.type, - identifier=session_info.identifier, - owner=session_info.owner, - ) - if not is_action_allowed(self.policy, Action.CREATE, resource, user): - raise AccessDeniedError(Action.CREATE, resource, user) + # Only perform access control if we have an authenticated user + if user is not None and session_info.identifier is not None: + resource = SessionResource( + type=session_info.type, + identifier=session_info.identifier, + owner=user, + ) + if not is_action_allowed(self.policy, Action.CREATE, resource, user): + raise AccessDeniedError(Action.CREATE, resource, user) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", @@ -100,18 +101,22 @@ class AgentPersistence: if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): return True + # Get current user - if None, skip access control (e.g., in tests) + user = get_authenticated_user() + if user is None: + return True + # Access control requires identifier and owner to be set if session_info.identifier is None or session_info.owner is None: return True # At this point, both identifier and owner are guaranteed to be non-None - assert session_info.identifier is not None and session_info.owner is not None resource = SessionResource( type=session_info.type, identifier=session_info.identifier, owner=session_info.owner, ) - return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user()) + return is_action_allowed(self.policy, Action.READ, resource, user) async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: """Get session info if the user has access to it. For internal use by sub-session methods.""" diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index dfd9b6d52..c4f90661c 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -192,18 +192,18 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config assert session_response.session_id is not None # Verify the session was stored - session = await agents_impl.get_agents_session(agent_id, session_response.session_id) + session = await agents_impl.get_agents_session(session_response.session_id, agent_id) assert session.session_name == "test_session" assert session.session_id == session_response.session_id assert session.started_at is not None assert session.turns == [] # Delete the session - await agents_impl.delete_agents_session(agent_id, session_response.session_id) + await agents_impl.delete_agents_session(session_response.session_id, agent_id) # Verify the session was deleted with pytest.raises(ValueError): - await agents_impl.get_agents_session(agent_id, session_response.session_id) + await agents_impl.get_agents_session(session_response.session_id, agent_id) @pytest.mark.parametrize("enable_session_persistence", [True, False]) @@ -226,11 +226,11 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, assert session2.session_id in session_ids # Delete one session - await agents_impl.delete_agents_session(agent_id, session1.session_id) + await agents_impl.delete_agents_session(session1.session_id, agent_id) # Verify the session was deleted with pytest.raises(ValueError): - await agents_impl.get_agents_session(agent_id, session1.session_id) + await agents_impl.get_agents_session(session1.session_id, agent_id) # List sessions again sessions = await agents_impl.list_agent_sessions(agent_id)