From c9d4b6c54faeb8cb3ae92e0aed7c117cd06389c9 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 29 Oct 2025 13:37:28 -0700 Subject: [PATCH] chore(mypy): part-04 resolve mypy errors in meta_reference agents (#3969) ## Summary Fixes all mypy type errors in `providers/inline/agents/meta_reference/` and removes exclusions from pyproject.toml. ## Changes - Fix type annotations for Safety API message parameters (OpenAIMessageParam) - Add Action enum usage in access control checks - Correct method signatures to match API supertype (parameter ordering) - Handle optional return types with proper None checks - Remove 3 meta_reference exclusions from mypy config **Files fixed:** 25 errors across 3 files (safety.py, persistence.py, agents.py) --- pyproject.toml | 6 --- .../agents/meta_reference/agent_instance.py | 8 ++-- .../inline/agents/meta_reference/agents.py | 27 ++++++++---- .../agents/meta_reference/persistence.py | 43 ++++++++++++++++--- .../responses/openai_responses.py | 2 +- .../inline/agents/meta_reference/safety.py | 4 +- .../agent/test_meta_reference_agent.py | 10 ++--- 7 files changed, 69 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 999c3d9a3..e99299dab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -284,12 +284,6 @@ exclude = [ "^src/llama_stack/models/llama/llama3/interface\\.py$", "^src/llama_stack/models/llama/llama3/tokenizer\\.py$", "^src/llama_stack/models/llama/llama3/tool_utils\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/agents\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/config\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/persistence\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/safety\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/__init__\\.py$", "^src/llama_stack/providers/inline/datasetio/localfs/", "^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$", diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 55bf31f57..b6fad553a 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -11,7 +11,7 @@ import uuid import warnings from collections.abc import AsyncGenerator from datetime import UTC, datetime -from typing import Any +from typing import Any, cast import httpx @@ -363,7 +363,7 @@ class ChatAgent(ShieldRunnerMixin): if self.input_shields: async for res in self.run_multiple_shields_wrapper( - turn_id, input_messages, self.input_shields, "user-input" + turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input" ): if isinstance(res, bool): return @@ -392,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin): if self.output_shields: async for res in self.run_multiple_shields_wrapper( - turn_id, messages, self.output_shields, "assistant-output" + turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output" ): if isinstance(res, bool): return @@ -404,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin): async def run_multiple_shields_wrapper( self, turn_id: str, - messages: list[Message], + messages: list[OpenAIMessageParam], shields: list[str], touchpoint: str, ) -> AsyncGenerator: diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index b4b77bacd..85c6cb251 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -21,6 +21,7 @@ from llama_stack.apis.agents import ( Document, ListOpenAIResponseInputItem, ListOpenAIResponseObject, + OpenAIDeleteResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, @@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents): persistence_store=( self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store ), - created_at=agent_info.created_at, + created_at=agent_info.created_at.isoformat(), policy=self.policy, telemetry_enabled=self.telemetry_enabled, ) @@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_id: str, messages: list[UserMessage | ToolResponseMessage], - toolgroups: list[AgentToolGroup] | None = None, - documents: list[Document] | None = None, stream: bool | None = False, + documents: list[Document] | None = None, + toolgroups: list[AgentToolGroup] | None = None, tool_config: ToolConfig | None = None, ) -> AsyncGenerator: request = AgentTurnCreateRequest( @@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents): async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: agent = await self._get_agent_impl(agent_id) turn = await agent.storage.get_session_turn(session_id, turn_id) + if turn is None: + raise ValueError(f"Turn {turn_id} not found in session {session_id}") return turn async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse: @@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents): async def get_agents_session( self, - agent_id: str, session_id: str, + agent_id: str, turn_ids: list[str] | None = None, ) -> Session: agent = await self._get_agent_impl(agent_id) session_info = await agent.storage.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") turns = await agent.storage.get_session_turns(session_id) if turn_ids: turns = [turn for turn in turns if turn.turn_id in turn_ids] @@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents): started_at=session_info.started_at, ) - async def delete_agents_session(self, agent_id: str, session_id: str) -> None: + async def delete_agents_session(self, session_id: str, agent_id: str) -> None: agent = await self._get_agent_impl(agent_id) # Delete turns first, then the session @@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents): agent = Agent( agent_id=agent_id, agent_config=chat_agent.agent_config, - created_at=chat_agent.created_at, + created_at=datetime.fromisoformat(chat_agent.created_at), ) return agent @@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents): self, response_id: str, ) -> OpenAIResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.get_openai_response(response_id) async def create_openai_response( @@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents): max_infer_iters: int | None = 10, guardrails: list[ResponseGuardrail] | None = None, ) -> OpenAIResponseObject: - return await self.openai_responses_impl.create_openai_response( + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" + result = await self.openai_responses_impl.create_openai_response( input, model, prompt, @@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents): max_infer_iters, guardrails, ) + return result # type: ignore[no-any-return] async def list_openai_responses( self, @@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents): model: str | None = None, order: Order | None = Order.desc, ) -> ListOpenAIResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) async def list_openai_response_input_items( @@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents): limit: int | None = 20, order: Order | None = Order.desc, ) -> ListOpenAIResponseInputItem: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.list_openai_response_input_items( response_id, after, before, include, limit, order ) - async def delete_openai_response(self, response_id: str) -> None: + async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.delete_openai_response(response_id) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/persistence.py b/src/llama_stack/providers/inline/agents/meta_reference/persistence.py index 26a2151e3..9e0598bf1 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -6,12 +6,14 @@ import json import uuid +from dataclasses import dataclass from datetime import UTC, datetime from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed -from llama_stack.core.access_control.datatypes import AccessRule +from llama_stack.core.access_control.conditions import User as ProtocolUser +from llama_stack.core.access_control.datatypes import AccessRule, Action from llama_stack.core.datatypes import User from llama_stack.core.request_headers import get_authenticated_user from llama_stack.log import get_logger @@ -33,6 +35,15 @@ class AgentInfo(AgentConfig): created_at: datetime +@dataclass +class SessionResource: + """Concrete implementation of ProtectedResource for session access control.""" + + type: str + identifier: str + owner: ProtocolUser # Use the protocol type for structural compatibility + + class AgentPersistence: def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]): self.agent_id = agent_id @@ -53,8 +64,15 @@ class AgentPersistence: turns=[], identifier=name, # should this be qualified in any way? ) - if not is_action_allowed(self.policy, "create", session_info, user): - raise AccessDeniedError("create", session_info, user) + # Only perform access control if we have an authenticated user + if user is not None and session_info.identifier is not None: + resource = SessionResource( + type=session_info.type, + identifier=session_info.identifier, + owner=user, + ) + if not is_action_allowed(self.policy, Action.CREATE, resource, user): + raise AccessDeniedError(Action.CREATE, resource, user) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", @@ -62,7 +80,7 @@ class AgentPersistence: ) return session_id - async def get_session_info(self, session_id: str) -> AgentSessionInfo: + async def get_session_info(self, session_id: str) -> AgentSessionInfo | None: value = await self.kvstore.get( key=f"session:{self.agent_id}:{session_id}", ) @@ -83,7 +101,22 @@ class AgentPersistence: if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): return True - return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) + # Get current user - if None, skip access control (e.g., in tests) + user = get_authenticated_user() + if user is None: + return True + + # Access control requires identifier and owner to be set + if session_info.identifier is None or session_info.owner is None: + return True + + # At this point, both identifier and owner are guaranteed to be non-None + resource = SessionResource( + type=session_info.type, + identifier=session_info.identifier, + owner=session_info.owner, + ) + return is_action_allowed(self.policy, Action.READ, resource, user) async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: """Get session info if the user has access to it. For internal use by sub-session methods.""" diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index f6769e838..933cfe963 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -254,7 +254,7 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, - guardrails: list[ResponseGuardrailSpec] | None = None, + guardrails: list[str | ResponseGuardrailSpec] | None = None, ): stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text diff --git a/src/llama_stack/providers/inline/agents/meta_reference/safety.py b/src/llama_stack/providers/inline/agents/meta_reference/safety.py index 9baf5a14d..f0ae51423 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -6,7 +6,7 @@ import asyncio -from llama_stack.apis.inference import Message +from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.core.telemetry import tracing from llama_stack.log import get_logger @@ -31,7 +31,7 @@ class ShieldRunnerMixin: self.input_shields = input_shields self.output_shields = output_shields - async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None: + async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None: async def run_shield_with_span(identifier: str): async with tracing.span(f"run_shield_{identifier}"): return await self.safety_api.run_shield( diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index dfd9b6d52..c4f90661c 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -192,18 +192,18 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config assert session_response.session_id is not None # Verify the session was stored - session = await agents_impl.get_agents_session(agent_id, session_response.session_id) + session = await agents_impl.get_agents_session(session_response.session_id, agent_id) assert session.session_name == "test_session" assert session.session_id == session_response.session_id assert session.started_at is not None assert session.turns == [] # Delete the session - await agents_impl.delete_agents_session(agent_id, session_response.session_id) + await agents_impl.delete_agents_session(session_response.session_id, agent_id) # Verify the session was deleted with pytest.raises(ValueError): - await agents_impl.get_agents_session(agent_id, session_response.session_id) + await agents_impl.get_agents_session(session_response.session_id, agent_id) @pytest.mark.parametrize("enable_session_persistence", [True, False]) @@ -226,11 +226,11 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, assert session2.session_id in session_ids # Delete one session - await agents_impl.delete_agents_session(agent_id, session1.session_id) + await agents_impl.delete_agents_session(session1.session_id, agent_id) # Verify the session was deleted with pytest.raises(ValueError): - await agents_impl.get_agents_session(agent_id, session1.session_id) + await agents_impl.get_agents_session(session1.session_id, agent_id) # List sessions again sessions = await agents_impl.list_agent_sessions(agent_id)