feat(responses): add MCP argument streaming and content part events

- Add content part events (response.content_part.added/done) for granular text streaming
- Implement MCP-specific argument streaming (response.mcp_call.arguments.delta/done)
- Differentiate between MCP and function call streaming events
- Update unit and integration tests for new streaming events
- Ensure proper event ordering and OpenAI spec compliance

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-08-12 23:02:39 -07:00
parent 8638537d14
commit e48d062233
No known key found for this signature in database
GPG key ID: A7318BD657B83EA8
4 changed files with 242 additions and 35 deletions

View file

@ -623,6 +623,58 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed" type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed"
@json_schema_type
class OpenAIResponseContentPart(BaseModel):
"""Base class for response content parts."""
id: str
type: str
@json_schema_type
class OpenAIResponseContentPartText(OpenAIResponseContentPart):
"""Text content part for streaming responses."""
text: str
type: Literal["text"] = "text"
@json_schema_type
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
"""Streaming event for when a new content part is added to a response item.
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param part: The content part that was added
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.added"
"""
response_id: str
item_id: str
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.added"] = "response.content_part.added"
@json_schema_type
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
"""Streaming event for when a content part is completed.
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param part: The completed content part
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.done"
"""
response_id: str
item_id: str
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.done"] = "response.content_part.done"
OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseOutputItemAdded | OpenAIResponseObjectStreamResponseOutputItemAdded
@ -642,6 +694,8 @@ OpenAIResponseObjectStream = Annotated[
| OpenAIResponseObjectStreamResponseMcpCallInProgress | OpenAIResponseObjectStreamResponseMcpCallInProgress
| OpenAIResponseObjectStreamResponseMcpCallFailed | OpenAIResponseObjectStreamResponseMcpCallFailed
| OpenAIResponseObjectStreamResponseMcpCallCompleted | OpenAIResponseObjectStreamResponseMcpCallCompleted
| OpenAIResponseObjectStreamResponseContentPartAdded
| OpenAIResponseObjectStreamResponseContentPartDone
| OpenAIResponseObjectStreamResponseCompleted, | OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"), Field(discriminator="type"),
] ]

View file

@ -20,6 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
ListOpenAIResponseObject, ListOpenAIResponseObject,
OpenAIDeleteResponseObject, OpenAIDeleteResponseObject,
OpenAIResponseContentPartText,
OpenAIResponseInput, OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent, OpenAIResponseInputMessageContent,
@ -32,9 +33,13 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseContentPartAdded,
OpenAIResponseObjectStreamResponseContentPartDone,
OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpCallCompleted, OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed, OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress, OpenAIResponseObjectStreamResponseMcpCallInProgress,
@ -475,6 +480,9 @@ class OpenAIResponsesImpl:
message_item_id = f"msg_{uuid.uuid4()}" message_item_id = f"msg_{uuid.uuid4()}"
# Track tool call items for streaming events # Track tool call items for streaming events
tool_call_item_ids: dict[int, str] = {} tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_id: str | None = None
content_part_emitted = False
async for chunk in completion_result: async for chunk in completion_result:
chat_response_id = chunk.id chat_response_id = chunk.id
@ -483,6 +491,20 @@ class OpenAIResponsesImpl:
for chunk_choice in chunk.choices: for chunk_choice in chunk.choices:
# Emit incremental text content as delta events # Emit incremental text content as delta events
if chunk_choice.delta.content: if chunk_choice.delta.content:
# Emit content_part.added event for first text chunk
if not content_part_emitted:
content_part_id = f"cp_text_{uuid.uuid4()}"
content_part_emitted = True
sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
response_id=response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartText(
id=content_part_id,
text="", # Will be filled incrementally via text deltas
),
sequence_number=sequence_number,
)
sequence_number += 1 sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta( yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0, content_index=0,
@ -529,16 +551,33 @@ class OpenAIResponsesImpl:
sequence_number=sequence_number, sequence_number=sequence_number,
) )
# Stream function call arguments as they arrive # Stream tool call arguments as they arrive (differentiate between MCP and function calls)
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
tool_call_item_id = tool_call_item_ids[tool_call.index] tool_call_item_id = tool_call_item_ids[tool_call.index]
sequence_number += 1 sequence_number += 1
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments, # Check if this is an MCP tool call
item_id=tool_call_item_id, is_mcp_tool = (
output_index=len(output_messages), ctx.mcp_tool_to_server
sequence_number=sequence_number, and tool_call.function.name
and tool_call.function.name in ctx.mcp_tool_to_server
) )
if is_mcp_tool:
# Emit MCP-specific argument delta event
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
else:
# Emit function call argument delta event
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Accumulate arguments for final response (only for subsequent chunks) # Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call: if not is_new_tool_call:
@ -546,27 +585,56 @@ class OpenAIResponsesImpl:
response_tool_call.function.arguments or "" response_tool_call.function.arguments or ""
) + tool_call.function.arguments ) + tool_call.function.arguments
# Emit function_call_arguments.done events for completed tool calls # Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()): for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call_item_id = tool_call_item_ids[tool_call_index] tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or "" final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
# Check if this is an MCP tool call
is_mcp_tool = ctx.mcp_tool_to_server and tool_call_name and tool_call_name in ctx.mcp_tool_to_server
sequence_number += 1 sequence_number += 1
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone( if is_mcp_tool:
arguments=final_arguments, # Emit MCP-specific argument done event
item_id=tool_call_item_id, yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(
output_index=len(output_messages), arguments=final_arguments,
sequence_number=sequence_number, item_id=tool_call_item_id,
) output_index=len(output_messages),
sequence_number=sequence_number,
)
else:
# Emit function call argument done event
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Convert collected chunks to complete response # Convert collected chunks to complete response
if chat_response_tool_calls: if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
# when there are tool calls, we need to clear the content
chat_response_content = []
else: else:
tool_calls = None tool_calls = None
# Emit content_part.done event if text content was streamed (before content gets cleared)
if content_part_emitted and content_part_id:
final_text = "".join(chat_response_content)
sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
response_id=response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartText(
id=content_part_id,
text=final_text,
),
sequence_number=sequence_number,
)
# Clear content when there are tool calls (OpenAI spec behavior)
if chat_response_tool_calls:
chat_response_content = []
assistant_message = OpenAIAssistantMessageParam( assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content), content="".join(chat_response_content),
tool_calls=tool_calls, tool_calls=tool_calls,

View file

@ -590,9 +590,17 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
# Verify tool call streaming events are present # Verify tool call streaming events are present
chunk_types = [chunk.type for chunk in chunks] chunk_types = [chunk.type for chunk in chunks]
# Should have function call arguments delta events for tool calls # Should have function call or MCP arguments delta/done events for tool calls
delta_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.delta"] delta_events = [
done_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.done"] chunk
for chunk in chunks
if chunk.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"]
]
done_events = [
chunk
for chunk in chunks
if chunk.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"]
]
# Should have output item events for tool calls # Should have output item events for tool calls
item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"] item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"]
@ -606,8 +614,12 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks" assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
# Since this test involves MCP tool calls, we should see streaming events # Since this test involves MCP tool calls, we should see streaming events
assert len(delta_events) > 0, f"Expected function_call_arguments.delta events, got chunk types: {chunk_types}" assert len(delta_events) > 0, (
assert len(done_events) > 0, f"Expected function_call_arguments.done events, got chunk types: {chunk_types}" f"Expected function_call_arguments.delta or mcp_call.arguments.delta events, got chunk types: {chunk_types}"
)
assert len(done_events) > 0, (
f"Expected function_call_arguments.done or mcp_call.arguments.done events, got chunk types: {chunk_types}"
)
# Should have output item events for function calls # Should have output item events for function calls
assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}" assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}"
@ -670,22 +682,32 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
assert isinstance(done_event.output_index, int), "Output index should be integer" assert isinstance(done_event.output_index, int), "Output index should be integer"
assert done_event.output_index >= 0, "Output index should be non-negative" assert done_event.output_index >= 0, "Output index should be non-negative"
# Group function call argument events by item_id (these should have proper tracking) # Group function call and MCP argument events by item_id (these should have proper tracking)
function_call_events_by_item_id = {} argument_events_by_item_id = {}
for chunk in chunks: for chunk in chunks:
if hasattr(chunk, "item_id") and chunk.type in [ if hasattr(chunk, "item_id") and chunk.type in [
"response.function_call_arguments.delta", "response.function_call_arguments.delta",
"response.function_call_arguments.done", "response.function_call_arguments.done",
"response.mcp_call.arguments.delta",
"response.mcp_call.arguments.done",
]: ]:
item_id = chunk.item_id item_id = chunk.item_id
if item_id not in function_call_events_by_item_id: if item_id not in argument_events_by_item_id:
function_call_events_by_item_id[item_id] = [] argument_events_by_item_id[item_id] = []
function_call_events_by_item_id[item_id].append(chunk) argument_events_by_item_id[item_id].append(chunk)
for item_id, related_events in function_call_events_by_item_id.items(): for item_id, related_events in argument_events_by_item_id.items():
# Should have at least one delta and one done event for a complete function call # Should have at least one delta and one done event for a complete tool call
delta_events = [e for e in related_events if e.type == "response.function_call_arguments.delta"] delta_events = [
done_events = [e for e in related_events if e.type == "response.function_call_arguments.done"] e
for e in related_events
if e.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"]
]
done_events = [
e
for e in related_events
if e.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"]
]
assert len(delta_events) > 0, f"Item {item_id} should have at least one delta event" assert len(delta_events) > 0, f"Item {item_id} should have at least one delta event"
assert len(done_events) == 1, f"Item {item_id} should have exactly one done event" assert len(done_events) == 1, f"Item {item_id} should have exactly one done event"
@ -694,6 +716,43 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
for event in related_events: for event in related_events:
assert event.item_id == item_id, f"Event should have consistent item_id {item_id}, got {event.item_id}" assert event.item_id == item_id, f"Event should have consistent item_id {item_id}, got {event.item_id}"
# Verify content part events if they exist (for text streaming)
content_part_added_events = [chunk for chunk in chunks if chunk.type == "response.content_part.added"]
content_part_done_events = [chunk for chunk in chunks if chunk.type == "response.content_part.done"]
# Content part events should be paired (if any exist)
if len(content_part_added_events) > 0:
assert len(content_part_done_events) > 0, (
"Should have content_part.done events if content_part.added events exist"
)
# Verify content part event structure
for added_event in content_part_added_events:
assert hasattr(added_event, "response_id"), "Content part added event should have response_id"
assert hasattr(added_event, "item_id"), "Content part added event should have item_id"
assert hasattr(added_event, "part"), "Content part added event should have part"
# Part might be a dict or object, handle both cases
if hasattr(added_event.part, "id"):
assert added_event.part.id, "Content part should have id"
assert added_event.part.type, "Content part should have type"
else:
assert "id" in added_event.part, "Content part should have id"
assert "type" in added_event.part, "Content part should have type"
for done_event in content_part_done_events:
assert hasattr(done_event, "response_id"), "Content part done event should have response_id"
assert hasattr(done_event, "item_id"), "Content part done event should have item_id"
assert hasattr(done_event, "part"), "Content part done event should have part"
# Part might be a dict or object, handle both cases
# Note: In some scenarios (e.g., with tool calls), text content might be empty
if hasattr(done_event.part, "text"):
# Text can be empty in tool call scenarios, so we just check it exists
assert hasattr(done_event.part, "text"), "Content part should have text field when done"
else:
# For dict case, text field might not be present if content was empty
# This is valid behavior when tool calls are present
pass
# Basic pairing check: each output_item.added should be followed by some activity # Basic pairing check: each output_item.added should be followed by some activity
# (but we can't enforce strict 1:1 pairing due to the complexity of multi-turn scenarios) # (but we can't enforce strict 1:1 pairing due to the complexity of multi-turn scenarios)
assert len(item_added_events) > 0, "Should have at least one output_item.added event" assert len(item_added_events) > 0, "Should have at least one output_item.added event"

View file

@ -136,9 +136,12 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
input=input_text, input=input_text,
model=model, model=model,
temperature=0.1, temperature=0.1,
stream=True, # Enable streaming to test content part events
) )
# Verify # For streaming response, collect all chunks
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with( mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model, model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)], messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
@ -147,11 +150,32 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
stream=True, stream=True,
temperature=0.1, temperature=0.1,
) )
# Should have content part events for text streaming
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
assert len(chunks) >= 4
assert chunks[0].type == "response.created"
# Check for content part events
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"]
text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"]
assert len(content_part_added_events) >= 1, "Should have content_part.added event for text"
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
assert len(text_delta_events) >= 1, "Should have text delta events"
# Verify final event is completion
assert chunks[-1].type == "response.completed"
# When streaming, the final response is in the last chunk
final_response = chunks[-1].response
assert final_response.model == model
assert len(final_response.output) == 1
assert isinstance(final_response.output[0], OpenAIResponseMessage)
openai_responses_impl.responses_store.store_response_object.assert_called_once() openai_responses_impl.responses_store.store_response_object.assert_called_once()
assert result.model == model assert final_response.output[0].content[0].text == "Dublin"
assert len(result.output) == 1
assert isinstance(result.output[0], OpenAIResponseMessage)
assert result.output[0].content[0].text == "Dublin"
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
@ -272,6 +296,8 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Check that we got the content from our mocked tool execution result # Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result] chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta, # Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed # function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 6 assert len(chunks) == 6