fix: resolve mypy errors in meta_reference agents

- Fix type annotations for Safety API message parameters
- Add Action enum usage in access control checks
- Correct method signatures to match API supertype
- Handle optional return types with proper None checks
- Remove meta_reference exclusions from mypy config
This commit is contained in:
Ashwin Bharambe 2025-10-29 11:03:10 -07:00
parent b90c6a2c8b
commit f0d365622a
5 changed files with 48 additions and 26 deletions

View file

@ -284,12 +284,6 @@ exclude = [
"^src/llama_stack/models/llama/llama3/interface\\.py$",
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/config\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/persistence\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/__init__\\.py$",
"^src/llama_stack/providers/inline/datasetio/localfs/",
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",

View file

@ -11,7 +11,7 @@ import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
from typing import Any, cast
import httpx
@ -363,7 +363,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.input_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
):
if isinstance(res, bool):
return
@ -392,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.output_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
@ -404,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
shields: list[str],
touchpoint: str,
) -> AsyncGenerator:

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
Document,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
created_at=agent_info.created_at.isoformat(),
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
if turn is None:
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_session(
self,
agent_id: str,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids]
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
started_at=session_info.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
created_at=datetime.fromisoformat(chat_agent.created_at),
)
return agent
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
response_id: str,
) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.get_openai_response(response_id)
async def create_openai_response(
@ -342,7 +348,12 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from typing import cast as typing_cast
# Cast guardrails to the more specific type expected by the implementation
guardrails_spec = typing_cast(list[ResponseGuardrailSpec] | None, guardrails)
result = await self.openai_responses_impl.create_openai_response(
input,
model,
prompt,
@ -356,8 +367,9 @@ class MetaReferenceAgentsImpl(Agents):
tools,
include,
max_infer_iters,
guardrails,
guardrails_spec,
)
return typing_cast(OpenAIResponseObject, result)
async def list_openai_responses(
self,
@ -366,6 +378,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
async def list_openai_response_input_items(
@ -377,9 +390,11 @@ class MetaReferenceAgentsImpl(Agents):
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_response_input_items(
response_id, after, before, include, limit, order
)
async def delete_openai_response(self, response_id: str) -> None:
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.delete_openai_response(response_id)

View file

@ -7,11 +7,12 @@
import json
import uuid
from datetime import UTC, datetime
from typing import cast
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
@ -53,8 +54,12 @@ class AgentPersistence:
turns=[],
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError("create", session_info, user)
# Both identifier and owner are set above, safe to use for access control
assert session_info.identifier is not None and session_info.owner is not None
from llama_stack.core.access_control.conditions import ProtectedResource
resource = cast(ProtectedResource, session_info)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -62,7 +67,7 @@ class AgentPersistence:
)
return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
@ -83,7 +88,15 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
# Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
assert session_info.identifier is not None and session_info.owner is not None
from llama_stack.core.access_control.conditions import ProtectedResource
resource = cast(ProtectedResource, session_info)
return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -6,7 +6,7 @@
import asyncio
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(