From 59793ac63b546f09af5a4b717432fafcebf93019 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Tue, 4 Nov 2025 12:51:19 -0800 Subject: [PATCH] minor linting change --- .../meta_reference/responses/tool_executor.py | 126 ++++++++++++------ 1 file changed, 88 insertions(+), 38 deletions(-) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index d6ec0e849..07d5dfc7c 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -26,10 +26,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageMCPCall, OpenAIResponseOutputMessageWebSearchToolCall, ) -from llama_stack.apis.common.content_types import ( - ImageContentItem, - TextContentItem, -) +from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.inference import ( OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, @@ -69,7 +66,9 @@ class ToolExecutor: ) -> AsyncIterator[ToolExecutionResult]: tool_call_id = tool_call.id function = tool_call.function - tool_kwargs = json.loads(function.arguments) if function and function.arguments else {} + tool_kwargs = ( + json.loads(function.arguments) if function and function.arguments else {} + ) if not function or not tool_call_id or not function.name: yield ToolExecutionResult(sequence_number=sequence_number) @@ -77,13 +76,20 @@ class ToolExecutor: # Emit progress events for tool execution start async for event_result in self._emit_progress_events( - function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server + function.name, + ctx, + sequence_number, + output_index, + item_id, + mcp_tool_to_server, ): sequence_number = event_result.sequence_number yield event_result # Execute the actual tool call - error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) + error_exc, result = await self._execute_tool( + function.name, tool_kwargs, ctx, mcp_tool_to_server + ) # Emit completion events for tool execution has_error = bool( @@ -91,20 +97,23 @@ class ToolExecutor: or ( result and ( - ((error_code := getattr(result, "error_code", None)) and error_code > 0) + ( + (error_code := getattr(result, "error_code", None)) + and error_code > 0 + ) or getattr(result, "error_message", None) ) ) ) async for event_result in self._emit_completion_events( - function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server, + function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server ): sequence_number = event_result.sequence_number yield event_result # Build result messages from tool execution output_message, input_message = await self._build_result_messages( - function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server, + function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server ) # Yield the final result @@ -113,7 +122,9 @@ class ToolExecutor: final_output_message=output_message, final_input_message=input_message, citation_files=( - metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None + metadata.get("citation_files") + if result and (metadata := getattr(result, "metadata", None)) + else None ), ) @@ -142,7 +153,10 @@ class ToolExecutor: return [] # Run all searches in parallel using gather - search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] + search_tasks = [ + search_single_store(vid) + for vid in response_file_search_tool.vector_store_ids + ] all_results = await asyncio.gather(*search_tasks) # Flatten results @@ -161,17 +175,23 @@ class ToolExecutor: chunk_text = result_item.content[0].text if result_item.content else "" # Get file_id from attributes if result_item.file_id is empty file_id = result_item.file_id or ( - result_item.attributes.get("document_id") if result_item.attributes else None + result_item.attributes.get("document_id") + if result_item.attributes + else None ) metadata_text = f"document_id: {file_id}, score: {result_item.score}" if result_item.attributes: metadata_text += f", attributes: {result_item.attributes}" - text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n" + text_content = ( + f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n" + ) content_items.append(TextContentItem(text=text_content)) unique_files.add(file_id) - content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + content_items.append( + TextContentItem(text="END of knowledge_search tool results.\n") + ) citation_instruction = "" if unique_files: @@ -206,7 +226,9 @@ class ToolExecutor: content=content_items, # type: ignore[arg-type] metadata={ "document_ids": [r.file_id for r in search_results], - "chunks": [r.content[0].text if r.content else "" for r in search_results], + "chunks": [ + r.content[0].text if r.content else "" for r in search_results + ], "scores": [r.score for r in search_results], "citation_files": citation_files, }, @@ -317,7 +339,11 @@ class ToolExecutor: elif function_name == "knowledge_search": response_file_search_tool = ( next( - (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + ( + t + for t in ctx.response_tools + if isinstance(t, OpenAIResponseInputToolFileSearch) + ), None, ) if ctx.response_tools @@ -363,28 +389,42 @@ class ToolExecutor: mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed( sequence_number=sequence_number, ) - yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number) - else: - mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( - sequence_number=sequence_number, + yield ToolExecutionResult( + stream_event=mcp_failed_event, sequence_number=sequence_number + ) + else: + mcp_completed_event = ( + OpenAIResponseObjectStreamResponseMcpCallCompleted( + sequence_number=sequence_number, + ) + ) + yield ToolExecutionResult( + stream_event=mcp_completed_event, sequence_number=sequence_number ) - yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number) elif function_name == "web_search": sequence_number += 1 - web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( - item_id=item_id, - output_index=output_index, - sequence_number=sequence_number, + web_completion_event = ( + OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + ) + yield ToolExecutionResult( + stream_event=web_completion_event, sequence_number=sequence_number ) - yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number) elif function_name == "knowledge_search": sequence_number += 1 - file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( - item_id=item_id, - output_index=output_index, - sequence_number=sequence_number, + file_completion_event = ( + OpenAIResponseObjectStreamResponseFileSearchCallCompleted( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + ) + yield ToolExecutionResult( + stream_event=file_completion_event, sequence_number=sequence_number ) - yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number) async def _build_result_messages( self, @@ -414,9 +454,11 @@ class ToolExecutor: ) if error_exc: message.error = str(error_exc) - elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or ( - result and getattr(result, "error_message", None) - ): + elif ( + result + and (error_code := getattr(result, "error_code", None)) + and error_code > 0 + ) or (result and getattr(result, "error_message", None)): ec = getattr(result, "error_code", "unknown") em = getattr(result, "error_message", "") message.error = f"Error (code {ec}): {em}" @@ -436,7 +478,11 @@ class ToolExecutor: queries=[tool_kwargs.get("query", "")], status="completed", ) - if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata: + if ( + result + and (metadata := getattr(result, "metadata", None)) + and "document_ids" in metadata + ): message.results = [] for i, doc_id in enumerate(metadata["document_ids"]): text = metadata["chunks"][i] if "chunks" in metadata else None @@ -472,7 +518,9 @@ class ToolExecutor: url_value = f"data:image;base64,{item.image.data}" else: url_value = str(item.image.url) if item.image.url else "" - part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value)) + part = OpenAIChatCompletionContentPartImageParam( + image_url=OpenAIImageURL(url=url_value) + ) else: raise ValueError(f"Unknown result content type: {type(item)}") content_list.append(part) @@ -484,6 +532,8 @@ class ToolExecutor: input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type] else: text = str(error_exc) if error_exc else "Tool execution failed" - input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) + input_message = OpenAIToolMessageParam( + content=text, tool_call_id=tool_call_id + ) return message, input_message