feat(responses): improve streaming for function calls

This commit is contained in:
Ashwin Bharambe 2025-08-12 18:06:02 -07:00
parent 6358d0a478
commit 674478c851
2 changed files with 236 additions and 28 deletions

View file

@ -33,6 +33,10 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta, OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput, OpenAIResponseOutput,
OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContent,
@ -73,7 +77,9 @@ from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool from llama_stack.providers.utils.inference.openai_compat import (
convert_tooldef_to_openai_tool,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack.providers.utils.responses.responses_store import ResponsesStore
logger = get_logger(name=__name__, category="openai_responses") logger = get_logger(name=__name__, category="openai_responses")
@ -82,7 +88,7 @@ OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _convert_response_content_to_chat_content( async def _convert_response_content_to_chat_content(
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent], content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
) -> str | list[OpenAIChatCompletionContentPartParam]: ) -> str | list[OpenAIChatCompletionContentPartParam]:
""" """
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
@ -150,7 +156,9 @@ async def _convert_response_input_to_chat_messages(
return messages return messages
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: async def _convert_chat_choice_to_response_message(
choice: OpenAIChoice,
) -> OpenAIResponseMessage:
""" """
Convert an OpenAI Chat Completion choice into an OpenAI Response output message. Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
""" """
@ -172,7 +180,9 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open
) )
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam: async def _convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
""" """
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
""" """
@ -228,7 +238,9 @@ class OpenAIResponsesImpl:
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
async def _prepend_previous_response( async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
): ):
if previous_response_id: if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
@ -446,6 +458,8 @@ class OpenAIResponsesImpl:
# Create a placeholder message item for delta events # Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}" message_item_id = f"msg_{uuid.uuid4()}"
# Track tool call items for streaming events
tool_call_item_ids: dict[int, str] = {}
async for chunk in completion_result: async for chunk in completion_result:
chat_response_id = chunk.id chat_response_id = chunk.id
@ -472,18 +486,61 @@ class OpenAIResponsesImpl:
if chunk_choice.delta.tool_calls: if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls: for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None) response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call: # Create new tool call entry if this is the first chunk for this index
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions if response_tool_call is None:
if tool_call.function.arguments:
# Guard against an initial None argument before we concatenate
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump() tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None) tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call chat_response_tool_calls[tool_call.index] = response_tool_call
# Create item ID for this tool call for streaming events
tool_call_item_id = f"fc_{uuid.uuid4()}"
tool_call_item_ids[tool_call.index] = tool_call_item_id
# Emit output_item.added event for the new function call
sequence_number += 1
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments="", # Will be filled incrementally via delta events
call_id=tool_call.id or "",
name=tool_call.function.name if tool_call.function else "",
id=tool_call_item_id,
status="in_progress",
)
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=response_id,
item=function_call_item,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Stream function call arguments as they arrive
if tool_call.function and tool_call.function.arguments:
tool_call_item_id = tool_call_item_ids[tool_call.index]
sequence_number += 1
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Accumulate arguments for final response
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Emit function_call_arguments.done events for completed tool calls
for tool_call_index in sorted(chat_response_tool_calls.keys()):
if tool_call_index in tool_call_item_ids:
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
sequence_number += 1
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Convert collected chunks to complete response # Convert collected chunks to complete response
if chat_response_tool_calls: if chat_response_tool_calls:
@ -532,18 +589,56 @@ class OpenAIResponsesImpl:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx) tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
if tool_call_log: if tool_call_log:
output_messages.append(tool_call_log) output_messages.append(tool_call_log)
# Emit output_item.done event for completed non-function tool call
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in tool_call_item_ids.items():
response_tool_call = chat_response_tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
if matching_item_id:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=response_id,
item=tool_call_log,
output_index=len(output_messages) - 1,
sequence_number=sequence_number,
)
if tool_response_message: if tool_response_message:
next_turn_messages.append(tool_response_message) next_turn_messages.append(tool_response_message)
for tool_call in function_tool_calls: for tool_call in function_tool_calls:
output_messages.append( # Find the item_id for this tool call from our tracking dictionary
OpenAIResponseOutputMessageFunctionToolCall( matching_item_id = None
arguments=tool_call.function.arguments or "", for index, item_id in tool_call_item_ids.items():
call_id=tool_call.id, response_tool_call = chat_response_tool_calls.get(index)
name=tool_call.function.name or "", if response_tool_call and response_tool_call.id == tool_call.id:
id=f"fc_{uuid.uuid4()}", matching_item_id = item_id
status="completed", break
)
# Use existing item_id or create new one if not found
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=final_item_id,
status="completed",
)
output_messages.append(function_call_item)
# Emit output_item.done event for completed function call
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=response_id,
item=function_call_item,
output_index=len(output_messages) - 1,
sequence_number=sequence_number,
) )
if not function_tool_calls and not non_function_tool_calls: if not function_tool_calls and not non_function_tool_calls:
@ -779,7 +874,8 @@ class OpenAIResponsesImpl:
) )
elif function.name == "knowledge_search": elif function.name == "knowledge_search":
response_file_search_tool = next( response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
) )
if response_file_search_tool: if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool # Use vector_stores.search API instead of knowledge_search tool
@ -798,7 +894,9 @@ class OpenAIResponsesImpl:
error_exc = e error_exc = e
if function.name in ctx.mcp_tool_to_server: if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall( message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id, id=tool_call_id,
@ -850,7 +948,10 @@ class OpenAIResponsesImpl:
if isinstance(result.content, str): if isinstance(result.content, str):
content = result.content content = result.content
elif isinstance(result.content, list): elif isinstance(result.content, list):
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
content = [] content = []
for item in result.content: for item in result.content:

View file

@ -384,12 +384,18 @@ def test_response_non_streaming_mcp_tool(request, compat_client, text_model_id,
assert list_tools.type == "mcp_list_tools" assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp" assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2 assert len(list_tools.tools) == 2
assert {t.name for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"} assert {t.name for t in list_tools.tools} == {
"get_boiling_point",
"greet_everyone",
}
call = response.output[1] call = response.output[1]
assert call.type == "mcp_call" assert call.type == "mcp_call"
assert call.name == "get_boiling_point" assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True} assert json.loads(call.arguments) == {
"liquid_name": "myawesomeliquid",
"celsius": True,
}
assert call.error is None assert call.error is None
assert "-100" in call.output assert "-100" in call.output
@ -581,6 +587,105 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
f"Last chunk should be response.completed, got {chunks[-1].type}" f"Last chunk should be response.completed, got {chunks[-1].type}"
) )
# Verify tool call streaming events are present
chunk_types = [chunk.type for chunk in chunks]
# Should have function call arguments delta events for tool calls
delta_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.delta"]
done_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.done"]
# Should have output item events for tool calls
item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"]
item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"]
# Verify we have substantial streaming activity (not just batch events)
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
# Since this test involves MCP tool calls, we should see streaming events
assert len(delta_events) > 0, f"Expected function_call_arguments.delta events, got chunk types: {chunk_types}"
assert len(done_events) > 0, f"Expected function_call_arguments.done events, got chunk types: {chunk_types}"
# Should have output item events for function calls
assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}"
assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}"
# Verify delta events have proper structure
for delta_event in delta_events:
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"
assert hasattr(delta_event, "item_id"), "Delta event should have 'item_id' field"
assert hasattr(delta_event, "sequence_number"), "Delta event should have 'sequence_number' field"
assert delta_event.delta, "Delta should not be empty"
# Verify done events have proper structure
for done_event in done_events:
assert hasattr(done_event, "arguments"), "Done event should have 'arguments' field"
assert hasattr(done_event, "item_id"), "Done event should have 'item_id' field"
assert done_event.arguments, "Final arguments should not be empty"
# Verify output item added events have proper structure
for added_event in item_added_events:
assert hasattr(added_event, "item"), "Added event should have 'item' field"
assert hasattr(added_event, "output_index"), "Added event should have 'output_index' field"
assert hasattr(added_event, "sequence_number"), "Added event should have 'sequence_number' field"
assert hasattr(added_event, "response_id"), "Added event should have 'response_id' field"
assert added_event.item.type in ["function_call", "mcp_call"], "Added item should be a tool call"
assert added_event.item.status == "in_progress", "Added item should be in progress"
assert added_event.response_id, "Response ID should not be empty"
assert isinstance(added_event.output_index, int), "Output index should be integer"
assert added_event.output_index >= 0, "Output index should be non-negative"
# Verify output item done events have proper structure
for done_event in item_done_events:
assert hasattr(done_event, "item"), "Done event should have 'item' field"
assert hasattr(done_event, "output_index"), "Done event should have 'output_index' field"
assert hasattr(done_event, "sequence_number"), "Done event should have 'sequence_number' field"
assert hasattr(done_event, "response_id"), "Done event should have 'response_id' field"
assert done_event.item.type in ["function_call", "mcp_call"], "Done item should be a tool call"
# Note: MCP calls don't have a status field, only function calls do
if done_event.item.type == "function_call":
assert done_event.item.status == "completed", "Function call should be completed"
assert done_event.response_id, "Response ID should not be empty"
assert isinstance(done_event.output_index, int), "Output index should be integer"
assert done_event.output_index >= 0, "Output index should be non-negative"
# Group function call argument events by item_id (these should have proper tracking)
function_call_events_by_item_id = {}
for chunk in chunks:
if hasattr(chunk, "item_id") and chunk.type in [
"response.function_call_arguments.delta",
"response.function_call_arguments.done",
]:
item_id = chunk.item_id
if item_id not in function_call_events_by_item_id:
function_call_events_by_item_id[item_id] = []
function_call_events_by_item_id[item_id].append(chunk)
for item_id, related_events in function_call_events_by_item_id.items():
# Should have at least one delta and one done event for a complete function call
delta_events = [e for e in related_events if e.type == "response.function_call_arguments.delta"]
done_events = [e for e in related_events if e.type == "response.function_call_arguments.done"]
assert len(delta_events) > 0, f"Item {item_id} should have at least one delta event"
assert len(done_events) == 1, f"Item {item_id} should have exactly one done event"
# Verify all events have the same item_id
for event in related_events:
assert event.item_id == item_id, f"Event should have consistent item_id {item_id}, got {event.item_id}"
# Basic pairing check: each output_item.added should be followed by some activity
# (but we can't enforce strict 1:1 pairing due to the complexity of multi-turn scenarios)
assert len(item_added_events) > 0, "Should have at least one output_item.added event"
# Verify response_id consistency across all events
response_ids = set()
for chunk in chunks:
if hasattr(chunk, "response_id"):
response_ids.add(chunk.response_id)
elif hasattr(chunk, "response") and hasattr(chunk.response, "id"):
response_ids.add(chunk.response.id)
assert len(response_ids) == 1, f"All events should reference the same response_id, found: {response_ids}"
# Get the final response from the last chunk # Get the final response from the last chunk
final_chunk = chunks[-1] final_chunk = chunks[-1]
if hasattr(final_chunk, "response"): if hasattr(final_chunk, "response"):
@ -722,7 +827,9 @@ def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_fact
# Attach file to vector store with attributes # Attach file to vector store with attributes
file_attach_response = compat_client.vector_stores.files.create( file_attach_response = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"] vector_store_id=vector_store.id,
file_id=file_response.id,
attributes=file_data["attributes"],
) )
# Wait for attachment # Wait for attachment