fix(mypy): resolve tool_executor type issues (45 errors fixed)

- Add proper type annotations using Any where needed
- Fix union-attr errors with getattr and walrus operator
- Fix arg-type errors for datetime/enum conversions
- Add type: ignore for list invariance issues
- Remove event variable reuse to satisfy type checker
- Use proper type narrowing for tool execution paths

Patterns established:
- Use getattr() with walrus operator for optional attributes
- Use type: ignore for runtime-correct but mypy-incompatible cases
- Separate event variables by type to avoid union conflicts
This commit is contained in:
Ashwin Bharambe 2025-10-28 11:31:51 -07:00
parent f88416ef87
commit 3a437d80af

View file

@ -7,6 +7,7 @@
import asyncio import asyncio
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch, OpenAIResponseInputToolFileSearch,
@ -22,10 +23,12 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseWebSearchCallSearching, OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall, OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults, OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseOutputMessageWebSearchToolCall,
) )
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -67,7 +70,7 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]: ) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id tool_call_id = tool_call.id
function = tool_call.function function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {} tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
if not function or not tool_call_id or not function.name: if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number) yield ToolExecutionResult(sequence_number=sequence_number)
@ -84,7 +87,16 @@ class ToolExecutor:
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
# Emit completion events for tool execution # Emit completion events for tool execution
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) has_error = bool(
error_exc
or (
result
and (
((error_code := getattr(result, "error_code", None)) and error_code > 0)
or getattr(result, "error_message", None)
)
)
)
async for event_result in self._emit_completion_events( async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
): ):
@ -101,7 +113,11 @@ class ToolExecutor:
sequence_number=sequence_number, sequence_number=sequence_number,
final_output_message=output_message, final_output_message=output_message,
final_input_message=input_message, final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None, citation_files=(
metadata.get("citation_files")
if result and (metadata := getattr(result, "metadata", None))
else None
),
) )
async def _execute_knowledge_search_via_vector_store( async def _execute_knowledge_search_via_vector_store(
@ -188,8 +204,9 @@ class ToolExecutor:
citation_files[file_id] = filename citation_files[file_id] = filename
# Cast to proper InterleavedContent type (list invariance)
return ToolInvocationResult( return ToolInvocationResult(
content=content_items, content=content_items, # type: ignore[arg-type]
metadata={ metadata={
"document_ids": [r.file_id for r in search_results], "document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results],
@ -209,51 +226,50 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]: ) -> AsyncIterator[ToolExecutionResult]:
"""Emit progress events for tool execution start.""" """Emit progress events for tool execution start."""
# Emit in_progress event based on tool type (only for tools with specific streaming events) # Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server: if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1 sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( mcp_progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=mcp_progress_event, sequence_number=sequence_number)
elif function_name == "web_search": elif function_name == "web_search":
sequence_number += 1 sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( web_progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=web_progress_event, sequence_number=sequence_number)
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
sequence_number += 1 sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress( file_progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=file_progress_event, sequence_number=sequence_number)
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event # For web search, emit searching event
if function_name == "web_search": if function_name == "web_search":
sequence_number += 1 sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( web_searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) yield ToolExecutionResult(stream_event=web_searching_event, sequence_number=sequence_number)
# For file search, emit searching event # For file search, emit searching event
if function_name == "knowledge_search": if function_name == "knowledge_search":
sequence_number += 1 sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching( file_searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) yield ToolExecutionResult(stream_event=file_searching_event, sequence_number=sequence_number)
async def _execute_tool( async def _execute_tool(
self, self,
@ -261,7 +277,7 @@ class ToolExecutor:
tool_kwargs: dict, tool_kwargs: dict,
ctx: ChatCompletionContext, ctx: ChatCompletionContext,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[Exception | None, any]: ) -> tuple[Exception | None, Any]:
"""Execute the tool and return error exception and result.""" """Execute the tool and return error exception and result."""
error_exc = None error_exc = None
result = None result = None
@ -284,10 +300,14 @@ class ToolExecutor:
kwargs=tool_kwargs, kwargs=tool_kwargs,
) )
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
response_file_search_tool = next( response_file_search_tool = (
next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None, None,
) )
if ctx.response_tools
else None
)
if response_file_search_tool: if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool # Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options # to support filters and ranking_options
@ -322,35 +342,34 @@ class ToolExecutor:
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]: ) -> AsyncIterator[ToolExecutionResult]:
"""Emit completion or failure events for tool execution.""" """Emit completion or failure events for tool execution."""
completion_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server: if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1 sequence_number += 1
if has_error: if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
else: else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
elif function_name == "web_search": elif function_name == "web_search":
sequence_number += 1 sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
sequence_number += 1 sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
async def _build_result_messages( async def _build_result_messages(
self, self,
@ -360,21 +379,18 @@ class ToolExecutor:
tool_kwargs: dict, tool_kwargs: dict,
ctx: ChatCompletionContext, ctx: ChatCompletionContext,
error_exc: Exception | None, error_exc: Exception | None,
result: any, result: Any,
has_error: bool, has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[any, any]: ) -> tuple[Any, Any]:
"""Build output and input messages from tool execution results.""" """Build output and input messages from tool execution results."""
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
# Build output message # Build output message
message: Any
if mcp_tool_to_server and function.name in mcp_tool_to_server: if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall( message = OpenAIResponseOutputMessageMCPCall(
id=item_id, id=item_id,
arguments=function.arguments, arguments=function.arguments,
@ -383,10 +399,14 @@ class ToolExecutor:
) )
if error_exc: if error_exc:
message.error = str(error_exc) message.error = str(error_exc)
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): elif (
message.error = f"Error (code {result.error_code}): {result.error_message}" result and (error_code := getattr(result, "error_code", None)) and error_code > 0
elif result and result.content: ) or (result and (error_message := getattr(result, "error_message", None))):
message.output = interleaved_content_as_str(result.content) ec = getattr(result, "error_code", "unknown")
em = getattr(result, "error_message", "")
message.error = f"Error (code {ec}): {em}"
elif result and (content := getattr(result, "content", None)):
message.output = interleaved_content_as_str(content)
else: else:
if function.name == "web_search": if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall( message = OpenAIResponseOutputMessageWebSearchToolCall(
@ -401,17 +421,17 @@ class ToolExecutor:
queries=[tool_kwargs.get("query", "")], queries=[tool_kwargs.get("query", "")],
status="completed", status="completed",
) )
if result and "document_ids" in result.metadata: if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
message.results = [] message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]): for i, doc_id in enumerate(metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None text = metadata["chunks"][i] if "chunks" in metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None score = metadata["scores"][i] if "scores" in metadata else None
message.results.append( message.results.append(
OpenAIResponseOutputMessageFileSearchToolCallResults( OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id, file_id=doc_id,
filename=doc_id, filename=doc_id,
text=text, text=text if text is not None else "",
score=score, score=score if score is not None else 0.0,
attributes={}, attributes={},
) )
) )
@ -421,27 +441,31 @@ class ToolExecutor:
raise ValueError(f"Unknown tool {function.name} called") raise ValueError(f"Unknown tool {function.name} called")
# Build input message # Build input message
input_message = None input_message: OpenAIToolMessageParam | None = None
if result and result.content: if result and (result_content := getattr(result, "content", None)):
if isinstance(result.content, str): if isinstance(result_content, str):
content = result.content msg_content: str | list[Any] = result_content
elif isinstance(result.content, list): elif isinstance(result_content, list):
content = [] content_list: list[Any] = []
for item in result.content: for item in result_content:
part: Any
if isinstance(item, TextContentItem): if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text) part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem): elif isinstance(item, ImageContentItem):
if item.image.data: if item.image.data:
url = f"data:image;base64,{item.image.data}" url_value = f"data:image;base64,{item.image.data}"
else: else:
url = item.image.url url_value = str(item.image.url) if item.image.url else ""
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
else: else:
raise ValueError(f"Unknown result content type: {type(item)}") raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part) content_list.append(part)
msg_content = content_list
else: else:
raise ValueError(f"Unknown result content type: {type(result.content)}") raise ValueError(f"Unknown result content type: {type(result_content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) # OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
# This is runtime-safe as the API accepts it, but mypy complains
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
else: else:
text = str(error_exc) if error_exc else "Tool execution failed" text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)