fix: handle missing authenticated user in tests

Skip access control when no user is authenticated (e.g., in unit tests).
Update test calls to match corrected parameter order for get/delete session methods.
This commit is contained in:
Ashwin Bharambe 2025-10-29 11:41:58 -07:00
parent f4012d7fde
commit a214c442ac
2 changed files with 21 additions and 16 deletions

View file

@ -38,6 +38,7 @@ class AgentInfo(AgentConfig):
@dataclass
class SessionResource:
"""Concrete implementation of ProtectedResource for session access control."""
type: str
identifier: str
owner: ProtocolUser # Use the protocol type for structural compatibility
@ -63,15 +64,15 @@ class AgentPersistence:
turns=[],
identifier=name, # should this be qualified in any way?
)
# Both identifier and owner are set above, safe to use for access control
assert session_info.identifier is not None and session_info.owner is not None
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
# Only perform access control if we have an authenticated user
if user is not None and session_info.identifier is not None:
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=user,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -100,18 +101,22 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
# Get current user - if None, skip access control (e.g., in tests)
user = get_authenticated_user()
if user is None:
return True
# Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
assert session_info.identifier is not None and session_info.owner is not None
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user())
return is_action_allowed(self.policy, Action.READ, resource, user)
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -192,18 +192,18 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
assert session_response.session_id is not None
# Verify the session was stored
session = await agents_impl.get_agents_session(agent_id, session_response.session_id)
session = await agents_impl.get_agents_session(session_response.session_id, agent_id)
assert session.session_name == "test_session"
assert session.session_id == session_response.session_id
assert session.started_at is not None
assert session.turns == []
# Delete the session
await agents_impl.delete_agents_session(agent_id, session_response.session_id)
await agents_impl.delete_agents_session(session_response.session_id, agent_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session_response.session_id)
await agents_impl.get_agents_session(session_response.session_id, agent_id)
@pytest.mark.parametrize("enable_session_persistence", [True, False])
@ -226,11 +226,11 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
assert session2.session_id in session_ids
# Delete one session
await agents_impl.delete_agents_session(agent_id, session1.session_id)
await agents_impl.delete_agents_session(session1.session_id, agent_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session1.session_id)
await agents_impl.get_agents_session(session1.session_id, agent_id)
# List sessions again
sessions = await agents_impl.list_agent_sessions(agent_id)