diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 104f15010..fbb5a608a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -35,9 +35,15 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpCallCompleted, + OpenAIResponseObjectStreamResponseMcpCallFailed, + OpenAIResponseObjectStreamResponseMcpCallInProgress, OpenAIResponseObjectStreamResponseOutputItemAdded, OpenAIResponseObjectStreamResponseOutputItemDone, OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseObjectStreamResponseWebSearchCallCompleted, + OpenAIResponseObjectStreamResponseWebSearchCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallSearching, OpenAIResponseOutput, OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, @@ -87,6 +93,15 @@ logger = get_logger(name=__name__, category="openai_responses") OPENAI_RESPONSES_PREFIX = "openai_responses:" +class ToolExecutionResult(BaseModel): + """Result of streaming tool execution.""" + + stream_event: OpenAIResponseObjectStream | None = None + sequence_number: int + final_output_message: OpenAIResponseOutput | None = None + final_input_message: OpenAIMessageParam | None = None + + async def _convert_response_content_to_chat_content( content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), ) -> str | list[OpenAIChatCompletionContentPartParam]: @@ -587,19 +602,38 @@ class OpenAIResponsesImpl: # execute non-function tool calls for tool_call in non_function_tool_calls: - tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx) + # Find the item_id for this tool call + matching_item_id = None + for index, item_id in tool_call_item_ids.items(): + response_tool_call = chat_response_tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use a fallback item_id if not found + if not matching_item_id: + matching_item_id = f"tc_{uuid.uuid4()}" + + # Execute tool call with streaming + tool_call_log = None + tool_response_message = None + async for result in self._execute_tool_call( + tool_call, ctx, sequence_number, response_id, len(output_messages), matching_item_id + ): + if result.stream_event: + # Forward streaming events + sequence_number = result.sequence_number + yield result.stream_event + + if result.final_output_message is not None: + tool_call_log = result.final_output_message + tool_response_message = result.final_input_message + sequence_number = result.sequence_number + if tool_call_log: output_messages.append(tool_call_log) # Emit output_item.done event for completed non-function tool call - # Find the item_id for this tool call - matching_item_id = None - for index, item_id in tool_call_item_ids.items(): - response_tool_call = chat_response_tool_calls.get(index) - if response_tool_call and response_tool_call.id == tool_call.id: - matching_item_id = item_id - break - if matching_item_id: sequence_number += 1 yield OpenAIResponseObjectStreamResponseOutputItemDone( @@ -848,7 +882,11 @@ class OpenAIResponsesImpl: self, tool_call: OpenAIChatCompletionToolCall, ctx: ChatCompletionContext, - ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: + sequence_number: int, + response_id: str, + output_index: int, + item_id: str, + ) -> AsyncIterator[ToolExecutionResult]: from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) @@ -858,8 +896,41 @@ class OpenAIResponsesImpl: tool_kwargs = json.loads(function.arguments) if function.arguments else {} if not function or not tool_call_id or not function.name: - return None, None + yield ToolExecutionResult(sequence_number=sequence_number) + return + # Emit in_progress event based on tool type (only for tools with specific streaming events) + progress_event = None + if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + elif function.name == "web_search": + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec + + if progress_event: + yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) + + # For web search, emit searching event + if function.name == "web_search": + sequence_number += 1 + searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) + + # Execute the actual tool call error_exc = None result = None try: @@ -894,6 +965,33 @@ class OpenAIResponsesImpl: except Exception as e: error_exc = e + # Emit completion or failure event based on result (only for tools with specific streaming events) + has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + completion_event = None + + if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: + sequence_number += 1 + if has_error: + completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + sequence_number=sequence_number, + ) + else: + completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + sequence_number=sequence_number, + ) + elif function.name == "web_search": + sequence_number += 1 + completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec + + if completion_event: + yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + + # Build the result message and input message if function.name in ctx.mcp_tool_to_server: from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageMCPCall, @@ -907,9 +1005,9 @@ class OpenAIResponsesImpl: ) if error_exc: message.error = str(error_exc) - elif (result.error_code and result.error_code > 0) or result.error_message: + elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result.content: + elif result and result.content: message.output = interleaved_content_as_str(result.content) else: if function.name == "web_search": @@ -917,7 +1015,7 @@ class OpenAIResponsesImpl: id=tool_call_id, status="completed", ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + if has_error: message.status = "failed" elif function.name == "knowledge_search": message = OpenAIResponseOutputMessageFileSearchToolCall( @@ -925,7 +1023,7 @@ class OpenAIResponsesImpl: queries=[tool_kwargs.get("query", "")], status="completed", ) - if "document_ids" in result.metadata: + if result and "document_ids" in result.metadata: message.results = [] for i, doc_id in enumerate(result.metadata["document_ids"]): text = result.metadata["chunks"][i] if "chunks" in result.metadata else None @@ -939,7 +1037,7 @@ class OpenAIResponsesImpl: attributes={}, ) ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + if has_error: message.status = "failed" else: raise ValueError(f"Unknown tool {function.name} called") @@ -971,10 +1069,13 @@ class OpenAIResponsesImpl: raise ValueError(f"Unknown result content type: {type(result.content)}") input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) else: - text = str(error_exc) + text = str(error_exc) if error_exc else "Tool execution failed" input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) - return message, input_message + # Yield the final result + yield ToolExecutionResult( + sequence_number=sequence_number, final_output_message=message, final_input_message=input_message + ) def _is_function_tool_call( diff --git a/tests/integration/non_ci/responses/test_responses.py b/tests/integration/non_ci/responses/test_responses.py index 6092346b0..776e3cf30 100644 --- a/tests/integration/non_ci/responses/test_responses.py +++ b/tests/integration/non_ci/responses/test_responses.py @@ -598,6 +598,10 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"] item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"] + # Should have tool execution progress events + mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"] + mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"] + # Verify we have substantial streaming activity (not just batch events) assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks" @@ -609,6 +613,24 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}" assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}" + # Should have tool execution progress events + assert len(mcp_in_progress_events) > 0, ( + f"Expected response.mcp_call.in_progress events, got chunk types: {chunk_types}" + ) + assert len(mcp_completed_events) > 0, ( + f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}" + ) + # MCP failed events are optional (only if errors occur) + + # Verify progress events have proper structure + for progress_event in mcp_in_progress_events: + assert hasattr(progress_event, "item_id"), "Progress event should have 'item_id' field" + assert hasattr(progress_event, "output_index"), "Progress event should have 'output_index' field" + assert hasattr(progress_event, "sequence_number"), "Progress event should have 'sequence_number' field" + + for completed_event in mcp_completed_events: + assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field" + # Verify delta events have proper structure for delta_event in delta_events: assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"