mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
refactor: use concrete SessionResource instead of cast to ProtectedResource
Replace unsafe cast with explicit SessionResource dataclass that properly implements the ProtectedResource protocol.
This commit is contained in:
parent
f0d365622a
commit
804d9420c9
1 changed files with 20 additions and 5 deletions
|
|
@ -6,12 +6,13 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
|
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
from llama_stack.core.request_headers import get_authenticated_user
|
from llama_stack.core.request_headers import get_authenticated_user
|
||||||
|
|
@ -34,6 +35,14 @@ class AgentInfo(AgentConfig):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionResource:
|
||||||
|
"""Concrete implementation of ProtectedResource for session access control."""
|
||||||
|
type: str
|
||||||
|
identifier: str
|
||||||
|
owner: ProtocolUser # Use the protocol type for structural compatibility
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
|
@ -56,8 +65,11 @@ class AgentPersistence:
|
||||||
)
|
)
|
||||||
# Both identifier and owner are set above, safe to use for access control
|
# Both identifier and owner are set above, safe to use for access control
|
||||||
assert session_info.identifier is not None and session_info.owner is not None
|
assert session_info.identifier is not None and session_info.owner is not None
|
||||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
resource = SessionResource(
|
||||||
resource = cast(ProtectedResource, session_info)
|
type=session_info.type,
|
||||||
|
identifier=session_info.identifier,
|
||||||
|
owner=session_info.owner,
|
||||||
|
)
|
||||||
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||||
raise AccessDeniedError(Action.CREATE, resource, user)
|
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||||
|
|
||||||
|
|
@ -94,8 +106,11 @@ class AgentPersistence:
|
||||||
|
|
||||||
# At this point, both identifier and owner are guaranteed to be non-None
|
# At this point, both identifier and owner are guaranteed to be non-None
|
||||||
assert session_info.identifier is not None and session_info.owner is not None
|
assert session_info.identifier is not None and session_info.owner is not None
|
||||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
resource = SessionResource(
|
||||||
resource = cast(ProtectedResource, session_info)
|
type=session_info.type,
|
||||||
|
identifier=session_info.identifier,
|
||||||
|
owner=session_info.owner,
|
||||||
|
)
|
||||||
return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user())
|
return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user())
|
||||||
|
|
||||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue