From 3a437d80affc75097e145b2f447396f221cd7f68 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 28 Oct 2025 11:31:51 -0700 Subject: [PATCH] fix(mypy): resolve tool_executor type issues (45 errors fixed) - Add proper type annotations using Any where needed - Fix union-attr errors with getattr and walrus operator - Fix arg-type errors for datetime/enum conversions - Add type: ignore for list invariance issues - Remove event variable reuse to satisfy type checker - Use proper type narrowing for tool execution paths Patterns established: - Use getattr() with walrus operator for optional attributes - Use type: ignore for runtime-correct but mypy-incompatible cases - Separate event variables by type to avoid union conflicts --- .../meta_reference/responses/tool_executor.py | 138 ++++++++++-------- 1 file changed, 81 insertions(+), 57 deletions(-) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 8e0dc9ecb..3a07a220e 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -7,6 +7,7 @@ import asyncio import json from collections.abc import AsyncIterator +from typing import Any from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputToolFileSearch, @@ -22,10 +23,12 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObjectStreamResponseWebSearchCallSearching, OpenAIResponseOutputMessageFileSearchToolCall, OpenAIResponseOutputMessageFileSearchToolCallResults, + OpenAIResponseOutputMessageMCPCall, OpenAIResponseOutputMessageWebSearchToolCall, ) from llama_stack.apis.common.content_types import ( ImageContentItem, + InterleavedContent, TextContentItem, ) from llama_stack.apis.inference import ( @@ -67,7 +70,7 @@ class ToolExecutor: ) -> AsyncIterator[ToolExecutionResult]: tool_call_id = tool_call.id function = tool_call.function - tool_kwargs = json.loads(function.arguments) if function.arguments else {} + tool_kwargs = json.loads(function.arguments) if function and function.arguments else {} if not function or not tool_call_id or not function.name: yield ToolExecutionResult(sequence_number=sequence_number) @@ -84,7 +87,16 @@ class ToolExecutor: error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) # Emit completion events for tool execution - has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + has_error = bool( + error_exc + or ( + result + and ( + ((error_code := getattr(result, "error_code", None)) and error_code > 0) + or getattr(result, "error_message", None) + ) + ) + ) async for event_result in self._emit_completion_events( function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server ): @@ -101,7 +113,11 @@ class ToolExecutor: sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message, - citation_files=result.metadata.get("citation_files") if result and result.metadata else None, + citation_files=( + metadata.get("citation_files") + if result and (metadata := getattr(result, "metadata", None)) + else None + ), ) async def _execute_knowledge_search_via_vector_store( @@ -188,8 +204,9 @@ class ToolExecutor: citation_files[file_id] = filename + # Cast to proper InterleavedContent type (list invariance) return ToolInvocationResult( - content=content_items, + content=content_items, # type: ignore[arg-type] metadata={ "document_ids": [r.file_id for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results], @@ -209,51 +226,50 @@ class ToolExecutor: ) -> AsyncIterator[ToolExecutionResult]: """Emit progress events for tool execution start.""" # Emit in_progress event based on tool type (only for tools with specific streaming events) - progress_event = None if mcp_tool_to_server and function_name in mcp_tool_to_server: sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( + mcp_progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=mcp_progress_event, sequence_number=sequence_number) elif function_name == "web_search": sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + web_progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=web_progress_event, sequence_number=sequence_number) elif function_name == "knowledge_search": sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress( + file_progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) - - if progress_event: - yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) + yield ToolExecutionResult(stream_event=file_progress_event, sequence_number=sequence_number) # For web search, emit searching event if function_name == "web_search": sequence_number += 1 - searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( + web_searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) - yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) + yield ToolExecutionResult(stream_event=web_searching_event, sequence_number=sequence_number) # For file search, emit searching event if function_name == "knowledge_search": sequence_number += 1 - searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching( + file_searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) - yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) + yield ToolExecutionResult(stream_event=file_searching_event, sequence_number=sequence_number) async def _execute_tool( self, @@ -261,7 +277,7 @@ class ToolExecutor: tool_kwargs: dict, ctx: ChatCompletionContext, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, - ) -> tuple[Exception | None, any]: + ) -> tuple[Exception | None, Any]: """Execute the tool and return error exception and result.""" error_exc = None result = None @@ -284,9 +300,13 @@ class ToolExecutor: kwargs=tool_kwargs, ) elif function_name == "knowledge_search": - response_file_search_tool = next( - (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), - None, + response_file_search_tool = ( + next( + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + None, + ) + if ctx.response_tools + else None ) if response_file_search_tool: # Use vector_stores.search API instead of knowledge_search tool @@ -322,35 +342,34 @@ class ToolExecutor: mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, ) -> AsyncIterator[ToolExecutionResult]: """Emit completion or failure events for tool execution.""" - completion_event = None - if mcp_tool_to_server and function_name in mcp_tool_to_server: sequence_number += 1 if has_error: - completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed( sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number) else: - completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number) elif function_name == "web_search": sequence_number += 1 - completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number) elif function_name == "knowledge_search": sequence_number += 1 - completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( + file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) - - if completion_event: - yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number) async def _build_result_messages( self, @@ -360,21 +379,18 @@ class ToolExecutor: tool_kwargs: dict, ctx: ChatCompletionContext, error_exc: Exception | None, - result: any, + result: Any, has_error: bool, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, - ) -> tuple[any, any]: + ) -> tuple[Any, Any]: """Build output and input messages from tool execution results.""" from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) # Build output message + message: Any if mcp_tool_to_server and function.name in mcp_tool_to_server: - from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseOutputMessageMCPCall, - ) - message = OpenAIResponseOutputMessageMCPCall( id=item_id, arguments=function.arguments, @@ -383,10 +399,14 @@ class ToolExecutor: ) if error_exc: message.error = str(error_exc) - elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): - message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result and result.content: - message.output = interleaved_content_as_str(result.content) + elif ( + result and (error_code := getattr(result, "error_code", None)) and error_code > 0 + ) or (result and (error_message := getattr(result, "error_message", None))): + ec = getattr(result, "error_code", "unknown") + em = getattr(result, "error_message", "") + message.error = f"Error (code {ec}): {em}" + elif result and (content := getattr(result, "content", None)): + message.output = interleaved_content_as_str(content) else: if function.name == "web_search": message = OpenAIResponseOutputMessageWebSearchToolCall( @@ -401,17 +421,17 @@ class ToolExecutor: queries=[tool_kwargs.get("query", "")], status="completed", ) - if result and "document_ids" in result.metadata: + if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata: message.results = [] - for i, doc_id in enumerate(result.metadata["document_ids"]): - text = result.metadata["chunks"][i] if "chunks" in result.metadata else None - score = result.metadata["scores"][i] if "scores" in result.metadata else None + for i, doc_id in enumerate(metadata["document_ids"]): + text = metadata["chunks"][i] if "chunks" in metadata else None + score = metadata["scores"][i] if "scores" in metadata else None message.results.append( OpenAIResponseOutputMessageFileSearchToolCallResults( file_id=doc_id, filename=doc_id, - text=text, - score=score, + text=text if text is not None else "", + score=score if score is not None else 0.0, attributes={}, ) ) @@ -421,27 +441,31 @@ class ToolExecutor: raise ValueError(f"Unknown tool {function.name} called") # Build input message - input_message = None - if result and result.content: - if isinstance(result.content, str): - content = result.content - elif isinstance(result.content, list): - content = [] - for item in result.content: + input_message: OpenAIToolMessageParam | None = None + if result and (result_content := getattr(result, "content", None)): + if isinstance(result_content, str): + msg_content: str | list[Any] = result_content + elif isinstance(result_content, list): + content_list: list[Any] = [] + for item in result_content: + part: Any if isinstance(item, TextContentItem): part = OpenAIChatCompletionContentPartTextParam(text=item.text) elif isinstance(item, ImageContentItem): if item.image.data: - url = f"data:image;base64,{item.image.data}" + url_value = f"data:image;base64,{item.image.data}" else: - url = item.image.url - part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + url_value = str(item.image.url) if item.image.url else "" + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value)) else: raise ValueError(f"Unknown result content type: {type(item)}") - content.append(part) + content_list.append(part) + msg_content = content_list else: - raise ValueError(f"Unknown result content type: {type(result.content)}") - input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + raise ValueError(f"Unknown result content type: {type(result_content)}") + # OpenAIToolMessageParam accepts str | list[TextParam] but we may have images + # This is runtime-safe as the API accepts it, but mypy complains + input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type] else: text = str(error_exc) if error_exc else "Tool execution failed" input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)