feat(tests): introduce a test "suite" concept to encompass dirs, options (#3339)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
UI Tests / ui-tests (22) (push) Successful in 33s
Pre-commit / pre-commit (push) Successful in 1m15s

Our integration tests need to be 'grouped' because each group often
needs a specific set of models it works with. We separated vision tests
due to this, and we have a separate set of tests which test "Responses"
API.

This PR makes this system a bit more official so it is very easy to
target these groups and apply all testing infrastructure towards all the
groups (for example, record-replay) uniformly.

There are three suites declared:
- base
- vision
- responses

Note that our CI currently runs the "base" and "vision" suites.

You can use the `--suite` option when running pytest (or any of the
testing scripts or workflows.) For example:
```
OLLAMA_URL=http://localhost:11434 \
  pytest -s -v tests/integration/ --stack-config starter --suite vision
```
This commit is contained in:
Ashwin Bharambe 2025-09-05 13:58:49 -07:00 committed by GitHub
parent 0c2757a05b
commit 47b640370e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 255 additions and 161 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
import pytest
import yaml
from openai import OpenAI
from llama_stack import LlamaStackAsLibraryClient
# --- Helper Functions ---
def _load_all_verification_configs():
"""Load and aggregate verification configs from the conf/ directory."""
# Note: Path is relative to *this* file (fixtures.py)
conf_dir = Path(__file__).parent.parent.parent / "conf"
if not conf_dir.is_dir():
# Use pytest.fail if called during test collection, otherwise raise error
# For simplicity here, we'll raise an error, assuming direct calls
# are less likely or can handle it.
raise FileNotFoundError(f"Verification config directory not found at {conf_dir}")
all_provider_configs = {}
yaml_files = list(conf_dir.glob("*.yaml"))
if not yaml_files:
raise FileNotFoundError(f"No YAML configuration files found in {conf_dir}")
for config_path in yaml_files:
provider_name = config_path.stem
try:
with open(config_path) as f:
provider_config = yaml.safe_load(f)
if provider_config:
all_provider_configs[provider_name] = provider_config
else:
# Log warning if possible, or just skip empty files silently
print(f"Warning: Config file {config_path} is empty or invalid.")
except Exception as e:
raise OSError(f"Error loading config file {config_path}: {e}") from e
return {"providers": all_provider_configs}
# --- End Helper Functions ---
@pytest.fixture(scope="session")
def verification_config():
"""Pytest fixture to provide the loaded verification config."""
try:
return _load_all_verification_configs()
except (OSError, FileNotFoundError) as e:
pytest.fail(str(e)) # Fail test collection if config loading fails
@pytest.fixture(scope="session")
def provider(request, verification_config):
provider = request.config.getoption("--provider")
base_url = request.config.getoption("--base-url")
if provider and base_url and verification_config["providers"][provider]["base_url"] != base_url:
raise ValueError(f"Provider {provider} is not supported for base URL {base_url}")
if not provider:
if not base_url:
raise ValueError("Provider and base URL are not provided")
for provider, metadata in verification_config["providers"].items():
if metadata["base_url"] == base_url:
provider = provider
break
return provider
@pytest.fixture(scope="session")
def base_url(request, provider, verification_config):
return request.config.getoption("--base-url") or verification_config.get("providers", {}).get(provider, {}).get(
"base_url"
)
@pytest.fixture(scope="session")
def api_key(request, provider, verification_config):
provider_conf = verification_config.get("providers", {}).get(provider, {})
api_key_env_var = provider_conf.get("api_key_var")
key_from_option = request.config.getoption("--api-key")
key_from_env = os.getenv(api_key_env_var) if api_key_env_var else None
final_key = key_from_option or key_from_env
return final_key
@pytest.fixture
def model_mapping(provider, providers_model_mapping):
return providers_model_mapping[provider]
@pytest.fixture(scope="session")
def openai_client(base_url, api_key, provider):
# Simplify running against a local Llama Stack
if base_url and "localhost" in base_url and not api_key:
api_key = "empty"
if provider.startswith("stack:"):
parts = provider.split(":")
if len(parts) != 2:
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
config = parts[1]
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
return client
return OpenAI(
base_url=base_url,
api_key=api_key,
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 139 KiB

View file

@ -0,0 +1,262 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import pytest
from pydantic import BaseModel
class ResponsesTestCase(BaseModel):
# Input can be a simple string or complex message structure
input: str | list[dict[str, Any]]
expected: str
# Tools as flexible dict structure (gets validated at runtime by the API)
tools: list[dict[str, Any]] | None = None
# Multi-turn conversations with input/output pairs
turns: list[tuple[str | list[dict[str, Any]], str]] | None = None
# File search specific fields
file_content: str | None = None
file_path: str | None = None
# Streaming flag
stream: bool | None = None
# Basic response test cases
basic_test_cases = [
pytest.param(
ResponsesTestCase(
input="Which planet do humans live on?",
expected="earth",
),
id="earth",
),
pytest.param(
ResponsesTestCase(
input="Which planet has rings around it with a name starting with letter S?",
expected="saturn",
),
id="saturn",
),
pytest.param(
ResponsesTestCase(
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "what teams are playing in this image?",
}
],
},
{
"role": "user",
"content": [
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg",
}
],
},
],
expected="brooklyn nets",
),
id="image_input",
),
]
# Multi-turn test cases
multi_turn_test_cases = [
pytest.param(
ResponsesTestCase(
input="", # Not used for multi-turn
expected="", # Not used for multi-turn
turns=[
("Which planet do humans live on?", "earth"),
("What is the name of the planet from your previous response?", "earth"),
],
),
id="earth",
),
]
# Web search test cases
web_search_test_cases = [
pytest.param(
ResponsesTestCase(
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "web_search", "search_context_size": "low"}],
expected="128",
),
id="llama_experts",
),
]
# File search test cases
file_search_test_cases = [
pytest.param(
ResponsesTestCase(
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "file_search"}],
expected="128",
file_content="Llama 4 Maverick has 128 experts",
),
id="llama_experts",
),
pytest.param(
ResponsesTestCase(
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "file_search"}],
expected="128",
file_path="pdfs/llama_stack_and_models.pdf",
),
id="llama_experts_pdf",
),
]
# MCP tool test cases
mcp_tool_test_cases = [
pytest.param(
ResponsesTestCase(
input="What is the boiling point of myawesomeliquid in Celsius?",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="Hello, world!",
),
id="boiling_point_tool",
),
]
# Custom tool test cases
custom_tool_test_cases = [
pytest.param(
ResponsesTestCase(
input="What's the weather like in San Francisco?",
tools=[
{
"type": "function",
"name": "get_weather",
"description": "Get current temperature for a given location.",
"parameters": {
"additionalProperties": False,
"properties": {
"location": {
"description": "City and country e.g. Bogotá, Colombia",
"type": "string",
}
},
"required": ["location"],
"type": "object",
},
}
],
expected="", # No specific expected output for custom tools
),
id="sf_weather",
),
]
# Image test cases
image_test_cases = [
pytest.param(
ResponsesTestCase(
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "Identify the type of animal in this image.",
},
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg",
},
],
},
],
expected="llama",
),
id="llama_image",
),
]
# Multi-turn image test cases
multi_turn_image_test_cases = [
pytest.param(
ResponsesTestCase(
input="", # Not used for multi-turn
expected="", # Not used for multi-turn
turns=[
(
[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'.",
},
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg",
},
],
},
],
"llama",
),
(
"What country do you find this animal primarily in? What continent?",
"peru",
),
],
),
id="llama_image_understanding",
),
]
# Multi-turn tool execution test cases
multi_turn_tool_execution_test_cases = [
pytest.param(
ResponsesTestCase(
input="I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response.",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="yes",
),
id="user_file_access_check",
),
pytest.param(
ResponsesTestCase(
input="I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me the boiling point in Celsius.",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="100°C",
),
id="experiment_results_lookup",
),
]
# Multi-turn tool execution streaming test cases
multi_turn_tool_execution_streaming_test_cases = [
pytest.param(
ResponsesTestCase(
input="Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response.",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="no",
stream=True,
),
id="user_permissions_workflow",
),
pytest.param(
ResponsesTestCase(
input="I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Return only one tool call per step. Please stream your analysis process.",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="85%",
stream=True,
),
id="experiment_analysis_streaming",
),
]

View file

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
def new_vector_store(openai_client, name):
"""Create a new vector store, cleaning up any existing one with the same name."""
# Ensure we don't reuse an existing vector store
vector_stores = openai_client.vector_stores.list()
for vector_store in vector_stores:
if vector_store.name == name:
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
# Create a new vector store
vector_store = openai_client.vector_stores.create(name=name)
return vector_store
def upload_file(openai_client, name, file_path):
"""Upload a file, cleaning up any existing file with the same name."""
# Ensure we don't reuse an existing file
files = openai_client.files.list()
for file in files:
if file.filename == name:
openai_client.files.delete(file_id=file.id)
# Upload a text file with our document content
return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants")
def wait_for_file_attachment(compat_client, vector_store_id, file_id):
"""Wait for a file to be attached to a vector store."""
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store_id,
file_id=file_id,
)
while file_attach_response.status == "in_progress":
time.sleep(0.1)
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store_id,
file_id=file_id,
)
assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}"
assert not file_attach_response.last_error
return file_attach_response
def setup_mcp_tools(tools, mcp_server_info):
"""Replace placeholder MCP server URLs with actual server info."""
# Create a deep copy to avoid modifying the original test case
import copy
tools_copy = copy.deepcopy(tools)
for tool in tools_copy:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
return tools_copy

View file

@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
class StreamingValidator:
"""Helper class for validating streaming response events."""
def __init__(self, chunks: list[Any]):
self.chunks = chunks
self.event_types = [chunk.type for chunk in chunks]
def assert_basic_event_sequence(self):
"""Verify basic created -> completed event sequence."""
assert len(self.chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(self.chunks)}"
assert self.chunks[0].type == "response.created", (
f"First chunk should be response.created, got {self.chunks[0].type}"
)
assert self.chunks[-1].type == "response.completed", (
f"Last chunk should be response.completed, got {self.chunks[-1].type}"
)
# Verify event order
created_index = self.event_types.index("response.created")
completed_index = self.event_types.index("response.completed")
assert created_index < completed_index, "response.created should come before response.completed"
def assert_response_consistency(self):
"""Verify response ID consistency across events."""
response_ids = set()
for chunk in self.chunks:
if hasattr(chunk, "response_id"):
response_ids.add(chunk.response_id)
elif hasattr(chunk, "response") and hasattr(chunk.response, "id"):
response_ids.add(chunk.response.id)
assert len(response_ids) == 1, f"All events should reference the same response_id, found: {response_ids}"
def assert_has_incremental_content(self):
"""Verify that content is delivered incrementally via delta events."""
delta_events = [
i for i, event_type in enumerate(self.event_types) if event_type == "response.output_text.delta"
]
assert len(delta_events) > 0, "Expected delta events for true incremental streaming, but found none"
# Verify delta events have content
non_empty_deltas = 0
delta_content_total = ""
for delta_idx in delta_events:
chunk = self.chunks[delta_idx]
if hasattr(chunk, "delta") and chunk.delta:
delta_content_total += chunk.delta
non_empty_deltas += 1
assert non_empty_deltas > 0, "Delta events found but none contain content"
assert len(delta_content_total) > 0, "Delta events found but total delta content is empty"
return delta_content_total
def assert_content_quality(self, expected_content: str):
"""Verify the final response contains expected content."""
final_chunk = self.chunks[-1]
if hasattr(final_chunk, "response"):
output_text = final_chunk.response.output_text.lower().strip()
assert len(output_text) > 0, "Response should have content"
assert expected_content.lower() in output_text, f"Expected '{expected_content}' in response"
def assert_has_tool_calls(self):
"""Verify tool call streaming events are present."""
# Check for tool call events
delta_events = [
chunk
for chunk in self.chunks
if chunk.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"]
]
done_events = [
chunk
for chunk in self.chunks
if chunk.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"]
]
assert len(delta_events) > 0, f"Expected tool call delta events, got chunk types: {self.event_types}"
assert len(done_events) > 0, f"Expected tool call done events, got chunk types: {self.event_types}"
# Verify output item events
item_added_events = [chunk for chunk in self.chunks if chunk.type == "response.output_item.added"]
item_done_events = [chunk for chunk in self.chunks if chunk.type == "response.output_item.done"]
assert len(item_added_events) > 0, (
f"Expected response.output_item.added events, got chunk types: {self.event_types}"
)
assert len(item_done_events) > 0, (
f"Expected response.output_item.done events, got chunk types: {self.event_types}"
)
def assert_has_mcp_events(self):
"""Verify MCP-specific streaming events are present."""
# Tool execution progress events
mcp_in_progress_events = [chunk for chunk in self.chunks if chunk.type == "response.mcp_call.in_progress"]
mcp_completed_events = [chunk for chunk in self.chunks if chunk.type == "response.mcp_call.completed"]
assert len(mcp_in_progress_events) > 0, (
f"Expected response.mcp_call.in_progress events, got chunk types: {self.event_types}"
)
assert len(mcp_completed_events) > 0, (
f"Expected response.mcp_call.completed events, got chunk types: {self.event_types}"
)
# MCP list tools events
mcp_list_tools_in_progress_events = [
chunk for chunk in self.chunks if chunk.type == "response.mcp_list_tools.in_progress"
]
mcp_list_tools_completed_events = [
chunk for chunk in self.chunks if chunk.type == "response.mcp_list_tools.completed"
]
assert len(mcp_list_tools_in_progress_events) > 0, (
f"Expected response.mcp_list_tools.in_progress events, got chunk types: {self.event_types}"
)
assert len(mcp_list_tools_completed_events) > 0, (
f"Expected response.mcp_list_tools.completed events, got chunk types: {self.event_types}"
)
def assert_rich_streaming(self, min_chunks: int = 10):
"""Verify we have substantial streaming activity."""
assert len(self.chunks) > min_chunks, (
f"Expected rich streaming with many events, got only {len(self.chunks)} chunks"
)
def validate_event_structure(self):
"""Validate the structure of various event types."""
for chunk in self.chunks:
if chunk.type == "response.created":
assert chunk.response.status == "in_progress"
elif chunk.type == "response.completed":
assert chunk.response.status == "completed"
elif hasattr(chunk, "item_id"):
assert chunk.item_id, "Events with item_id should have non-empty item_id"
elif hasattr(chunk, "sequence_number"):
assert isinstance(chunk.sequence_number, int), "sequence_number should be an integer"

View file

@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import pytest
from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_image_test_cases, multi_turn_test_cases
from .streaming_assertions import StreamingValidator
@pytest.mark.parametrize("case", basic_test_cases)
def test_response_non_streaming_basic(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
stream=False,
)
output_text = response.output_text.lower().strip()
assert len(output_text) > 0
assert case.expected.lower() in output_text
retrieved_response = compat_client.responses.retrieve(response_id=response.id)
assert retrieved_response.output_text == response.output_text
next_response = compat_client.responses.create(
model=text_model_id,
input="Repeat your previous response in all caps.",
previous_response_id=response.id,
)
next_output_text = next_response.output_text.strip()
assert case.expected.upper() in next_output_text
@pytest.mark.parametrize("case", basic_test_cases)
def test_response_streaming_basic(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
stream=True,
)
# Track events and timing to verify proper streaming
events = []
event_times = []
response_id = ""
start_time = time.time()
for chunk in response:
current_time = time.time()
event_times.append(current_time - start_time)
events.append(chunk)
if chunk.type == "response.created":
# Verify response.created is emitted first and immediately
assert len(events) == 1, "response.created should be the first event"
assert event_times[0] < 0.1, "response.created should be emitted immediately"
assert chunk.response.status == "in_progress"
response_id = chunk.response.id
elif chunk.type == "response.completed":
# Verify response.completed comes after response.created
assert len(events) >= 2, "response.completed should come after response.created"
assert chunk.response.status == "completed"
assert chunk.response.id == response_id, "Response ID should be consistent"
# Verify content quality
output_text = chunk.response.output_text.lower().strip()
assert len(output_text) > 0, "Response should have content"
assert case.expected.lower() in output_text, f"Expected '{case.expected}' in response"
# Use validator for common checks
validator = StreamingValidator(events)
validator.assert_basic_event_sequence()
validator.assert_response_consistency()
# Verify stored response matches streamed response
retrieved_response = compat_client.responses.retrieve(response_id=response_id)
final_event = events[-1]
assert retrieved_response.output_text == final_event.response.output_text
@pytest.mark.parametrize("case", basic_test_cases)
def test_response_streaming_incremental_content(compat_client, text_model_id, case):
"""Test that streaming actually delivers content incrementally, not just at the end."""
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
stream=True,
)
# Track all events and their content to verify incremental streaming
events = []
content_snapshots = []
event_times = []
start_time = time.time()
for chunk in response:
current_time = time.time()
event_times.append(current_time - start_time)
events.append(chunk)
# Track content at each event based on event type
if chunk.type == "response.output_text.delta":
# For delta events, track the delta content
content_snapshots.append(chunk.delta)
elif hasattr(chunk, "response") and hasattr(chunk.response, "output_text"):
# For response.created/completed events, track the full output_text
content_snapshots.append(chunk.response.output_text)
else:
content_snapshots.append("")
validator = StreamingValidator(events)
validator.assert_basic_event_sequence()
# Check if we have incremental content updates
event_types = [event.type for event in events]
created_index = event_types.index("response.created")
completed_index = event_types.index("response.completed")
# The key test: verify content progression
created_content = content_snapshots[created_index]
completed_content = content_snapshots[completed_index]
# Verify that response.created has empty or minimal content
assert len(created_content) == 0, f"response.created should have empty content, got: {repr(created_content[:100])}"
# Verify that response.completed has the full content
assert len(completed_content) > 0, "response.completed should have content"
assert case.expected.lower() in completed_content.lower(), f"Expected '{case.expected}' in final content"
# Use validator for incremental content checks
delta_content_total = validator.assert_has_incremental_content()
# Verify that the accumulated delta content matches the final content
assert delta_content_total.strip() == completed_content.strip(), (
f"Delta content '{delta_content_total}' should match final content '{completed_content}'"
)
# Verify timing: delta events should come between created and completed
delta_events = [i for i, event_type in enumerate(event_types) if event_type == "response.output_text.delta"]
for delta_idx in delta_events:
assert created_index < delta_idx < completed_index, (
f"Delta event at index {delta_idx} should be between created ({created_index}) and completed ({completed_index})"
)
@pytest.mark.parametrize("case", multi_turn_test_cases)
def test_response_non_streaming_multi_turn(compat_client, text_model_id, case):
previous_response_id = None
for turn_input, turn_expected in case.turns:
response = compat_client.responses.create(
model=text_model_id,
input=turn_input,
previous_response_id=previous_response_id,
)
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn_expected.lower() in output_text
@pytest.mark.parametrize("case", image_test_cases)
def test_response_non_streaming_image(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
stream=False,
)
output_text = response.output_text.lower()
assert case.expected.lower() in output_text
@pytest.mark.parametrize("case", multi_turn_image_test_cases)
def test_response_non_streaming_multi_turn_image(compat_client, text_model_id, case):
previous_response_id = None
for turn_input, turn_expected in case.turns:
response = compat_client.responses.create(
model=text_model_id,
input=turn_input,
previous_response_id=previous_response_id,
)
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn_expected.lower() in output_text

View file

@ -0,0 +1,318 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import time
import pytest
from llama_stack import LlamaStackAsLibraryClient
from .helpers import new_vector_store, upload_file
@pytest.mark.parametrize(
"text_format",
# Not testing json_object because most providers don't actually support it.
[
{"type": "text"},
{
"type": "json_schema",
"name": "capitals",
"description": "A schema for the capital of each country",
"schema": {"type": "object", "properties": {"capital": {"type": "string"}}},
"strict": True,
},
],
)
def test_response_text_format(compat_client, text_model_id, text_format):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API text format is not yet supported in library client.")
stream = False
response = compat_client.responses.create(
model=text_model_id,
input="What is the capital of France?",
stream=stream,
text={"format": text_format},
)
# by_alias=True is needed because otherwise Pydantic renames our "schema" field
assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format
assert "paris" in response.output_text.lower()
if text_format["type"] == "json_schema":
assert "paris" in json.loads(response.output_text)["capital"].lower()
@pytest.fixture
def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory):
"""Create a vector store with multiple files that have different attributes for filtering tests."""
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store_with_filters")
tmp_path = tmp_path_factory.mktemp("filter_test_files")
# Create multiple files with different attributes
files_data = [
{
"name": "us_marketing_q1.txt",
"content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
"attributes": {
"region": "us",
"category": "marketing",
"date": 1672531200, # Jan 1, 2023
},
},
{
"name": "us_engineering_q2.txt",
"content": "US technical updates for Q2 2023. New features deployed in the US region.",
"attributes": {
"region": "us",
"category": "engineering",
"date": 1680307200, # Apr 1, 2023
},
},
{
"name": "eu_marketing_q1.txt",
"content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
"attributes": {
"region": "eu",
"category": "marketing",
"date": 1672531200, # Jan 1, 2023
},
},
{
"name": "asia_sales_q3.txt",
"content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
"attributes": {
"region": "asia",
"category": "sales",
"date": 1688169600, # Jul 1, 2023
},
},
]
file_ids = []
for file_data in files_data:
# Create file
file_path = tmp_path / file_data["name"]
file_path.write_text(file_data["content"])
# Upload file
file_response = upload_file(compat_client, file_data["name"], str(file_path))
file_ids.append(file_response.id)
# Attach file to vector store with attributes
file_attach_response = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_response.id,
attributes=file_data["attributes"],
)
# Wait for attachment
while file_attach_response.status == "in_progress":
time.sleep(0.1)
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
assert file_attach_response.status == "completed"
yield vector_store
# Cleanup: delete vector store and files
try:
compat_client.vector_stores.delete(vector_store_id=vector_store.id)
for file_id in file_ids:
try:
compat_client.files.delete(file_id=file_id)
except Exception:
pass # File might already be deleted
except Exception:
pass # Best effort cleanup
def test_response_file_search_filter_by_region(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with region equality filter."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {"type": "eq", "key": "region", "value": "us"},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What are the updates from the US region?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify file search was called with US filter
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return US files (not EU or Asia files)
for result in response.output[0].results:
assert "us" in result.text.lower() or "US" in result.text
# Ensure non-US regions are NOT returned
assert "european" not in result.text.lower()
assert "asia" not in result.text.lower()
def test_response_file_search_filter_by_category(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with category equality filter."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {"type": "eq", "key": "category", "value": "marketing"},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="Show me all marketing reports",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return marketing files (not engineering or sales)
for result in response.output[0].results:
# Marketing files should have promotional/advertising content
assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
# Ensure non-marketing categories are NOT returned
assert "technical" not in result.text.lower()
assert "revenue figures" not in result.text.lower()
def test_response_file_search_filter_by_date_range(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with date range filter using compound AND."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "and",
"filters": [
{
"type": "gte",
"key": "date",
"value": 1672531200, # Jan 1, 2023
},
{
"type": "lt",
"key": "date",
"value": 1680307200, # Apr 1, 2023
},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What happened in Q1 2023?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return Q1 files (not Q2 or Q3)
for result in response.output[0].results:
assert "q1" in result.text.lower()
# Ensure non-Q1 quarters are NOT returned
assert "q2" not in result.text.lower()
assert "q3" not in result.text.lower()
def test_response_file_search_filter_compound_and(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with compound AND filter (region AND category)."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "and",
"filters": [
{"type": "eq", "key": "region", "value": "us"},
{"type": "eq", "key": "category", "value": "engineering"},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What are the engineering updates from the US?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return US engineering files
assert len(response.output[0].results) >= 1
for result in response.output[0].results:
assert "us" in result.text.lower() and "technical" in result.text.lower()
# Ensure it's not from other regions or categories
assert "european" not in result.text.lower() and "asia" not in result.text.lower()
assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
def test_response_file_search_filter_compound_or(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with compound OR filter (marketing OR sales)."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "or",
"filters": [
{"type": "eq", "key": "category", "value": "marketing"},
{"type": "eq", "key": "category", "value": "sales"},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="Show me marketing and sales documents",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should return marketing and sales files, but NOT engineering
categories_found = set()
for result in response.output[0].results:
text_lower = result.text.lower()
if "promotional" in text_lower or "advertising" in text_lower:
categories_found.add("marketing")
if "revenue figures" in text_lower:
categories_found.add("sales")
# Ensure engineering files are NOT returned
assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
# Verify we got at least one of the expected categories
assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"

View file

@ -0,0 +1,474 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import httpx
import openai
import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.core.datatypes import AuthenticationRequiredError
from tests.common.mcp import dependency_tools, make_mcp_server
from .fixtures.test_cases import (
custom_tool_test_cases,
file_search_test_cases,
mcp_tool_test_cases,
multi_turn_tool_execution_streaming_test_cases,
multi_turn_tool_execution_test_cases,
web_search_test_cases,
)
from .helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment
from .streaming_assertions import StreamingValidator
@pytest.mark.parametrize("case", web_search_test_cases)
def test_response_non_streaming_web_search(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=case.tools,
stream=False,
)
assert len(response.output) > 1
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
assert response.output[1].type == "message"
assert response.output[1].status == "completed"
assert response.output[1].role == "assistant"
assert len(response.output[1].content) > 0
assert case.expected.lower() in response.output_text.lower().strip()
@pytest.mark.parametrize("case", file_search_test_cases)
def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
if case.file_content:
file_name = "test_response_non_streaming_file_search.txt"
file_path = tmp_path / file_name
file_path.write_text(case.file_content)
elif case.file_path:
file_path = os.path.join(os.path.dirname(__file__), "fixtures", case.file_path)
file_name = os.path.basename(file_path)
else:
raise ValueError("No file content or path provided for case")
file_response = upload_file(compat_client, file_name, file_path)
# Attach our file to the vector store
compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
# Wait for the file to be attached
wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
# Update our tools with the right vector store id
tools = case.tools
for tool in tools:
if tool["type"] == "file_search":
tool["vector_store_ids"] = [vector_store.id]
# Create the response request, which should query our vector store
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify the file_search_tool was called
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries # ensure it's some non-empty list
assert response.output[0].results
assert case.expected.lower() in response.output[0].results[0].text.lower()
assert response.output[0].results[0].score > 0
# Verify the output_text generated by the response
assert case.expected.lower() in response.output_text.lower().strip()
def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
# Create the response request, which should query our vector store
response = compat_client.responses.create(
model=text_model_id,
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
stream=False,
include=["file_search_call.results"],
)
# Verify the file_search_tool was called
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries # ensure it's some non-empty list
assert not response.output[0].results # ensure we don't get any results
# Verify some output_text was generated by the response
assert response.output_text
@pytest.mark.parametrize("case", mcp_tool_test_cases)
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client")
with make_mcp_server() as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2
assert {t.name for t in list_tools.tools} == {
"get_boiling_point",
"greet_everyone",
}
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {
"liquid_name": "myawesomeliquid",
"celsius": True,
}
assert call.error is None
assert "-100" in call.output
# sometimes the model will call the tool again, so we need to get the last message
message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
exc_type = (
AuthenticationRequiredError
if isinstance(compat_client, LlamaStackAsLibraryClient)
else (httpx.HTTPStatusError, openai.AuthenticationError)
)
with pytest.raises(exc_type):
compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
)
for tool in tools:
if tool["type"] == "mcp":
tool["headers"] = {"Authorization": "Bearer test-token"}
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
)
assert len(response.output) >= 3
@pytest.mark.parametrize("case", mcp_tool_test_cases)
def test_response_sequential_mcp_tool(compat_client, text_model_id, case):
if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client")
with make_mcp_server() as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2
assert {t.name for t in list_tools.tools} == {
"get_boiling_point",
"greet_everyone",
}
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {
"liquid_name": "myawesomeliquid",
"celsius": True,
}
assert call.error is None
assert "-100" in call.output
# sometimes the model will call the tool again, so we need to get the last message
message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
response2 = compat_client.responses.create(
model=text_model_id, input=case.input, tools=tools, stream=False, previous_response_id=response.id
)
assert len(response2.output) >= 1
message = response2.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
@pytest.mark.parametrize("case", custom_tool_test_cases)
def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=case.tools,
stream=False,
)
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].status == "completed"
assert response.output[0].name == "get_weather"
@pytest.mark.parametrize("case", custom_tool_test_cases)
def test_response_function_call_ordering_1(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=case.tools,
stream=False,
)
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].status == "completed"
assert response.output[0].name == "get_weather"
inputs = []
inputs.append(
{
"role": "user",
"content": case.input,
}
)
inputs.append(
{
"type": "function_call_output",
"output": "It is raining.",
"call_id": response.output[0].call_id,
}
)
response = compat_client.responses.create(
model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
)
assert len(response.output) == 1
def test_response_function_call_ordering_2(compat_client, text_model_id):
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get current temperature for a given location.",
"parameters": {
"additionalProperties": False,
"properties": {
"location": {
"description": "City and country e.g. Bogotá, Colombia",
"type": "string",
}
},
"required": ["location"],
"type": "object",
},
}
]
inputs = [
{
"role": "user",
"content": "Is the weather better in San Francisco or Los Angeles?",
}
]
response = compat_client.responses.create(
model=text_model_id,
input=inputs,
tools=tools,
stream=False,
)
for output in response.output:
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
inputs.append(output)
for output in response.output:
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
weather = "It is raining."
if "Los Angeles" in output.arguments:
weather = "It is cloudy."
inputs.append(
{
"type": "function_call_output",
"output": weather,
"call_id": output.call_id,
}
)
response = compat_client.responses.create(
model=text_model_id,
input=inputs,
tools=tools,
stream=False,
)
assert len(response.output) == 1
assert "Los Angeles" in response.output_text
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
response = compat_client.responses.create(
input=case.input,
model=text_model_id,
tools=tools,
)
# Verify we have MCP tool calls in the output
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
message_outputs = [output for output in response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
final_message = message_outputs[-1]
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
expected_output = case.expected
assert expected_output.lower() in response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {response.output_text}"
)
@pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases)
def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
stream = compat_client.responses.create(
input=case.input,
model=text_model_id,
tools=tools,
stream=True,
)
chunks = []
for chunk in stream:
chunks.append(chunk)
# Use validator for common streaming checks
validator = StreamingValidator(chunks)
validator.assert_basic_event_sequence()
validator.assert_response_consistency()
validator.assert_has_tool_calls()
validator.assert_has_mcp_events()
validator.assert_rich_streaming()
# Get the final response from the last chunk
final_chunk = chunks[-1]
if hasattr(final_chunk, "response"):
final_response = final_chunk.response
# Verify multi-turn MCP tool execution results
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
message_outputs = [output for output in final_response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
# Should have at least 1 MCP call (the model should call at least one tool)
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
# All MCP calls should be completed (verifies our tool execution works)
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
# Should have at least one final message response
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
# Final message should be from assistant and completed
final_message = message_outputs[-1]
assert final_message.role == "assistant", (
f"Final message should be from assistant, got {final_message.role}"
)
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
# Check that the expected output appears in the response
expected_output = case.expected
assert expected_output.lower() in final_response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
)