mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat(responses): implement full multi-turn support (#2295)
I think the implementation needs more simplification. Spent way too much time trying to get the tests pass with models not co-operating :( Finally had to switch claude-sonnet to get things to pass reliably. ### Test Plan ``` export TAVILY_SEARCH_API_KEY=... export OPENAI_API_KEY=... uv run pytest -p no:warnings \ -s -v tests/verifications/openai_api/test_responses.py \ --provider=stack:starter \ --model openai/gpt-4o ```
This commit is contained in:
parent
cac7d404a2
commit
dbe4e84aca
9 changed files with 593 additions and 136 deletions
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
# we want the mcp server to be authenticated OR not, depends
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
|
||||
# Unfortunately the toolgroup id must be tied to the tool names because the registry
|
||||
|
@ -13,15 +14,158 @@ from contextlib import contextmanager
|
|||
MCP_TOOLGROUP_ID = "mcp::localmcp"
|
||||
|
||||
|
||||
def default_tools():
|
||||
"""Default tools for backward compatibility."""
|
||||
from mcp import types
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
async def greet_everyone(
|
||||
url: str, ctx: Context
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
return [types.TextContent(type="text", text="Hello, world!")]
|
||||
|
||||
async def get_boiling_point(liquid_name: str, celsius: bool = True) -> int:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celsius or Fahrenheit.
|
||||
|
||||
:param liquid_name: The name of the liquid
|
||||
:param celsius: Whether to return the boiling point in Celsius
|
||||
:return: The boiling point of the liquid in Celcius or Fahrenheit
|
||||
"""
|
||||
if liquid_name.lower() == "myawesomeliquid":
|
||||
if celsius:
|
||||
return -100
|
||||
else:
|
||||
return -212
|
||||
else:
|
||||
return -1
|
||||
|
||||
return {"greet_everyone": greet_everyone, "get_boiling_point": get_boiling_point}
|
||||
|
||||
|
||||
def dependency_tools():
|
||||
"""Tools with natural dependencies for multi-turn testing."""
|
||||
from mcp import types
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
async def get_user_id(username: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the user ID for a given username. This ID is needed for other operations.
|
||||
|
||||
:param username: The username to look up
|
||||
:return: The user ID for the username
|
||||
"""
|
||||
# Simple mapping for testing
|
||||
user_mapping = {"alice": "user_12345", "bob": "user_67890", "charlie": "user_11111", "admin": "user_00000"}
|
||||
return user_mapping.get(username.lower(), "user_99999")
|
||||
|
||||
async def get_user_permissions(user_id: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the permissions for a user ID. Requires a valid user ID from get_user_id.
|
||||
|
||||
:param user_id: The user ID to check permissions for
|
||||
:return: The permissions for the user
|
||||
"""
|
||||
# Permission mapping based on user IDs
|
||||
permission_mapping = {
|
||||
"user_12345": "read,write", # alice
|
||||
"user_67890": "read", # bob
|
||||
"user_11111": "admin", # charlie
|
||||
"user_00000": "superadmin", # admin
|
||||
"user_99999": "none", # unknown users
|
||||
}
|
||||
return permission_mapping.get(user_id, "none")
|
||||
|
||||
async def check_file_access(user_id: str, filename: str, ctx: Context) -> str:
|
||||
"""
|
||||
Check if a user can access a specific file. Requires a valid user ID.
|
||||
|
||||
:param user_id: The user ID to check access for
|
||||
:param filename: The filename to check access to
|
||||
:return: Whether the user can access the file (yes/no)
|
||||
"""
|
||||
# Get permissions first
|
||||
permission_mapping = {
|
||||
"user_12345": "read,write", # alice
|
||||
"user_67890": "read", # bob
|
||||
"user_11111": "admin", # charlie
|
||||
"user_00000": "superadmin", # admin
|
||||
"user_99999": "none", # unknown users
|
||||
}
|
||||
permissions = permission_mapping.get(user_id, "none")
|
||||
|
||||
# Check file access based on permissions and filename
|
||||
if permissions == "superadmin":
|
||||
access = "yes"
|
||||
elif permissions == "admin":
|
||||
access = "yes" if not filename.startswith("secret_") else "no"
|
||||
elif "write" in permissions:
|
||||
access = "yes" if filename.endswith(".txt") else "no"
|
||||
elif "read" in permissions:
|
||||
access = "yes" if filename.endswith(".txt") or filename.endswith(".md") else "no"
|
||||
else:
|
||||
access = "no"
|
||||
|
||||
return [types.TextContent(type="text", text=access)]
|
||||
|
||||
async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the experiment ID for a given experiment name. This ID is needed to get results.
|
||||
|
||||
:param experiment_name: The name of the experiment
|
||||
:return: The experiment ID
|
||||
"""
|
||||
# Simple mapping for testing
|
||||
experiment_mapping = {
|
||||
"temperature_test": "exp_001",
|
||||
"pressure_test": "exp_002",
|
||||
"chemical_reaction": "exp_003",
|
||||
"boiling_point": "exp_004",
|
||||
}
|
||||
exp_id = experiment_mapping.get(experiment_name.lower(), "exp_999")
|
||||
return exp_id
|
||||
|
||||
async def get_experiment_results(experiment_id: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the results for an experiment ID. Requires a valid experiment ID from get_experiment_id.
|
||||
|
||||
:param experiment_id: The experiment ID to get results for
|
||||
:return: The experiment results
|
||||
"""
|
||||
# Results mapping based on experiment IDs
|
||||
results_mapping = {
|
||||
"exp_001": "Temperature: 25°C, Status: Success",
|
||||
"exp_002": "Pressure: 1.2 atm, Status: Success",
|
||||
"exp_003": "Yield: 85%, Status: Complete",
|
||||
"exp_004": "Boiling Point: 100°C, Status: Verified",
|
||||
"exp_999": "No results found",
|
||||
}
|
||||
results = results_mapping.get(experiment_id, "Invalid experiment ID")
|
||||
return results
|
||||
|
||||
return {
|
||||
"get_user_id": get_user_id,
|
||||
"get_user_permissions": get_user_permissions,
|
||||
"check_file_access": check_file_access,
|
||||
"get_experiment_id": get_experiment_id,
|
||||
"get_experiment_results": get_experiment_results,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_mcp_server(required_auth_token: str | None = None):
|
||||
def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Callable] | None = None):
|
||||
"""
|
||||
Create an MCP server with the specified tools.
|
||||
|
||||
:param required_auth_token: Optional auth token required for access
|
||||
:param tools: Dictionary of tool_name -> tool_function. If None, uses default tools.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from mcp import types
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import Response
|
||||
|
@ -29,35 +173,18 @@ def make_mcp_server(required_auth_token: str | None = None):
|
|||
|
||||
server = FastMCP("FastMCP Test Server", log_level="WARNING")
|
||||
|
||||
@server.tool()
|
||||
async def greet_everyone(
|
||||
url: str, ctx: Context
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
return [types.TextContent(type="text", text="Hello, world!")]
|
||||
tools = tools or default_tools()
|
||||
|
||||
@server.tool()
|
||||
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit.
|
||||
|
||||
:param liquid_name: The name of the liquid
|
||||
:param celcius: Whether to return the boiling point in Celcius
|
||||
:return: The boiling point of the liquid in Celcius or Fahrenheit
|
||||
"""
|
||||
if liquid_name.lower() == "polyjuice":
|
||||
if celcius:
|
||||
return -100
|
||||
else:
|
||||
return -212
|
||||
else:
|
||||
return -1
|
||||
# Register all tools with the server
|
||||
for tool_func in tools.values():
|
||||
server.tool()(tool_func)
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
auth_header: str | None = request.headers.get("Authorization")
|
||||
auth_token = None
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
|
|
|
@ -224,16 +224,16 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
|
|
@ -36,7 +36,7 @@ test_response_mcp_tool:
|
|||
test_params:
|
||||
case:
|
||||
- case_id: "boiling_point_tool"
|
||||
input: "What is the boiling point of polyjuice?"
|
||||
input: "What is the boiling point of myawesomeliquid in Celsius?"
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
|
@ -94,3 +94,43 @@ test_response_multi_turn_image:
|
|||
output: "llama"
|
||||
- input: "What country do you find this animal primarily in? What continent?"
|
||||
output: "peru"
|
||||
|
||||
test_response_multi_turn_tool_execution:
|
||||
test_name: test_response_multi_turn_tool_execution
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_file_access_check"
|
||||
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "yes"
|
||||
- case_id: "experiment_results_lookup"
|
||||
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "100°C"
|
||||
|
||||
test_response_multi_turn_tool_execution_streaming:
|
||||
test_name: test_response_multi_turn_tool_execution_streaming
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_permissions_workflow"
|
||||
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "no"
|
||||
- case_id: "experiment_analysis_streaming"
|
||||
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "85%"
|
||||
|
|
|
@ -12,7 +12,7 @@ import pytest
|
|||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from tests.common.mcp import make_mcp_server
|
||||
from tests.common.mcp import dependency_tools, make_mcp_server
|
||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||
case_id_generator,
|
||||
get_base_test_name,
|
||||
|
@ -280,6 +280,7 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
|
|||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(response.output) >= 3
|
||||
list_tools = response.output[0]
|
||||
assert list_tools.type == "mcp_list_tools"
|
||||
|
@ -290,11 +291,12 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
|
|||
call = response.output[1]
|
||||
assert call.type == "mcp_call"
|
||||
assert call.name == "get_boiling_point"
|
||||
assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
|
||||
assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True}
|
||||
assert call.error is None
|
||||
assert "-100" in call.output
|
||||
|
||||
message = response.output[2]
|
||||
# sometimes the model will call the tool again, so we need to get the last message
|
||||
message = response.output[-1]
|
||||
text_content = message.content[0].text
|
||||
assert "boiling point" in text_content.lower()
|
||||
|
||||
|
@ -393,3 +395,154 @@ def test_response_non_streaming_multi_turn_image(request, openai_client, model,
|
|||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn["output"].lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_multi_turn_tool_execution(
|
||||
request, openai_client, model, provider, verification_config, case
|
||||
):
|
||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
# Replace the placeholder URL with the actual server URL
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
response = openai_client.responses.create(
|
||||
input=case["input"],
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we have MCP tool calls in the output
|
||||
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
expected_output = case["output"]
|
||||
assert expected_output.lower() in response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {response.output_text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
async def test_response_streaming_multi_turn_tool_execution(
|
||||
request, openai_client, model, provider, verification_config, case
|
||||
):
|
||||
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
# Replace the placeholder URL with the actual server URL
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
stream = openai_client.responses.create(
|
||||
input=case["input"],
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Should have at least response.created and response.completed
|
||||
assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
|
||||
|
||||
# First chunk should be response.created
|
||||
assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
|
||||
|
||||
# Last chunk should be response.completed
|
||||
assert chunks[-1].type == "response.completed", (
|
||||
f"Last chunk should be response.completed, got {chunks[-1].type}"
|
||||
)
|
||||
|
||||
# Get the final response from the last chunk
|
||||
final_chunk = chunks[-1]
|
||||
if hasattr(final_chunk, "response"):
|
||||
final_response = final_chunk.response
|
||||
|
||||
# Verify multi-turn MCP tool execution results
|
||||
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in final_response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
# Should have at least 1 MCP call (the model should call at least one tool)
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
|
||||
# All MCP calls should be completed (verifies our tool execution works)
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
# Should have at least one final message response
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
# Final message should be from assistant and completed
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", (
|
||||
f"Final message should be from assistant, got {final_message.role}"
|
||||
)
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
# Check that the expected output appears in the response
|
||||
expected_output = case["output"]
|
||||
assert expected_output.lower() in final_response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue