mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-21 09:23:13 +00:00
# What does this PR do? A _bunch_ on cleanup for the Responses tests. - Got rid of YAML test cases, moved them to just use simple pydantic models - Splitting the large monolithic test file into multiple focused test files: - `test_basic_responses.py` for basic and image response tests - `test_tool_responses.py` for tool-related tests - `test_file_search.py` for file search specific tests - Adding a `StreamingValidator` helper class to standardize streaming response validation ## Test Plan Run the tests: ``` pytest -s -v tests/integration/non_ci/responses/ \ --stack-config=starter \ --text-model openai/gpt-4o \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ -k "client_with_models" ```
335 lines
14 KiB
Python
335 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import os
|
|
|
|
import httpx
|
|
import openai
|
|
import pytest
|
|
from fixtures.test_cases import (
|
|
custom_tool_test_cases,
|
|
file_search_test_cases,
|
|
mcp_tool_test_cases,
|
|
multi_turn_tool_execution_streaming_test_cases,
|
|
multi_turn_tool_execution_test_cases,
|
|
web_search_test_cases,
|
|
)
|
|
from helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment
|
|
from streaming_assertions import StreamingValidator
|
|
|
|
from llama_stack import LlamaStackAsLibraryClient
|
|
from llama_stack.core.datatypes import AuthenticationRequiredError
|
|
from tests.common.mcp import dependency_tools, make_mcp_server
|
|
|
|
|
|
@pytest.mark.parametrize("case", web_search_test_cases)
|
|
def test_response_non_streaming_web_search(compat_client, text_model_id, case):
|
|
response = compat_client.responses.create(
|
|
model=text_model_id,
|
|
input=case.input,
|
|
tools=case.tools,
|
|
stream=False,
|
|
)
|
|
assert len(response.output) > 1
|
|
assert response.output[0].type == "web_search_call"
|
|
assert response.output[0].status == "completed"
|
|
assert response.output[1].type == "message"
|
|
assert response.output[1].status == "completed"
|
|
assert response.output[1].role == "assistant"
|
|
assert len(response.output[1].content) > 0
|
|
assert case.expected.lower() in response.output_text.lower().strip()
|
|
|
|
|
|
@pytest.mark.parametrize("case", file_search_test_cases)
|
|
def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case):
|
|
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
|
pytest.skip("Responses API file search is not yet supported in library client.")
|
|
|
|
vector_store = new_vector_store(compat_client, "test_vector_store")
|
|
|
|
if case.file_content:
|
|
file_name = "test_response_non_streaming_file_search.txt"
|
|
file_path = tmp_path / file_name
|
|
file_path.write_text(case.file_content)
|
|
elif case.file_path:
|
|
file_path = os.path.join(os.path.dirname(__file__), "fixtures", case.file_path)
|
|
file_name = os.path.basename(file_path)
|
|
else:
|
|
raise ValueError("No file content or path provided for case")
|
|
|
|
file_response = upload_file(compat_client, file_name, file_path)
|
|
|
|
# Attach our file to the vector store
|
|
compat_client.vector_stores.files.create(
|
|
vector_store_id=vector_store.id,
|
|
file_id=file_response.id,
|
|
)
|
|
|
|
# Wait for the file to be attached
|
|
wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
|
|
|
|
# Update our tools with the right vector store id
|
|
tools = case.tools
|
|
for tool in tools:
|
|
if tool["type"] == "file_search":
|
|
tool["vector_store_ids"] = [vector_store.id]
|
|
|
|
# Create the response request, which should query our vector store
|
|
response = compat_client.responses.create(
|
|
model=text_model_id,
|
|
input=case.input,
|
|
tools=tools,
|
|
stream=False,
|
|
include=["file_search_call.results"],
|
|
)
|
|
|
|
# Verify the file_search_tool was called
|
|
assert len(response.output) > 1
|
|
assert response.output[0].type == "file_search_call"
|
|
assert response.output[0].status == "completed"
|
|
assert response.output[0].queries # ensure it's some non-empty list
|
|
assert response.output[0].results
|
|
assert case.expected.lower() in response.output[0].results[0].text.lower()
|
|
assert response.output[0].results[0].score > 0
|
|
|
|
# Verify the output_text generated by the response
|
|
assert case.expected.lower() in response.output_text.lower().strip()
|
|
|
|
|
|
def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id):
|
|
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
|
pytest.skip("Responses API file search is not yet supported in library client.")
|
|
|
|
vector_store = new_vector_store(compat_client, "test_vector_store")
|
|
|
|
# Create the response request, which should query our vector store
|
|
response = compat_client.responses.create(
|
|
model=text_model_id,
|
|
input="How many experts does the Llama 4 Maverick model have?",
|
|
tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
|
|
stream=False,
|
|
include=["file_search_call.results"],
|
|
)
|
|
|
|
# Verify the file_search_tool was called
|
|
assert len(response.output) > 1
|
|
assert response.output[0].type == "file_search_call"
|
|
assert response.output[0].status == "completed"
|
|
assert response.output[0].queries # ensure it's some non-empty list
|
|
assert not response.output[0].results # ensure we don't get any results
|
|
|
|
# Verify some output_text was generated by the response
|
|
assert response.output_text
|
|
|
|
|
|
@pytest.mark.parametrize("case", mcp_tool_test_cases)
|
|
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
|
|
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
|
pytest.skip("in-process MCP server is only supported in library client")
|
|
|
|
with make_mcp_server() as mcp_server_info:
|
|
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
|
|
|
response = compat_client.responses.create(
|
|
model=text_model_id,
|
|
input=case.input,
|
|
tools=tools,
|
|
stream=False,
|
|
)
|
|
|
|
assert len(response.output) >= 3
|
|
list_tools = response.output[0]
|
|
assert list_tools.type == "mcp_list_tools"
|
|
assert list_tools.server_label == "localmcp"
|
|
assert len(list_tools.tools) == 2
|
|
assert {t.name for t in list_tools.tools} == {
|
|
"get_boiling_point",
|
|
"greet_everyone",
|
|
}
|
|
|
|
call = response.output[1]
|
|
assert call.type == "mcp_call"
|
|
assert call.name == "get_boiling_point"
|
|
assert json.loads(call.arguments) == {
|
|
"liquid_name": "myawesomeliquid",
|
|
"celsius": True,
|
|
}
|
|
assert call.error is None
|
|
assert "-100" in call.output
|
|
|
|
# sometimes the model will call the tool again, so we need to get the last message
|
|
message = response.output[-1]
|
|
text_content = message.content[0].text
|
|
assert "boiling point" in text_content.lower()
|
|
|
|
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
|
|
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
|
|
|
exc_type = (
|
|
AuthenticationRequiredError
|
|
if isinstance(compat_client, LlamaStackAsLibraryClient)
|
|
else (httpx.HTTPStatusError, openai.AuthenticationError)
|
|
)
|
|
with pytest.raises(exc_type):
|
|
compat_client.responses.create(
|
|
model=text_model_id,
|
|
input=case.input,
|
|
tools=tools,
|
|
stream=False,
|
|
)
|
|
|
|
for tool in tools:
|
|
if tool["type"] == "mcp":
|
|
tool["headers"] = {"Authorization": "Bearer test-token"}
|
|
|
|
response = compat_client.responses.create(
|
|
model=text_model_id,
|
|
input=case.input,
|
|
tools=tools,
|
|
stream=False,
|
|
)
|
|
assert len(response.output) >= 3
|
|
|
|
|
|
@pytest.mark.parametrize("case", custom_tool_test_cases)
|
|
def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
|
|
response = compat_client.responses.create(
|
|
model=text_model_id,
|
|
input=case.input,
|
|
tools=case.tools,
|
|
stream=False,
|
|
)
|
|
assert len(response.output) == 1
|
|
assert response.output[0].type == "function_call"
|
|
assert response.output[0].status == "completed"
|
|
assert response.output[0].name == "get_weather"
|
|
|
|
|
|
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
|
|
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
|
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
|
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
|
pytest.skip("in-process MCP server is only supported in library client")
|
|
|
|
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
|
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
|
|
|
response = compat_client.responses.create(
|
|
input=case.input,
|
|
model=text_model_id,
|
|
tools=tools,
|
|
)
|
|
|
|
# Verify we have MCP tool calls in the output
|
|
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
|
|
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
|
|
message_outputs = [output for output in response.output if output.type == "message"]
|
|
|
|
# Should have exactly 1 MCP list tools message (at the beginning)
|
|
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
|
assert mcp_list_tools[0].server_label == "localmcp"
|
|
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
|
expected_tool_names = {
|
|
"get_user_id",
|
|
"get_user_permissions",
|
|
"check_file_access",
|
|
"get_experiment_id",
|
|
"get_experiment_results",
|
|
}
|
|
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
|
|
|
|
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
|
for mcp_call in mcp_calls:
|
|
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
|
|
|
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
|
|
|
final_message = message_outputs[-1]
|
|
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
|
|
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
|
assert len(final_message.content) > 0, "Final message should have content"
|
|
|
|
expected_output = case.expected
|
|
assert expected_output.lower() in response.output_text.lower(), (
|
|
f"Expected '{expected_output}' to appear in response: {response.output_text}"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases)
|
|
def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
|
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
|
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
|
pytest.skip("in-process MCP server is only supported in library client")
|
|
|
|
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
|
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
|
|
|
stream = compat_client.responses.create(
|
|
input=case.input,
|
|
model=text_model_id,
|
|
tools=tools,
|
|
stream=True,
|
|
)
|
|
|
|
chunks = []
|
|
for chunk in stream:
|
|
chunks.append(chunk)
|
|
|
|
# Use validator for common streaming checks
|
|
validator = StreamingValidator(chunks)
|
|
validator.assert_basic_event_sequence()
|
|
validator.assert_response_consistency()
|
|
validator.assert_has_tool_calls()
|
|
validator.assert_has_mcp_events()
|
|
validator.assert_rich_streaming()
|
|
|
|
# Get the final response from the last chunk
|
|
final_chunk = chunks[-1]
|
|
if hasattr(final_chunk, "response"):
|
|
final_response = final_chunk.response
|
|
|
|
# Verify multi-turn MCP tool execution results
|
|
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
|
|
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
|
|
message_outputs = [output for output in final_response.output if output.type == "message"]
|
|
|
|
# Should have exactly 1 MCP list tools message (at the beginning)
|
|
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
|
assert mcp_list_tools[0].server_label == "localmcp"
|
|
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
|
expected_tool_names = {
|
|
"get_user_id",
|
|
"get_user_permissions",
|
|
"check_file_access",
|
|
"get_experiment_id",
|
|
"get_experiment_results",
|
|
}
|
|
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
|
|
|
|
# Should have at least 1 MCP call (the model should call at least one tool)
|
|
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
|
|
|
# All MCP calls should be completed (verifies our tool execution works)
|
|
for mcp_call in mcp_calls:
|
|
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
|
|
|
# Should have at least one final message response
|
|
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
|
|
|
# Final message should be from assistant and completed
|
|
final_message = message_outputs[-1]
|
|
assert final_message.role == "assistant", (
|
|
f"Final message should be from assistant, got {final_message.role}"
|
|
)
|
|
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
|
assert len(final_message.content) > 0, "Final message should have content"
|
|
|
|
# Check that the expected output appears in the response
|
|
expected_output = case.expected
|
|
assert expected_output.lower() in final_response.output_text.lower(), (
|
|
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
|
)
|