feat(responses): implement full multi-turn support

This commit is contained in:
Ashwin Bharambe 2025-05-27 14:32:21 -07:00
parent 6bb174bb05
commit fd15a6832c
5 changed files with 577 additions and 125 deletions

View file

@ -258,6 +258,19 @@ class OpenAIResponsesImpl:
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
def _is_function_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool] | None,
) -> bool:
"""Check if a tool call is a function tool call (client-side) vs non-function (server-side)."""
if not tools:
return False
# If the first tool is a function, assume all tools are functions
# This matches the logic in _process_response_choices
return tools[0].type == "function"
async def _process_response_choices(
self,
chat_response: OpenAIChatCompletion,
@ -270,7 +283,7 @@ class OpenAIResponsesImpl:
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if tools[0].type == "function":
if self._is_function_tool_call(choice.message.tool_calls[0], tools):
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
@ -332,6 +345,7 @@ class OpenAIResponsesImpl:
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
):
stream = False if stream is None else stream
@ -358,58 +372,100 @@ class OpenAIResponsesImpl:
temperature=temperature,
)
inference_result = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
if stream:
return self._create_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
max_infer_iters=max_infer_iters,
)
else:
return await self._create_non_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
max_infer_iters=max_infer_iters,
)
async def _create_non_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
max_infer_iters: int | None,
) -> OpenAIResponseObject:
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
# Implement tool execution loop - handle ALL inference rounds including the first
n_iter = 0
messages = ctx.messages.copy()
current_response = None
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response,
while True:
# Do inference (including the first one)
inference_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.tools,
stream=False,
temperature=ctx.temperature,
)
current_response = OpenAIChatCompletion(**inference_result.model_dump())
# Separate function vs non-function tool calls
function_tool_calls = []
non_function_tool_calls = []
for choice in current_response.choices:
if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls:
if self._is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
# Process response choices based on tool call types
if function_tool_calls:
# For function tool calls, use existing logic and return immediately
current_output_messages = await self._process_response_choices(
chat_response=current_response,
ctx=ctx,
tools=tools,
)
)
output_messages.extend(current_output_messages)
break
elif non_function_tool_calls:
# For non-function tool calls, execute them and continue loop
for choice in current_response.choices:
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
output_messages.extend(tool_outputs)
# Add assistant message and tool responses to messages for next iteration
messages.append(choice.message)
messages.extend(tool_response_messages)
n_iter += 1
if n_iter >= (max_infer_iters or 10):
break
# Continue with next iteration of the loop
continue
else:
# No tool calls - convert response to message and we're done
for choice in current_response.choices:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
break
response = OpenAIResponseObject(
created_at=chat_response.created,
created_at=current_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
@ -429,13 +485,13 @@ class OpenAIResponsesImpl:
async def _create_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
max_infer_iters: int | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
@ -453,8 +509,21 @@ class OpenAIResponsesImpl:
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# For streaming, inference_result is an async iterator of chunks
# Stream chunks and emit delta events as they arrive
# Implement tool execution loop for streaming - handle ALL inference rounds including the first
n_iter = 0
messages = ctx.messages.copy()
while True:
# Do inference (including the first one) - streaming
current_inference_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.tools,
stream=True,
temperature=ctx.temperature,
)
# Process streaming chunks and build complete response
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
@ -466,7 +535,7 @@ class OpenAIResponsesImpl:
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in inference_result:
async for chunk in current_inference_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
@ -487,12 +556,12 @@ class OpenAIResponsesImpl:
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
# Don't attempt to concatenate arguments if we don't have any new arguments
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
if tool_call.function.arguments:
# Guard against an initial None argument before we concatenate
response_tool_call.function.arguments = (
@ -513,7 +582,7 @@ class OpenAIResponsesImpl:
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response_obj = OpenAIChatCompletion(
current_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
@ -526,14 +595,49 @@ class OpenAIResponsesImpl:
model=chunk_model,
)
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response_obj,
# Separate function vs non-function tool calls
function_tool_calls = []
non_function_tool_calls = []
for choice in current_response.choices:
if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls:
if self._is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
# Process response choices based on tool call types
if function_tool_calls:
# For function tool calls, use existing logic and break
current_output_messages = await self._process_response_choices(
chat_response=current_response,
ctx=ctx,
tools=tools,
)
)
output_messages.extend(current_output_messages)
break
elif non_function_tool_calls:
# For non-function tool calls, execute them and continue loop
for choice in current_response.choices:
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
output_messages.extend(tool_outputs)
# Add assistant message and tool responses to messages for next iteration
messages.append(choice.message)
messages.extend(tool_response_messages)
n_iter += 1
if n_iter >= (max_infer_iters or 10):
break
# Continue with next iteration of the loop
continue
else:
# No tool calls - convert response to message and we're done
for choice in current_response.choices:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
break
# Create final response
final_response = OpenAIResponseObject(
@ -646,6 +750,30 @@ class OpenAIResponsesImpl:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_calls_only(
self,
choice: OpenAIChoice,
ctx: ChatCompletionContext,
) -> tuple[list[OpenAIResponseOutput], list[OpenAIMessageParam]]:
"""Execute tool calls and return output messages and tool response messages for next inference."""
output_messages: list[OpenAIResponseOutput] = []
tool_response_messages: list[OpenAIMessageParam] = []
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages, tool_response_messages
if not choice.message.tool_calls:
return output_messages, tool_response_messages
for tool_call in choice.message.tool_calls:
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
tool_response_messages.append(further_input)
return output_messages, tool_response_messages
async def _execute_tool_and_return_final_output(
self,
choice: OpenAIChoice,

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
# we want the mcp server to be authenticated OR not, depends
from collections.abc import Callable
from contextlib import contextmanager
# Unfortunately the toolgroup id must be tied to the tool names because the registry
@ -13,29 +14,16 @@ from contextlib import contextmanager
MCP_TOOLGROUP_ID = "mcp::localmcp"
@contextmanager
def make_mcp_server(required_auth_token: str | None = None):
import threading
import time
import httpx
import uvicorn
def default_tools():
"""Default tools for backward compatibility."""
from mcp import types
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route
from mcp.server.fastmcp import Context
server = FastMCP("FastMCP Test Server", log_level="WARNING")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
@server.tool()
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit.
@ -52,12 +40,151 @@ def make_mcp_server(required_auth_token: str | None = None):
else:
return -1
return {"greet_everyone": greet_everyone, "get_boiling_point": get_boiling_point}
def dependency_tools():
"""Tools with natural dependencies for multi-turn testing."""
from mcp import types
from mcp.server.fastmcp import Context
async def get_user_id(username: str, ctx: Context) -> str:
"""
Get the user ID for a given username. This ID is needed for other operations.
:param username: The username to look up
:return: The user ID for the username
"""
# Simple mapping for testing
user_mapping = {"alice": "user_12345", "bob": "user_67890", "charlie": "user_11111", "admin": "user_00000"}
return user_mapping.get(username.lower(), "user_99999")
async def get_user_permissions(user_id: str, ctx: Context) -> str:
"""
Get the permissions for a user ID. Requires a valid user ID from get_user_id.
:param user_id: The user ID to check permissions for
:return: The permissions for the user
"""
# Permission mapping based on user IDs
permission_mapping = {
"user_12345": "read,write", # alice
"user_67890": "read", # bob
"user_11111": "admin", # charlie
"user_00000": "superadmin", # admin
"user_99999": "none", # unknown users
}
return permission_mapping.get(user_id, "none")
async def check_file_access(user_id: str, filename: str, ctx: Context) -> str:
"""
Check if a user can access a specific file. Requires a valid user ID.
:param user_id: The user ID to check access for
:param filename: The filename to check access to
:return: Whether the user can access the file (yes/no)
"""
# Get permissions first
permission_mapping = {
"user_12345": "read,write", # alice
"user_67890": "read", # bob
"user_11111": "admin", # charlie
"user_00000": "superadmin", # admin
"user_99999": "none", # unknown users
}
permissions = permission_mapping.get(user_id, "none")
# Check file access based on permissions and filename
if permissions == "superadmin":
access = "yes"
elif permissions == "admin":
access = "yes" if not filename.startswith("secret_") else "no"
elif "write" in permissions:
access = "yes" if filename.endswith(".txt") else "no"
elif "read" in permissions:
access = "yes" if filename.endswith(".txt") or filename.endswith(".md") else "no"
else:
access = "no"
return [types.TextContent(type="text", text=access)]
async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
"""
Get the experiment ID for a given experiment name. This ID is needed to get results.
:param experiment_name: The name of the experiment
:return: The experiment ID
"""
# Simple mapping for testing
experiment_mapping = {
"temperature_test": "exp_001",
"pressure_test": "exp_002",
"chemical_reaction": "exp_003",
"boiling_point": "exp_004",
}
exp_id = experiment_mapping.get(experiment_name.lower(), "exp_999")
return exp_id
async def get_experiment_results(experiment_id: str, ctx: Context) -> str:
"""
Get the results for an experiment ID. Requires a valid experiment ID from get_experiment_id.
:param experiment_id: The experiment ID to get results for
:return: The experiment results
"""
# Results mapping based on experiment IDs
results_mapping = {
"exp_001": "Temperature: 25°C, Status: Success",
"exp_002": "Pressure: 1.2 atm, Status: Success",
"exp_003": "Yield: 85%, Status: Complete",
"exp_004": "Boiling Point: 100°C, Status: Verified",
"exp_999": "No results found",
}
results = results_mapping.get(experiment_id, "Invalid experiment ID")
return results
return {
"get_user_id": get_user_id,
"get_user_permissions": get_user_permissions,
"check_file_access": check_file_access,
"get_experiment_id": get_experiment_id,
"get_experiment_results": get_experiment_results,
}
@contextmanager
def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Callable] | None = None):
"""
Create an MCP server with the specified tools.
:param required_auth_token: Optional auth token required for access
:param tools: Dictionary of tool_name -> tool_function. If None, uses default tools.
"""
import threading
import time
import httpx
import uvicorn
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route
server = FastMCP("FastMCP Test Server", log_level="WARNING")
tools = tools or default_tools()
# Register all tools with the server
for tool_func in tools.values():
server.tool()(tool_func)
sse = SseServerTransport("/messages/")
async def handle_sse(request):
from starlette.exceptions import HTTPException
auth_header = request.headers.get("Authorization")
auth_header: str | None = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]

View file

@ -224,16 +224,16 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
],
)
# Verify
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0

View file

@ -94,3 +94,43 @@ test_response_multi_turn_image:
output: "llama"
- input: "What country do you find this animal primarily in? What continent?"
output: "peru"
test_response_multi_turn_tool_execution:
test_name: test_response_multi_turn_tool_execution
test_params:
case:
- case_id: "user_file_access_check"
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Tell me the final result."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "yes"
- case_id: "experiment_results_lookup"
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "100°C"
test_response_multi_turn_tool_execution_streaming:
test_name: test_response_multi_turn_tool_execution_streaming
test_params:
case:
- case_id: "user_permissions_workflow"
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
stream: true
output: "no"
- case_id: "experiment_analysis_streaming"
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
stream: true
output: "85%"

View file

@ -12,7 +12,7 @@ import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from tests.common.mcp import make_mcp_server
from tests.common.mcp import dependency_tools, make_mcp_server
from tests.verifications.openai_api.fixtures.fixtures import (
case_id_generator,
get_base_test_name,
@ -280,6 +280,7 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
@ -294,7 +295,12 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
assert call.error is None
assert "-100" in call.output
message = response.output[2]
from rich.pretty import pprint
pprint(response)
# sometimes the model will call the tool again, so we need to get the last message
message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
@ -393,3 +399,154 @@ def test_response_non_streaming_multi_turn_image(request, openai_client, model,
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn["output"].lower() in output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_multi_turn_tool_execution(
request, openai_client, model, provider, verification_config, case
):
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = case["tools"]
# Replace the placeholder URL with the actual server URL
for tool in tools:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
response = openai_client.responses.create(
input=case["input"],
model=model,
tools=tools,
)
# Verify we have MCP tool calls in the output
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
message_outputs = [output for output in response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
final_message = message_outputs[-1]
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
expected_output = case["output"]
assert expected_output.lower() in response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {response.output_text}"
)
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
ids=case_id_generator,
)
async def test_response_streaming_multi_turn_tool_execution(
request, openai_client, model, provider, verification_config, case
):
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = case["tools"]
# Replace the placeholder URL with the actual server URL
for tool in tools:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
stream = openai_client.responses.create(
input=case["input"],
model=model,
tools=tools,
stream=True,
)
chunks = []
async for chunk in stream:
chunks.append(chunk)
# Should have at least response.created and response.completed
assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
# First chunk should be response.created
assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
# Last chunk should be response.completed
assert chunks[-1].type == "response.completed", (
f"Last chunk should be response.completed, got {chunks[-1].type}"
)
# Get the final response from the last chunk
final_chunk = chunks[-1]
if hasattr(final_chunk, "response"):
final_response = final_chunk.response
# Verify multi-turn MCP tool execution results
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
message_outputs = [output for output in final_response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
# Should have at least 1 MCP call (the model should call at least one tool)
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
# All MCP calls should be completed (verifies our tool execution works)
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
# Should have at least one final message response
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
# Final message should be from assistant and completed
final_message = message_outputs[-1]
assert final_message.role == "assistant", (
f"Final message should be from assistant, got {final_message.role}"
)
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
# Check that the expected output appears in the response
expected_output = case["output"]
assert expected_output.lower() in final_response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
)