feat(responses): add output_text delta events to responses (#2265)

This adds initial streaming support to the Responses API. 

This PR makes sure that the _first_ inference call made to chat
completions streams out.

There's more to be done:
 - tool call output tokens need to stream out when possible
- we need to loop through multiple rounds of inference and they all need
to stream out.

## Test Plan

Added a test. Executed as:

```
FIREWORKS_API_KEY=... \
  pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
  --provider=stack:fireworks --model meta-llama/Llama-4-Scout-17B-16E-Instruct
```

Then, started a llama stack fireworks distro and tested against it like
this:

```
OPENAI_API_KEY=blah \
   pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
   --base-url http://localhost:8321/v1/openai/v1 \
  --model meta-llama/Llama-4-Scout-17B-16E-Instruct 
```
This commit is contained in:
Ashwin Bharambe 2025-05-27 13:07:14 -07:00 committed by GitHub
parent 6ee319ae08
commit 5cdb29758a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 493 additions and 160 deletions

View file

@ -7540,6 +7540,9 @@
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
},
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta"
},
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
}
@ -7548,6 +7551,7 @@
"propertyName": "type",
"mapping": {
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
"response.output_text.delta": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta",
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
}
}
@ -7590,6 +7594,41 @@
],
"title": "OpenAIResponseObjectStreamResponseCreated"
},
"OpenAIResponseObjectStreamResponseOutputTextDelta": {
"type": "object",
"properties": {
"content_index": {
"type": "integer"
},
"delta": {
"type": "string"
},
"item_id": {
"type": "string"
},
"output_index": {
"type": "integer"
},
"sequence_number": {
"type": "integer"
},
"type": {
"type": "string",
"const": "response.output_text.delta",
"default": "response.output_text.delta"
}
},
"additionalProperties": false,
"required": [
"content_index",
"delta",
"item_id",
"output_index",
"sequence_number",
"type"
],
"title": "OpenAIResponseObjectStreamResponseOutputTextDelta"
},
"CreateUploadSessionRequest": {
"type": "object",
"properties": {

View file

@ -5294,11 +5294,13 @@ components:
OpenAIResponseObjectStream:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
discriminator:
propertyName: type
mapping:
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
response.output_text.delta: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
"OpenAIResponseObjectStreamResponseCompleted":
type: object
@ -5330,6 +5332,33 @@ components:
- type
title: >-
OpenAIResponseObjectStreamResponseCreated
"OpenAIResponseObjectStreamResponseOutputTextDelta":
type: object
properties:
content_index:
type: integer
delta:
type: string
item_id:
type: string
output_index:
type: integer
sequence_number:
type: integer
type:
type: string
const: response.output_text.delta
default: response.output_text.delta
additionalProperties: false
required:
- content_index
- delta
- item_id
- output_index
- sequence_number
- type
title: >-
OpenAIResponseObjectStreamResponseOutputTextDelta
CreateUploadSessionRequest:
type: object
properties:

View file

@ -149,6 +149,16 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.output_text.delta"] = "response.output_text.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
response: OpenAIResponseObject
@ -156,7 +166,9 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseOutputTextDelta
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"),
]
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import time
import uuid
from collections.abc import AsyncIterator
from typing import Any, cast
@ -29,10 +30,12 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.inference.inference import (
@ -255,110 +258,14 @@ class OpenAIResponsesImpl:
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def create_openai_response(
async def _process_response_choices(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
chat_response: OpenAIChatCompletion,
ctx: ChatCompletionContext,
tools: list[OpenAIResponseInputTool] | None,
) -> list[OpenAIResponseOutput]:
"""Handle tool execution and response message creation."""
output_messages: list[OpenAIResponseOutput] = []
stream = False if stream is None else stream
# Huge TODO: we need to run this in a loop, until morale improves
# Create context to run "chat completion"
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
# Run inference
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
# Collect output
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
# TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
async for chunk in chat_response:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# TODO: this only works for text content
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
# Ensure we don't have any empty type field in the tool call dict.
# The OpenAI client used by providers often returns a type=None here.
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
else:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
@ -380,7 +287,128 @@ class OpenAIResponsesImpl:
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# Create response object
return output_messages
async def _store_response(
self,
response: OpenAIResponseObject,
original_input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(original_input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=original_input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in original_input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
original_input = input # Keep reference for storage
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Tool setup
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
inference_result = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
if stream:
return self._create_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
original_input=original_input,
model=model,
store=store,
tools=tools,
)
else:
return await self._create_non_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
original_input=original_input,
model=model,
store=store,
tools=tools,
)
async def _create_non_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
original_input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> OpenAIResponseObject:
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response,
ctx=ctx,
tools=tools,
)
)
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
@ -393,45 +421,135 @@ class OpenAIResponsesImpl:
# Store response if requested
if store:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
await self._store_response(
response=response,
original_input=original_input,
)
if stream:
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
# TODO: response created should actually get emitted much earlier in the process
yield OpenAIResponseObjectStreamResponseCreated(response=response)
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
return async_response()
return response
async def _create_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
original_input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
initial_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="in_progress",
output=output_messages.copy(),
)
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# For streaming, inference_result is an async iterator of chunks
# Stream chunks and emit delta events as they arrive
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in inference_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response_obj = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response_obj,
ctx=ctx,
tools=tools,
)
)
# Create final response
final_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="completed",
output=output_messages,
)
if store:
await self._store_response(
response=final_response,
original_input=original_input,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> tuple[
@ -441,7 +559,6 @@ class OpenAIResponsesImpl:
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseOutputMessageMCPListTools,
)
from llama_stack.apis.tools.tools import Tool

View file

@ -232,9 +232,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) > 0
assert chunks[0].response.output[0].type == "function_call"
assert chunks[0].response.output[0].name == "get_weather"
assert len(chunks) == 2 # Should have response.created and response.completed
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check response.completed event (should have the tool call)
assert chunks[1].type == "response.completed"
assert len(chunks[1].response.output) == 1
assert chunks[1].response.output[0].type == "function_call"
assert chunks[1].response.output[0].name == "get_weather"
@pytest.mark.asyncio

View file

@ -10,17 +10,17 @@ from tests.verifications.openai_api.fixtures.fixtures import _load_all_verificat
def pytest_generate_tests(metafunc):
"""Dynamically parametrize tests based on the selected provider and config."""
if "model" in metafunc.fixturenames:
model = metafunc.config.getoption("model")
if model:
metafunc.parametrize("model", [model])
return
provider = metafunc.config.getoption("provider")
if not provider:
print("Warning: --provider not specified. Skipping model parametrization.")
metafunc.parametrize("model", [])
return
model = metafunc.config.getoption("model")
if model:
metafunc.parametrize("model", [model])
return
try:
config_data = _load_all_verification_configs()
except (OSError, FileNotFoundError) as e:

View file

@ -77,11 +77,12 @@ test_response_image:
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
output: "llama"
# the models are really poor at tool calling after seeing images :/
test_response_multi_turn_image:
test_name: test_response_multi_turn_image
test_params:
case:
- case_id: "llama_image_search"
- case_id: "llama_image_understanding"
turns:
- input:
- role: user
@ -91,7 +92,5 @@ test_response_multi_turn_image:
- type: input_image
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
output: "llama"
- input: "Search the web using the search tool for the animal from the previous response. Your search query should be a single phrase that includes the animal's name and the words 'maverick', 'scout' and 'llm'"
tools:
- type: web_search
output: "model"
- input: "What country do you find this animal primarily in? What continent?"
output: "peru"

View file

@ -7,6 +7,7 @@
import json
import httpx
import openai
import pytest
from llama_stack import LlamaStackAsLibraryClient
@ -61,23 +62,151 @@ def test_response_streaming_basic(request, openai_client, model, provider, verif
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
import time
response = openai_client.responses.create(
model=model,
input=case["input"],
stream=True,
)
streamed_content = []
# Track events and timing to verify proper streaming
events = []
event_times = []
response_id = ""
start_time = time.time()
for chunk in response:
if chunk.type == "response.completed":
current_time = time.time()
event_times.append(current_time - start_time)
events.append(chunk)
if chunk.type == "response.created":
# Verify response.created is emitted first and immediately
assert len(events) == 1, "response.created should be the first event"
assert event_times[0] < 0.1, "response.created should be emitted immediately"
assert chunk.response.status == "in_progress"
response_id = chunk.response.id
streamed_content.append(chunk.response.output_text.strip())
assert len(streamed_content) > 0
assert case["output"].lower() in "".join(streamed_content).lower()
elif chunk.type == "response.completed":
# Verify response.completed comes after response.created
assert len(events) >= 2, "response.completed should come after response.created"
assert chunk.response.status == "completed"
assert chunk.response.id == response_id, "Response ID should be consistent"
# Verify content quality
output_text = chunk.response.output_text.lower().strip()
assert len(output_text) > 0, "Response should have content"
assert case["output"].lower() in output_text, f"Expected '{case['output']}' in response"
# Verify we got both required events
event_types = [event.type for event in events]
assert "response.created" in event_types, "Missing response.created event"
assert "response.completed" in event_types, "Missing response.completed event"
# Verify event order
created_index = event_types.index("response.created")
completed_index = event_types.index("response.completed")
assert created_index < completed_index, "response.created should come before response.completed"
# Verify stored response matches streamed response
retrieved_response = openai_client.responses.retrieve(response_id=response_id)
assert retrieved_response.output_text == "".join(streamed_content)
final_event = events[-1]
assert retrieved_response.output_text == final_event.response.output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_basic"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_streaming_incremental_content(request, openai_client, model, provider, verification_config, case):
"""Test that streaming actually delivers content incrementally, not just at the end."""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
import time
response = openai_client.responses.create(
model=model,
input=case["input"],
stream=True,
)
# Track all events and their content to verify incremental streaming
events = []
content_snapshots = []
event_times = []
start_time = time.time()
for chunk in response:
current_time = time.time()
event_times.append(current_time - start_time)
events.append(chunk)
# Track content at each event based on event type
if chunk.type == "response.output_text.delta":
# For delta events, track the delta content
content_snapshots.append(chunk.delta)
elif hasattr(chunk, "response") and hasattr(chunk.response, "output_text"):
# For response.created/completed events, track the full output_text
content_snapshots.append(chunk.response.output_text)
else:
content_snapshots.append("")
# Verify we have the expected events
event_types = [event.type for event in events]
assert "response.created" in event_types, "Missing response.created event"
assert "response.completed" in event_types, "Missing response.completed event"
# Check if we have incremental content updates
created_index = event_types.index("response.created")
completed_index = event_types.index("response.completed")
# The key test: verify content progression
created_content = content_snapshots[created_index]
completed_content = content_snapshots[completed_index]
# Verify that response.created has empty or minimal content
assert len(created_content) == 0, f"response.created should have empty content, got: {repr(created_content[:100])}"
# Verify that response.completed has the full content
assert len(completed_content) > 0, "response.completed should have content"
assert case["output"].lower() in completed_content.lower(), f"Expected '{case['output']}' in final content"
# Check for true incremental streaming by looking for delta events
delta_events = [i for i, event_type in enumerate(event_types) if event_type == "response.output_text.delta"]
# Assert that we have delta events (true incremental streaming)
assert len(delta_events) > 0, "Expected delta events for true incremental streaming, but found none"
# Verify delta events have content and accumulate to final content
delta_content_total = ""
non_empty_deltas = 0
for delta_idx in delta_events:
delta_content = content_snapshots[delta_idx]
if delta_content:
delta_content_total += delta_content
non_empty_deltas += 1
# Assert that we have meaningful delta content
assert non_empty_deltas > 0, "Delta events found but none contain content"
assert len(delta_content_total) > 0, "Delta events found but total delta content is empty"
# Verify that the accumulated delta content matches the final content
assert delta_content_total.strip() == completed_content.strip(), (
f"Delta content '{delta_content_total}' should match final content '{completed_content}'"
)
# Verify timing: delta events should come between created and completed
for delta_idx in delta_events:
assert created_index < delta_idx < completed_index, (
f"Delta event at index {delta_idx} should be between created ({created_index}) and completed ({completed_index})"
)
@pytest.mark.parametrize(
@ -178,7 +307,7 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
exc_type = (
AuthenticationRequiredError
if isinstance(openai_client, LlamaStackAsLibraryClient)
else httpx.HTTPStatusError
else (httpx.HTTPStatusError, openai.AuthenticationError)
)
with pytest.raises(exc_type):
openai_client.responses.create(