chore(mypy): part-04 resolve mypy errors in meta_reference agents (#3969)

## Summary
Fixes all mypy type errors in `providers/inline/agents/meta_reference/`
and removes exclusions from pyproject.toml.

## Changes
- Fix type annotations for Safety API message parameters
(OpenAIMessageParam)
- Add Action enum usage in access control checks
- Correct method signatures to match API supertype (parameter ordering)
- Handle optional return types with proper None checks
- Remove 3 meta_reference exclusions from mypy config

**Files fixed:** 25 errors across 3 files (safety.py, persistence.py,
agents.py)
This commit is contained in:
Ashwin Bharambe 2025-10-29 13:37:28 -07:00 committed by GitHub
parent e6b27db30a
commit c9d4b6c54f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 69 additions and 31 deletions

View file

@ -284,12 +284,6 @@ exclude = [
"^src/llama_stack/models/llama/llama3/interface\\.py$", "^src/llama_stack/models/llama/llama3/interface\\.py$",
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$", "^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$", "^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/config\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/persistence\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/__init__\\.py$",
"^src/llama_stack/providers/inline/datasetio/localfs/", "^src/llama_stack/providers/inline/datasetio/localfs/",
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",

View file

@ -11,7 +11,7 @@ import uuid
import warnings import warnings
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any, cast
import httpx import httpx
@ -363,7 +363,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.input_shields: if self.input_shields:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input" turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -392,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.output_shields: if self.output_shields:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output" turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -404,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
messages: list[Message], messages: list[OpenAIMessageParam],
shields: list[str], shields: list[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
Document, Document,
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
ListOpenAIResponseObject, ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput, OpenAIResponseInput,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
persistence_store=( persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
), ),
created_at=agent_info.created_at, created_at=agent_info.created_at.isoformat(),
policy=self.policy, policy=self.policy,
telemetry_enabled=self.telemetry_enabled, telemetry_enabled=self.telemetry_enabled,
) )
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
messages: list[UserMessage | ToolResponseMessage], messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False, stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnCreateRequest( request = AgentTurnCreateRequest(
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id) turn = await agent.storage.get_session_turn(session_id, turn_id)
if turn is None:
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
return turn return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse: async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_session( async def get_agents_session(
self, self,
agent_id: str,
session_id: str, session_id: str,
agent_id: str,
turn_ids: list[str] | None = None, turn_ids: list[str] | None = None,
) -> Session: ) -> Session:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id) session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id) turns = await agent.storage.get_session_turns(session_id)
if turn_ids: if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids] turns = [turn for turn in turns if turn.turn_id in turn_ids]
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
started_at=session_info.started_at, started_at=session_info.started_at,
) )
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session # Delete turns first, then the session
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
agent = Agent( agent = Agent(
agent_id=agent_id, agent_id=agent_id,
agent_config=chat_agent.agent_config, agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at, created_at=datetime.fromisoformat(chat_agent.created_at),
) )
return agent return agent
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
response_id: str, response_id: str,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.get_openai_response(response_id) return await self.openai_responses_impl.get_openai_response(response_id)
async def create_openai_response( async def create_openai_response(
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None, guardrails: list[ResponseGuardrail] | None = None,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response( assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response(
input, input,
model, model,
prompt, prompt,
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters, max_infer_iters,
guardrails, guardrails,
) )
return result # type: ignore[no-any-return]
async def list_openai_responses( async def list_openai_responses(
self, self,
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str | None = None, model: str | None = None,
order: Order | None = Order.desc, order: Order | None = Order.desc,
) -> ListOpenAIResponseObject: ) -> ListOpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
async def list_openai_response_input_items( async def list_openai_response_input_items(
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
limit: int | None = 20, limit: int | None = 20,
order: Order | None = Order.desc, order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem: ) -> ListOpenAIResponseInputItem:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_response_input_items( return await self.openai_responses_impl.list_openai_response_input_items(
response_id, after, before, include, limit, order response_id, after, before, include, limit, order
) )
async def delete_openai_response(self, response_id: str) -> None: async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.delete_openai_response(response_id) return await self.openai_responses_impl.delete_openai_response(response_id)

View file

@ -6,12 +6,14 @@
import json import json
import uuid import uuid
from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.access_control.conditions import User as ProtocolUser
from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
created_at: datetime created_at: datetime
@dataclass
class SessionResource:
"""Concrete implementation of ProtectedResource for session access control."""
type: str
identifier: str
owner: ProtocolUser # Use the protocol type for structural compatibility
class AgentPersistence: class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]): def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id self.agent_id = agent_id
@ -53,8 +64,15 @@ class AgentPersistence:
turns=[], turns=[],
identifier=name, # should this be qualified in any way? identifier=name, # should this be qualified in any way?
) )
if not is_action_allowed(self.policy, "create", session_info, user): # Only perform access control if we have an authenticated user
raise AccessDeniedError("create", session_info, user) if user is not None and session_info.identifier is not None:
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=user,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
@ -62,7 +80,7 @@ class AgentPersistence:
) )
return session_id return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo: async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
) )
@ -83,7 +101,22 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) # Get current user - if None, skip access control (e.g., in tests)
user = get_authenticated_user()
if user is None:
return True
# Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
return is_action_allowed(self.policy, Action.READ, resource, user)
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods.""" """Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -254,7 +254,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrailSpec] | None = None, guardrails: list[str | ResponseGuardrailSpec] | None = None,
): ):
stream = bool(stream) stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
from llama_stack.apis.inference import Message from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.core.telemetry import tracing from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields self.input_shields = input_shields
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None: async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str): async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"): async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield( return await self.safety_api.run_shield(

View file

@ -192,18 +192,18 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
assert session_response.session_id is not None assert session_response.session_id is not None
# Verify the session was stored # Verify the session was stored
session = await agents_impl.get_agents_session(agent_id, session_response.session_id) session = await agents_impl.get_agents_session(session_response.session_id, agent_id)
assert session.session_name == "test_session" assert session.session_name == "test_session"
assert session.session_id == session_response.session_id assert session.session_id == session_response.session_id
assert session.started_at is not None assert session.started_at is not None
assert session.turns == [] assert session.turns == []
# Delete the session # Delete the session
await agents_impl.delete_agents_session(agent_id, session_response.session_id) await agents_impl.delete_agents_session(session_response.session_id, agent_id)
# Verify the session was deleted # Verify the session was deleted
with pytest.raises(ValueError): with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session_response.session_id) await agents_impl.get_agents_session(session_response.session_id, agent_id)
@pytest.mark.parametrize("enable_session_persistence", [True, False]) @pytest.mark.parametrize("enable_session_persistence", [True, False])
@ -226,11 +226,11 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
assert session2.session_id in session_ids assert session2.session_id in session_ids
# Delete one session # Delete one session
await agents_impl.delete_agents_session(agent_id, session1.session_id) await agents_impl.delete_agents_session(session1.session_id, agent_id)
# Verify the session was deleted # Verify the session was deleted
with pytest.raises(ValueError): with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session1.session_id) await agents_impl.get_agents_session(session1.session_id, agent_id)
# List sessions again # List sessions again
sessions = await agents_impl.list_agent_sessions(agent_id) sessions = await agents_impl.list_agent_sessions(agent_id)