chore(mypy): part-04 resolve mypy errors in meta_reference agents (#3969)

## Summary
Fixes all mypy type errors in `providers/inline/agents/meta_reference/`
and removes exclusions from pyproject.toml.

## Changes
- Fix type annotations for Safety API message parameters
(OpenAIMessageParam)
- Add Action enum usage in access control checks
- Correct method signatures to match API supertype (parameter ordering)
- Handle optional return types with proper None checks
- Remove 3 meta_reference exclusions from mypy config

**Files fixed:** 25 errors across 3 files (safety.py, persistence.py,
agents.py)
This commit is contained in:
Ashwin Bharambe 2025-10-29 13:37:28 -07:00 committed by GitHub
parent e6b27db30a
commit c9d4b6c54f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 69 additions and 31 deletions

View file

@ -11,7 +11,7 @@ import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
from typing import Any, cast
import httpx
@ -363,7 +363,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.input_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
):
if isinstance(res, bool):
return
@ -392,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.output_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
@ -404,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
shields: list[str],
touchpoint: str,
) -> AsyncGenerator:

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
Document,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
created_at=agent_info.created_at.isoformat(),
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
if turn is None:
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_session(
self,
agent_id: str,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids]
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
started_at=session_info.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
created_at=datetime.fromisoformat(chat_agent.created_at),
)
return agent
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
response_id: str,
) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.get_openai_response(response_id)
async def create_openai_response(
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response(
input,
model,
prompt,
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters,
guardrails,
)
return result # type: ignore[no-any-return]
async def list_openai_responses(
self,
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
async def list_openai_response_input_items(
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_response_input_items(
response_id, after, before, include, limit, order
)
async def delete_openai_response(self, response_id: str) -> None:
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.delete_openai_response(response_id)

View file

@ -6,12 +6,14 @@
import json
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.access_control.conditions import User as ProtocolUser
from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
created_at: datetime
@dataclass
class SessionResource:
"""Concrete implementation of ProtectedResource for session access control."""
type: str
identifier: str
owner: ProtocolUser # Use the protocol type for structural compatibility
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id
@ -53,8 +64,15 @@ class AgentPersistence:
turns=[],
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError("create", session_info, user)
# Only perform access control if we have an authenticated user
if user is not None and session_info.identifier is not None:
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=user,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -62,7 +80,7 @@ class AgentPersistence:
)
return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
@ -83,7 +101,22 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
# Get current user - if None, skip access control (e.g., in tests)
user = get_authenticated_user()
if user is None:
return True
# Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
return is_action_allowed(self.policy, Action.READ, resource, user)
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -254,7 +254,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrailSpec] | None = None,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text

View file

@ -6,7 +6,7 @@
import asyncio
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(