fix: resolve mypy errors in meta_reference agents

- Fix type annotations for Safety API message parameters
- Add Action enum usage in access control checks
- Correct method signatures to match API supertype
- Handle optional return types with proper None checks
- Remove meta_reference exclusions from mypy config
This commit is contained in:
Ashwin Bharambe 2025-10-29 11:03:10 -07:00
parent b90c6a2c8b
commit f0d365622a
5 changed files with 48 additions and 26 deletions

View file

@ -284,12 +284,6 @@ exclude = [
"^src/llama_stack/models/llama/llama3/interface\\.py$", "^src/llama_stack/models/llama/llama3/interface\\.py$",
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$", "^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$", "^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/config\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/persistence\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/__init__\\.py$",
"^src/llama_stack/providers/inline/datasetio/localfs/", "^src/llama_stack/providers/inline/datasetio/localfs/",
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",

View file

@ -11,7 +11,7 @@ import uuid
import warnings import warnings
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any, cast
import httpx import httpx
@ -363,7 +363,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.input_shields: if self.input_shields:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input" turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -392,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.output_shields: if self.output_shields:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output" turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -404,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
messages: list[Message], messages: list[OpenAIMessageParam],
shields: list[str], shields: list[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
Document, Document,
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
ListOpenAIResponseObject, ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput, OpenAIResponseInput,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
persistence_store=( persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
), ),
created_at=agent_info.created_at, created_at=agent_info.created_at.isoformat(),
policy=self.policy, policy=self.policy,
telemetry_enabled=self.telemetry_enabled, telemetry_enabled=self.telemetry_enabled,
) )
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
messages: list[UserMessage | ToolResponseMessage], messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False, stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnCreateRequest( request = AgentTurnCreateRequest(
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id) turn = await agent.storage.get_session_turn(session_id, turn_id)
if turn is None:
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
return turn return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse: async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_session( async def get_agents_session(
self, self,
agent_id: str,
session_id: str, session_id: str,
agent_id: str,
turn_ids: list[str] | None = None, turn_ids: list[str] | None = None,
) -> Session: ) -> Session:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id) session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id) turns = await agent.storage.get_session_turns(session_id)
if turn_ids: if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids] turns = [turn for turn in turns if turn.turn_id in turn_ids]
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
started_at=session_info.started_at, started_at=session_info.started_at,
) )
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session # Delete turns first, then the session
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
agent = Agent( agent = Agent(
agent_id=agent_id, agent_id=agent_id,
agent_config=chat_agent.agent_config, agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at, created_at=datetime.fromisoformat(chat_agent.created_at),
) )
return agent return agent
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
response_id: str, response_id: str,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.get_openai_response(response_id) return await self.openai_responses_impl.get_openai_response(response_id)
async def create_openai_response( async def create_openai_response(
@ -342,7 +348,12 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None, guardrails: list[ResponseGuardrail] | None = None,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response( assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from typing import cast as typing_cast
# Cast guardrails to the more specific type expected by the implementation
guardrails_spec = typing_cast(list[ResponseGuardrailSpec] | None, guardrails)
result = await self.openai_responses_impl.create_openai_response(
input, input,
model, model,
prompt, prompt,
@ -356,8 +367,9 @@ class MetaReferenceAgentsImpl(Agents):
tools, tools,
include, include,
max_infer_iters, max_infer_iters,
guardrails, guardrails_spec,
) )
return typing_cast(OpenAIResponseObject, result)
async def list_openai_responses( async def list_openai_responses(
self, self,
@ -366,6 +378,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str | None = None, model: str | None = None,
order: Order | None = Order.desc, order: Order | None = Order.desc,
) -> ListOpenAIResponseObject: ) -> ListOpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
async def list_openai_response_input_items( async def list_openai_response_input_items(
@ -377,9 +390,11 @@ class MetaReferenceAgentsImpl(Agents):
limit: int | None = 20, limit: int | None = 20,
order: Order | None = Order.desc, order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem: ) -> ListOpenAIResponseInputItem:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_response_input_items( return await self.openai_responses_impl.list_openai_response_input_items(
response_id, after, before, include, limit, order response_id, after, before, include, limit, order
) )
async def delete_openai_response(self, response_id: str) -> None: async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.delete_openai_response(response_id) return await self.openai_responses_impl.delete_openai_response(response_id)

View file

@ -7,11 +7,12 @@
import json import json
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import cast
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -53,8 +54,12 @@ class AgentPersistence:
turns=[], turns=[],
identifier=name, # should this be qualified in any way? identifier=name, # should this be qualified in any way?
) )
if not is_action_allowed(self.policy, "create", session_info, user): # Both identifier and owner are set above, safe to use for access control
raise AccessDeniedError("create", session_info, user) assert session_info.identifier is not None and session_info.owner is not None
from llama_stack.core.access_control.conditions import ProtectedResource
resource = cast(ProtectedResource, session_info)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
@ -62,7 +67,7 @@ class AgentPersistence:
) )
return session_id return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo: async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
) )
@ -83,7 +88,15 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) # Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
assert session_info.identifier is not None and session_info.owner is not None
from llama_stack.core.access_control.conditions import ProtectedResource
resource = cast(ProtectedResource, session_info)
return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods.""" """Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
from llama_stack.apis.inference import Message from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.core.telemetry import tracing from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields self.input_shields = input_shields
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None: async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str): async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"): async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield( return await self.safety_api.run_shield(