feat(responses): stream progress of tool calls (#3135)

# What does this PR do?
Enhances tool execution streaming by adding support for real-time progress events during tool calls. This implementation adds streaming events for MCP and web search tools, including in-progress, searching, completed, and failed states. 

The refactored `_execute_tool_call` method now returns an async iterator that yields streaming events throughout the tool execution lifecycle.

## Test Plan
Updated the integration test `test_response_streaming_multi_turn_tool_execution` to verify the presence and structure of new streaming events, including:
- Checking for MCP in-progress and completed events
- Verifying that progress events contain required fields (item_id, output_index, sequence_number)
- Ensuring completed events have the necessary sequence_number field
This commit is contained in:
Ashwin Bharambe 2025-08-13 16:31:25 -07:00 committed by GitHub
parent 5b312a80b9
commit 8638537d14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 141 additions and 18 deletions

View file

@ -35,9 +35,15 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
@ -87,6 +93,15 @@ logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
class ToolExecutionResult(BaseModel):
"""Result of streaming tool execution."""
stream_event: OpenAIResponseObjectStream | None = None
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
async def _convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
) -> str | list[OpenAIChatCompletionContentPartParam]:
@ -587,19 +602,38 @@ class OpenAIResponsesImpl:
# execute non-function tool calls
for tool_call in non_function_tool_calls:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in tool_call_item_ids.items():
response_tool_call = chat_response_tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use a fallback item_id if not found
if not matching_item_id:
matching_item_id = f"tc_{uuid.uuid4()}"
# Execute tool call with streaming
tool_call_log = None
tool_response_message = None
async for result in self._execute_tool_call(
tool_call, ctx, sequence_number, response_id, len(output_messages), matching_item_id
):
if result.stream_event:
# Forward streaming events
sequence_number = result.sequence_number
yield result.stream_event
if result.final_output_message is not None:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
sequence_number = result.sequence_number
if tool_call_log:
output_messages.append(tool_call_log)
# Emit output_item.done event for completed non-function tool call
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in tool_call_item_ids.items():
response_tool_call = chat_response_tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
if matching_item_id:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
@ -848,7 +882,11 @@ class OpenAIResponsesImpl:
self,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
sequence_number: int,
response_id: str,
output_index: int,
item_id: str,
) -> AsyncIterator[ToolExecutionResult]:
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
@ -858,8 +896,41 @@ class OpenAIResponsesImpl:
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
return None, None
yield ToolExecutionResult(sequence_number=sequence_number)
return
# Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
elif function.name == "web_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event
if function.name == "web_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# Execute the actual tool call
error_exc = None
result = None
try:
@ -894,6 +965,33 @@ class OpenAIResponsesImpl:
except Exception as e:
error_exc = e
# Emit completion or failure event based on result (only for tools with specific streaming events)
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
completion_event = None
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
sequence_number += 1
if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number,
)
else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
elif function.name == "web_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
# Build the result message and input message
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
@ -907,9 +1005,9 @@ class OpenAIResponsesImpl:
)
if error_exc:
message.error = str(error_exc)
elif (result.error_code and result.error_code > 0) or result.error_message:
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result.content:
elif result and result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
@ -917,7 +1015,7 @@ class OpenAIResponsesImpl:
id=tool_call_id,
status="completed",
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
if has_error:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
@ -925,7 +1023,7 @@ class OpenAIResponsesImpl:
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if "document_ids" in result.metadata:
if result and "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
@ -939,7 +1037,7 @@ class OpenAIResponsesImpl:
attributes={},
)
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
if has_error:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
@ -971,10 +1069,13 @@ class OpenAIResponsesImpl:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc)
text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=message, final_input_message=input_message
)
def _is_function_tool_call(

View file

@ -598,6 +598,10 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"]
item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"]
# Should have tool execution progress events
mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"]
mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"]
# Verify we have substantial streaming activity (not just batch events)
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
@ -609,6 +613,24 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}"
assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}"
# Should have tool execution progress events
assert len(mcp_in_progress_events) > 0, (
f"Expected response.mcp_call.in_progress events, got chunk types: {chunk_types}"
)
assert len(mcp_completed_events) > 0, (
f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}"
)
# MCP failed events are optional (only if errors occur)
# Verify progress events have proper structure
for progress_event in mcp_in_progress_events:
assert hasattr(progress_event, "item_id"), "Progress event should have 'item_id' field"
assert hasattr(progress_event, "output_index"), "Progress event should have 'output_index' field"
assert hasattr(progress_event, "sequence_number"), "Progress event should have 'sequence_number' field"
for completed_event in mcp_completed_events:
assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field"
# Verify delta events have proper structure
for delta_event in delta_events:
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"