chore(api): add mypy coverage to meta_reference

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 20:55:17 +02:00
parent 1d8c00635c
commit f617a28164
3 changed files with 55 additions and 29 deletions

View file

@ -11,7 +11,8 @@ from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.distribution.access_control.conditions import User as ProtectedResourceUser
from llama_stack.distribution.access_control.datatypes import AccessRule, Action
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore
@ -23,8 +24,8 @@ class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_db_id: str | None = None
started_at: datetime
owner: User | None = None
identifier: str | None = None
owner: ProtectedResourceUser
identifier: str
type: str = "session"
@ -43,17 +44,21 @@ class AgentPersistence:
# Get current user's auth attributes for new sessions
user = get_authenticated_user()
# If no user is authenticated, create a default user for backward compatibility
if user is None:
user = User(principal="anonymous", attributes=None)
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(UTC),
owner=user,
owner=user, # User from datatypes is compatible with ProtectedResourceUser protocol
turns=[],
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError("create", session_info, user)
if not is_action_allowed(self.policy, Action.CREATE, session_info, user):
raise AccessDeniedError(Action.CREATE, session_info, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -68,7 +73,17 @@ class AgentPersistence:
if not value:
return None
session_info = AgentSessionInfo(**json.loads(value))
session_data = json.loads(value)
# Handle backward compatibility for sessions without owner field
if "owner" not in session_data or session_data["owner"] is None:
session_data["owner"] = User(principal="anonymous", attributes=None)
# Handle backward compatibility for sessions without identifier field
if "identifier" not in session_data or session_data["identifier"] is None:
session_data["identifier"] = session_data.get("session_name", "unknown")
session_info = AgentSessionInfo(**session_data)
# Check access to session
if not self._check_session_access(session_info):
@ -79,10 +94,10 @@ class AgentPersistence:
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
if not hasattr(session_info, "access_attributes"):
return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
return is_action_allowed(self.policy, Action.READ, session_info, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""