refactor: use concrete SessionResource instead of cast to ProtectedResource

Replace unsafe cast with explicit SessionResource dataclass that properly
implements the ProtectedResource protocol.
This commit is contained in:
Ashwin Bharambe 2025-10-29 11:20:41 -07:00
parent f0d365622a
commit 804d9420c9

View file

@ -6,12 +6,13 @@
import json import json
import uuid import uuid
from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import cast
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.conditions import User as ProtocolUser
from llama_stack.core.access_control.datatypes import AccessRule, Action from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user from llama_stack.core.request_headers import get_authenticated_user
@ -34,6 +35,14 @@ class AgentInfo(AgentConfig):
created_at: datetime created_at: datetime
@dataclass
class SessionResource:
"""Concrete implementation of ProtectedResource for session access control."""
type: str
identifier: str
owner: ProtocolUser # Use the protocol type for structural compatibility
class AgentPersistence: class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]): def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id self.agent_id = agent_id
@ -56,8 +65,11 @@ class AgentPersistence:
) )
# Both identifier and owner are set above, safe to use for access control # Both identifier and owner are set above, safe to use for access control
assert session_info.identifier is not None and session_info.owner is not None assert session_info.identifier is not None and session_info.owner is not None
from llama_stack.core.access_control.conditions import ProtectedResource resource = SessionResource(
resource = cast(ProtectedResource, session_info) type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user): if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user) raise AccessDeniedError(Action.CREATE, resource, user)
@ -94,8 +106,11 @@ class AgentPersistence:
# At this point, both identifier and owner are guaranteed to be non-None # At this point, both identifier and owner are guaranteed to be non-None
assert session_info.identifier is not None and session_info.owner is not None assert session_info.identifier is not None and session_info.owner is not None
from llama_stack.core.access_control.conditions import ProtectedResource resource = SessionResource(
resource = cast(ProtectedResource, session_info) type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user()) return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: