feat(responses): implement full multi-turn support (#2295)

I think the implementation needs more simplification. Spent way too much
time trying to get the tests pass with models not co-operating :(
Finally had to switch claude-sonnet to get things to pass reliably.

### Test Plan

```
export TAVILY_SEARCH_API_KEY=...
export OPENAI_API_KEY=...

uv run pytest -p no:warnings \
   -s -v tests/verifications/openai_api/test_responses.py \
 --provider=stack:starter \
  --model openai/gpt-4o
```
This commit is contained in:
Ashwin Bharambe 2025-06-02 15:35:49 -07:00 committed by GitHub
parent cac7d404a2
commit dbe4e84aca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 593 additions and 136 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
# we want the mcp server to be authenticated OR not, depends
from collections.abc import Callable
from contextlib import contextmanager
# Unfortunately the toolgroup id must be tied to the tool names because the registry
@ -13,15 +14,158 @@ from contextlib import contextmanager
MCP_TOOLGROUP_ID = "mcp::localmcp"
def default_tools():
"""Default tools for backward compatibility."""
from mcp import types
from mcp.server.fastmcp import Context
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
async def get_boiling_point(liquid_name: str, celsius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celsius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celsius: Whether to return the boiling point in Celsius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "myawesomeliquid":
if celsius:
return -100
else:
return -212
else:
return -1
return {"greet_everyone": greet_everyone, "get_boiling_point": get_boiling_point}
def dependency_tools():
"""Tools with natural dependencies for multi-turn testing."""
from mcp import types
from mcp.server.fastmcp import Context
async def get_user_id(username: str, ctx: Context) -> str:
"""
Get the user ID for a given username. This ID is needed for other operations.
:param username: The username to look up
:return: The user ID for the username
"""
# Simple mapping for testing
user_mapping = {"alice": "user_12345", "bob": "user_67890", "charlie": "user_11111", "admin": "user_00000"}
return user_mapping.get(username.lower(), "user_99999")
async def get_user_permissions(user_id: str, ctx: Context) -> str:
"""
Get the permissions for a user ID. Requires a valid user ID from get_user_id.
:param user_id: The user ID to check permissions for
:return: The permissions for the user
"""
# Permission mapping based on user IDs
permission_mapping = {
"user_12345": "read,write", # alice
"user_67890": "read", # bob
"user_11111": "admin", # charlie
"user_00000": "superadmin", # admin
"user_99999": "none", # unknown users
}
return permission_mapping.get(user_id, "none")
async def check_file_access(user_id: str, filename: str, ctx: Context) -> str:
"""
Check if a user can access a specific file. Requires a valid user ID.
:param user_id: The user ID to check access for
:param filename: The filename to check access to
:return: Whether the user can access the file (yes/no)
"""
# Get permissions first
permission_mapping = {
"user_12345": "read,write", # alice
"user_67890": "read", # bob
"user_11111": "admin", # charlie
"user_00000": "superadmin", # admin
"user_99999": "none", # unknown users
}
permissions = permission_mapping.get(user_id, "none")
# Check file access based on permissions and filename
if permissions == "superadmin":
access = "yes"
elif permissions == "admin":
access = "yes" if not filename.startswith("secret_") else "no"
elif "write" in permissions:
access = "yes" if filename.endswith(".txt") else "no"
elif "read" in permissions:
access = "yes" if filename.endswith(".txt") or filename.endswith(".md") else "no"
else:
access = "no"
return [types.TextContent(type="text", text=access)]
async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
"""
Get the experiment ID for a given experiment name. This ID is needed to get results.
:param experiment_name: The name of the experiment
:return: The experiment ID
"""
# Simple mapping for testing
experiment_mapping = {
"temperature_test": "exp_001",
"pressure_test": "exp_002",
"chemical_reaction": "exp_003",
"boiling_point": "exp_004",
}
exp_id = experiment_mapping.get(experiment_name.lower(), "exp_999")
return exp_id
async def get_experiment_results(experiment_id: str, ctx: Context) -> str:
"""
Get the results for an experiment ID. Requires a valid experiment ID from get_experiment_id.
:param experiment_id: The experiment ID to get results for
:return: The experiment results
"""
# Results mapping based on experiment IDs
results_mapping = {
"exp_001": "Temperature: 25°C, Status: Success",
"exp_002": "Pressure: 1.2 atm, Status: Success",
"exp_003": "Yield: 85%, Status: Complete",
"exp_004": "Boiling Point: 100°C, Status: Verified",
"exp_999": "No results found",
}
results = results_mapping.get(experiment_id, "Invalid experiment ID")
return results
return {
"get_user_id": get_user_id,
"get_user_permissions": get_user_permissions,
"check_file_access": check_file_access,
"get_experiment_id": get_experiment_id,
"get_experiment_results": get_experiment_results,
}
@contextmanager
def make_mcp_server(required_auth_token: str | None = None):
def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Callable] | None = None):
"""
Create an MCP server with the specified tools.
:param required_auth_token: Optional auth token required for access
:param tools: Dictionary of tool_name -> tool_function. If None, uses default tools.
"""
import threading
import time
import httpx
import uvicorn
from mcp import types
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
@ -29,35 +173,18 @@ def make_mcp_server(required_auth_token: str | None = None):
server = FastMCP("FastMCP Test Server", log_level="WARNING")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
tools = tools or default_tools()
@server.tool()
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1
# Register all tools with the server
for tool_func in tools.values():
server.tool()(tool_func)
sse = SseServerTransport("/messages/")
async def handle_sse(request):
from starlette.exceptions import HTTPException
auth_header = request.headers.get("Authorization")
auth_header: str | None = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]

View file

@ -224,16 +224,16 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
],
)
# Verify
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0

View file

@ -36,7 +36,7 @@ test_response_mcp_tool:
test_params:
case:
- case_id: "boiling_point_tool"
input: "What is the boiling point of polyjuice?"
input: "What is the boiling point of myawesomeliquid in Celsius?"
tools:
- type: mcp
server_label: "localmcp"
@ -94,3 +94,43 @@ test_response_multi_turn_image:
output: "llama"
- input: "What country do you find this animal primarily in? What continent?"
output: "peru"
test_response_multi_turn_tool_execution:
test_name: test_response_multi_turn_tool_execution
test_params:
case:
- case_id: "user_file_access_check"
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "yes"
- case_id: "experiment_results_lookup"
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "100°C"
test_response_multi_turn_tool_execution_streaming:
test_name: test_response_multi_turn_tool_execution_streaming
test_params:
case:
- case_id: "user_permissions_workflow"
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
stream: true
output: "no"
- case_id: "experiment_analysis_streaming"
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
stream: true
output: "85%"

View file

@ -12,7 +12,7 @@ import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from tests.common.mcp import make_mcp_server
from tests.common.mcp import dependency_tools, make_mcp_server
from tests.verifications.openai_api.fixtures.fixtures import (
case_id_generator,
get_base_test_name,
@ -280,6 +280,7 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
@ -290,11 +291,12 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True}
assert call.error is None
assert "-100" in call.output
message = response.output[2]
# sometimes the model will call the tool again, so we need to get the last message
message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
@ -393,3 +395,154 @@ def test_response_non_streaming_multi_turn_image(request, openai_client, model,
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn["output"].lower() in output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_multi_turn_tool_execution(
request, openai_client, model, provider, verification_config, case
):
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = case["tools"]
# Replace the placeholder URL with the actual server URL
for tool in tools:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
response = openai_client.responses.create(
input=case["input"],
model=model,
tools=tools,
)
# Verify we have MCP tool calls in the output
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
message_outputs = [output for output in response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
final_message = message_outputs[-1]
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
expected_output = case["output"]
assert expected_output.lower() in response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {response.output_text}"
)
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
ids=case_id_generator,
)
async def test_response_streaming_multi_turn_tool_execution(
request, openai_client, model, provider, verification_config, case
):
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = case["tools"]
# Replace the placeholder URL with the actual server URL
for tool in tools:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
stream = openai_client.responses.create(
input=case["input"],
model=model,
tools=tools,
stream=True,
)
chunks = []
async for chunk in stream:
chunks.append(chunk)
# Should have at least response.created and response.completed
assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
# First chunk should be response.created
assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
# Last chunk should be response.completed
assert chunks[-1].type == "response.completed", (
f"Last chunk should be response.completed, got {chunks[-1].type}"
)
# Get the final response from the last chunk
final_chunk = chunks[-1]
if hasattr(final_chunk, "response"):
final_response = final_chunk.response
# Verify multi-turn MCP tool execution results
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
message_outputs = [output for output in final_response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
# Should have at least 1 MCP call (the model should call at least one tool)
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
# All MCP calls should be completed (verifies our tool execution works)
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
# Should have at least one final message response
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
# Final message should be from assistant and completed
final_message = message_outputs[-1]
assert final_message.role == "assistant", (
f"Final message should be from assistant, got {final_message.role}"
)
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
# Check that the expected output appears in the response
expected_output = case["output"]
assert expected_output.lower() in final_response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
)