From a23ee35b240af05e944fedc79303a621340e961e Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Tue, 4 Nov 2025 13:10:46 -0800 Subject: [PATCH] reverting some formatting changes --- .../meta_reference/responses/tool_executor.py | 88 ++++++------------- 1 file changed, 25 insertions(+), 63 deletions(-) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 07d5dfc7c..e76807a10 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -26,7 +26,10 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageMCPCall, OpenAIResponseOutputMessageWebSearchToolCall, ) -from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, +) from llama_stack.apis.inference import ( OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, @@ -66,9 +69,7 @@ class ToolExecutor: ) -> AsyncIterator[ToolExecutionResult]: tool_call_id = tool_call.id function = tool_call.function - tool_kwargs = ( - json.loads(function.arguments) if function and function.arguments else {} - ) + tool_kwargs = json.loads(function.arguments) if function and function.arguments else {} if not function or not tool_call_id or not function.name: yield ToolExecutionResult(sequence_number=sequence_number) @@ -76,12 +77,7 @@ class ToolExecutor: # Emit progress events for tool execution start async for event_result in self._emit_progress_events( - function.name, - ctx, - sequence_number, - output_index, - item_id, - mcp_tool_to_server, + function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server ): sequence_number = event_result.sequence_number yield event_result @@ -97,10 +93,7 @@ class ToolExecutor: or ( result and ( - ( - (error_code := getattr(result, "error_code", None)) - and error_code > 0 - ) + ((error_code := getattr(result, "error_code", None)) and error_code > 0) or getattr(result, "error_message", None) ) ) @@ -122,9 +115,7 @@ class ToolExecutor: final_output_message=output_message, final_input_message=input_message, citation_files=( - metadata.get("citation_files") - if result and (metadata := getattr(result, "metadata", None)) - else None + metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None ), ) @@ -153,10 +144,7 @@ class ToolExecutor: return [] # Run all searches in parallel using gather - search_tasks = [ - search_single_store(vid) - for vid in response_file_search_tool.vector_store_ids - ] + search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] all_results = await asyncio.gather(*search_tasks) # Flatten results @@ -175,9 +163,7 @@ class ToolExecutor: chunk_text = result_item.content[0].text if result_item.content else "" # Get file_id from attributes if result_item.file_id is empty file_id = result_item.file_id or ( - result_item.attributes.get("document_id") - if result_item.attributes - else None + result_item.attributes.get("document_id") if result_item.attributes else None ) metadata_text = f"document_id: {file_id}, score: {result_item.score}" if result_item.attributes: @@ -189,9 +175,7 @@ class ToolExecutor: content_items.append(TextContentItem(text=text_content)) unique_files.add(file_id) - content_items.append( - TextContentItem(text="END of knowledge_search tool results.\n") - ) + content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) citation_instruction = "" if unique_files: @@ -226,9 +210,7 @@ class ToolExecutor: content=content_items, # type: ignore[arg-type] metadata={ "document_ids": [r.file_id for r in search_results], - "chunks": [ - r.content[0].text if r.content else "" for r in search_results - ], + "chunks": [r.content[0].text if r.content else "" for r in search_results], "scores": [r.score for r in search_results], "citation_files": citation_files, }, @@ -339,11 +321,7 @@ class ToolExecutor: elif function_name == "knowledge_search": response_file_search_tool = ( next( - ( - t - for t in ctx.response_tools - if isinstance(t, OpenAIResponseInputToolFileSearch) - ), + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None, ) if ctx.response_tools @@ -389,39 +367,33 @@ class ToolExecutor: mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed( sequence_number=sequence_number, ) - yield ToolExecutionResult( - stream_event=mcp_failed_event, sequence_number=sequence_number - ) + yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number) else: - mcp_completed_event = ( - OpenAIResponseObjectStreamResponseMcpCallCompleted( + mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( sequence_number=sequence_number, ) - ) yield ToolExecutionResult( stream_event=mcp_completed_event, sequence_number=sequence_number ) elif function_name == "web_search": sequence_number += 1 - web_completion_event = ( - OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) - ) + yield ToolExecutionResult( stream_event=web_completion_event, sequence_number=sequence_number ) elif function_name == "knowledge_search": sequence_number += 1 - file_completion_event = ( - OpenAIResponseObjectStreamResponseFileSearchCallCompleted( + file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) - ) + yield ToolExecutionResult( stream_event=file_completion_event, sequence_number=sequence_number ) @@ -454,11 +426,9 @@ class ToolExecutor: ) if error_exc: message.error = str(error_exc) - elif ( - result - and (error_code := getattr(result, "error_code", None)) - and error_code > 0 - ) or (result and getattr(result, "error_message", None)): + elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or ( + result and getattr(result, "error_message", None) + ): ec = getattr(result, "error_code", "unknown") em = getattr(result, "error_message", "") message.error = f"Error (code {ec}): {em}" @@ -478,11 +448,7 @@ class ToolExecutor: queries=[tool_kwargs.get("query", "")], status="completed", ) - if ( - result - and (metadata := getattr(result, "metadata", None)) - and "document_ids" in metadata - ): + if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata: message.results = [] for i, doc_id in enumerate(metadata["document_ids"]): text = metadata["chunks"][i] if "chunks" in metadata else None @@ -518,9 +484,7 @@ class ToolExecutor: url_value = f"data:image;base64,{item.image.data}" else: url_value = str(item.image.url) if item.image.url else "" - part = OpenAIChatCompletionContentPartImageParam( - image_url=OpenAIImageURL(url=url_value) - ) + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value)) else: raise ValueError(f"Unknown result content type: {type(item)}") content_list.append(part) @@ -532,8 +496,6 @@ class ToolExecutor: input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type] else: text = str(error_exc) if error_exc else "Tool execution failed" - input_message = OpenAIToolMessageParam( - content=text, tool_call_id=tool_call_id - ) + input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) return message, input_message