fix(mypy): part-03 completely resolve meta reference responses impl typing issues (#3951)

## Summary
Resolves all mypy errors in meta reference agent OpenAI responses
implementation by adding proper type narrowing, None checks, and
Sequence type support.

## Changes
- Fixed streaming.py, openai_responses.py, utils.py, tool_executor.py,
agent_instance.py
- Added Sequence type support to schema generator (ensures correct JSON
schema generation)
- Applied union type narrowing and None checks throughout

## Test plan
- All modified files pass mypy type checking (0 errors)
- Schema generator produces correct `type: array` for Sequence types

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-10-29 08:07:15 -07:00 committed by GitHub
parent e5c27dbcbf
commit a4f97559d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 174 additions and 78 deletions

View file

@ -91,7 +91,8 @@ class OpenAIResponsesImpl:
input: str | list[OpenAIResponseInput],
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
# Convert Sequence to list for mutation
new_input_items = list(previous_response.input)
new_input_items.extend(previous_response.output)
if isinstance(input, str):
@ -107,7 +108,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None,
conversation: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]:
"""Process input with optional previous response context.
Returns:
@ -208,6 +209,9 @@ class OpenAIResponsesImpl:
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
# Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues
input_items_data: list[OpenAIResponseInput] = []
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
@ -219,7 +223,6 @@ class OpenAIResponsesImpl:
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
@ -289,16 +292,19 @@ class OpenAIResponsesImpl:
failed_response = None
async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}:
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
match stream_chunk.type:
case "response.completed" | "response.incomplete":
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
case "response.failed":
failed_response = stream_chunk.response
case _:
pass # Other event types don't have .response
if failed_response is not None:
error_message = (
@ -326,6 +332,11 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
assert text is not None, "text must not be None"
assert max_infer_iters is not None, "max_infer_iters must not be None"
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
input, tools, previous_response_id, conversation
@ -368,16 +379,19 @@ class OpenAIResponsesImpl:
final_response = None
failed_response = None
output_items = []
# Type as ConversationItem to avoid list invariance issues
output_items: list[ConversationItem] = []
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type in {"response.completed", "response.incomplete"}:
final_response = stream_chunk.response
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
if stream_chunk.type == "response.output_item.done":
item = stream_chunk.item
output_items.append(item)
match stream_chunk.type:
case "response.completed" | "response.incomplete":
final_response = stream_chunk.response
case "response.failed":
failed_response = stream_chunk.response
case "response.output_item.done":
item = stream_chunk.item
output_items.append(item)
case _:
pass # Other event types
# Store and sync before yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
@ -410,7 +424,8 @@ class OpenAIResponsesImpl:
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
) -> None:
"""Sync content and response messages to the conversation."""
conversation_items = []
# Type as ConversationItem union to avoid list invariance issues
conversation_items: list[ConversationItem] = []
if isinstance(input, str):
conversation_items.append(

View file

@ -111,7 +111,7 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
instructions: str,
instructions: str | None,
safety_api,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
self.prompt = prompt
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
ctx.tool_context.previous_tools if ctx.tool_context else {}
)
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
@ -229,7 +231,8 @@ class StreamingResponseOrchestrator:
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
# Pydantic models are dict-compatible but mypy treats them as distinct types
tools=self.ctx.chat_tools, # type: ignore[arg-type]
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
@ -272,7 +275,12 @@ class StreamingResponseOrchestrator:
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
has_tool_calls = (
isinstance(choice.message, OpenAIAssistantMessageParam)
and choice.message.tool_calls
and self.ctx.response_tools
)
if not has_tool_calls:
output_messages.append(
await convert_chat_choice_to_response_message(
choice,
@ -722,7 +730,10 @@ class StreamingResponseOrchestrator:
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
if not is_new_tool_call and response_tool_call is not None:
# Both should have functions since we're inside the tool_call.function check above
assert response_tool_call.function is not None
assert tool_call.function is not None
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
@ -747,10 +758,13 @@ class StreamingResponseOrchestrator:
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call = chat_response_tool_calls[tool_call_index]
# Ensure that arguments, if sent back to the inference provider, are not None
tool_call.function.arguments = tool_call.function.arguments or "{}"
if tool_call.function:
tool_call.function.arguments = tool_call.function.arguments or "{}"
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = tool_call.function.arguments
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
func = chat_response_tool_calls[tool_call_index].function
tool_call_name = func.name if func else ""
# Check if this is an MCP tool call
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
@ -894,12 +908,11 @@ class StreamingResponseOrchestrator:
self.sequence_number += 1
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
item = OpenAIResponseOutputMessageMCPCall(
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
arguments="",
name=tool_call.function.name,
id=matching_item_id,
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
status="in_progress",
)
elif tool_call.function.name == "web_search":
item = OpenAIResponseOutputMessageWebSearchToolCall(
@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator:
description=tool.description,
input_schema=tool.input_schema,
)
return convert_tooldef_to_openai_tool(tool_def)
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
# Initialize chat_tools if not already set
if self.ctx.chat_tools is None:
@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator:
for input_tool in tools:
if input_tool.type == "function":
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
# Need to access tool_groups_api from tool_executor
@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator:
if isinstance(mcp_tool.allowed_tools, list):
always_allowed = mcp_tool.allowed_tools
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
always_allowed = mcp_tool.allowed_tools.always
never_allowed = mcp_tool.allowed_tools.never
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
always_allowed = mcp_tool.allowed_tools.tool_names
# Call list_mcp_tools
tool_defs = None
@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_chat_tool(t)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
# Add to MCP tool mapping
if t.name in self.mcp_tool_to_server:
@ -1120,13 +1133,17 @@ class StreamingResponseOrchestrator:
self, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle all mcp tool lists from previous response that are still valid:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
yield stream_event
# tool_context can be None when no tools are provided in the response request
if self.ctx.tool_context:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(
self.ctx.tool_context.tools_to_process, output_messages
):
yield stream_event
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import cast
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
@ -100,17 +101,19 @@ class ToolContext(BaseModel):
if isinstance(tool, OpenAIResponseToolMCP):
previous_tools_by_label[tool.server_label] = tool
# collect tool definitions which are the same in current and previous requests:
tools_to_process = []
tools_to_process: list[OpenAIResponseInputTool] = []
matched: dict[str, OpenAIResponseInputToolMCP] = {}
for tool in self.current_tools:
# Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union)
# which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct.
for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment]
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
previous_tool = previous_tools_by_label[tool.server_label]
if previous_tool.allowed_tools == tool.allowed_tools:
matched[tool.server_label] = tool
else:
tools_to_process.append(tool)
tools_to_process.append(tool) # type: ignore[arg-type]
else:
tools_to_process.append(tool)
tools_to_process.append(tool) # type: ignore[arg-type]
# tools that are not the same or were not previously defined need to be processed:
self.tools_to_process = tools_to_process
# for all matched definitions, get the mcp_list_tools objects from the previous output:
@ -119,9 +122,11 @@ class ToolContext(BaseModel):
]
# reconstruct the tool to server mappings that can be reused:
for listing in self.previous_tool_listings:
# listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool]
definition = matched[listing.server_label]
for tool in listing.tools:
self.previous_tools[tool.name] = definition
for mcp_tool in listing.tools:
# mcp_tool is MCPListToolsTool which has a name: str field
self.previous_tools[mcp_tool.name] = definition
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.current_tools:
@ -139,6 +144,8 @@ class ToolContext(BaseModel):
server_label=tool.server_label,
allowed_tools=tool.allowed_tools,
)
# Exhaustive check - all tool types should be handled above
raise AssertionError(f"Unexpected tool type: {type(tool)}")
return [convert_tool(tool) for tool in self.current_tools]

View file

@ -7,6 +7,7 @@
import asyncio
import re
import uuid
from collections.abc import Sequence
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message(
return OpenAIResponseMessage(
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
status="completed",
role="assistant",
)
async def convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content(
if isinstance(content, str):
return content
converted_parts = []
# Type with union to avoid list invariance issues
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages(
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
# Output can be None, use empty string as fallback
output_content = input_item.output if input_item.output is not None else ""
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
content=output_content,
tool_call_id=input_item.id,
)
)
@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages(
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
elif isinstance(input_item, OpenAIResponseMessage):
# Narrow type to OpenAIResponseMessage which has content and role attributes
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages(
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
# Dynamic message type call - different message types have different content expectations
messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type]
if len(tool_call_results):
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format(
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
# Assert name exists for json_schema format
assert text.format.get("name"), "json_schema format requires a name"
schema_name: str = text.format["name"] # type: ignore[assignment]
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass
def _extract_citations_from_text(
@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
shields_list = await safety_api.routing_table.list_shields()
# TODO: list_shields not in Safety interface but available at runtime via API routing
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
for guardrail_id in guardrail_ids:
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
for result in response.results:
if result.flagged:
message = result.user_message or "Content blocked by safety guardrails"
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
flagged_categories = (
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
)
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
if flagged_categories:
@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
return message
# No violations found
return None
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""