mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
chore(tests): refactor and move responses tests away from verifications (#3068)
This PR kills the verifications infrastructure which is no longer used. It was relocated to the `llama-stack-evals` (https://github.com/meta-llama/llama-stack-evals) repository previously. Responses tests used this infrastructure but that wasn't quite necessary, just a little useful back when @bbrownin introduced the tests. On Discord, we agreed that tests can be moved to our regular integrations test infra. ## Test Plan Some tests currently do fail (although they run!) I will send a follow-up PR which makes them all pass.
This commit is contained in:
parent
342550c1e2
commit
5f1ddd35e4
36 changed files with 93 additions and 13032 deletions
5
tests/integration/non_ci/responses/__init__.py
Normal file
5
tests/integration/non_ci/responses/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
tests/integration/non_ci/responses/fixtures/__init__.py
Normal file
5
tests/integration/non_ci/responses/fixtures/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
137
tests/integration/non_ci/responses/fixtures/fixtures.py
Normal file
137
tests/integration/non_ci/responses/fixtures/fixtures.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
# --- Helper Functions ---
|
||||
|
||||
|
||||
def _load_all_verification_configs():
|
||||
"""Load and aggregate verification configs from the conf/ directory."""
|
||||
# Note: Path is relative to *this* file (fixtures.py)
|
||||
conf_dir = Path(__file__).parent.parent.parent / "conf"
|
||||
if not conf_dir.is_dir():
|
||||
# Use pytest.fail if called during test collection, otherwise raise error
|
||||
# For simplicity here, we'll raise an error, assuming direct calls
|
||||
# are less likely or can handle it.
|
||||
raise FileNotFoundError(f"Verification config directory not found at {conf_dir}")
|
||||
|
||||
all_provider_configs = {}
|
||||
yaml_files = list(conf_dir.glob("*.yaml"))
|
||||
if not yaml_files:
|
||||
raise FileNotFoundError(f"No YAML configuration files found in {conf_dir}")
|
||||
|
||||
for config_path in yaml_files:
|
||||
provider_name = config_path.stem
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
provider_config = yaml.safe_load(f)
|
||||
if provider_config:
|
||||
all_provider_configs[provider_name] = provider_config
|
||||
else:
|
||||
# Log warning if possible, or just skip empty files silently
|
||||
print(f"Warning: Config file {config_path} is empty or invalid.")
|
||||
except Exception as e:
|
||||
raise OSError(f"Error loading config file {config_path}: {e}") from e
|
||||
|
||||
return {"providers": all_provider_configs}
|
||||
|
||||
|
||||
def case_id_generator(case):
|
||||
"""Generate a test ID from the case's 'case_id' field, or use a default."""
|
||||
case_id = case.get("case_id")
|
||||
if isinstance(case_id, str | int):
|
||||
return re.sub(r"\\W|^(?=\\d)", "_", str(case_id))
|
||||
return None
|
||||
|
||||
|
||||
# Helper to get the base test name from the request object
|
||||
def get_base_test_name(request):
|
||||
return request.node.originalname
|
||||
|
||||
|
||||
# --- End Helper Functions ---
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def verification_config():
|
||||
"""Pytest fixture to provide the loaded verification config."""
|
||||
try:
|
||||
return _load_all_verification_configs()
|
||||
except (OSError, FileNotFoundError) as e:
|
||||
pytest.fail(str(e)) # Fail test collection if config loading fails
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def provider(request, verification_config):
|
||||
provider = request.config.getoption("--provider")
|
||||
base_url = request.config.getoption("--base-url")
|
||||
|
||||
if provider and base_url and verification_config["providers"][provider]["base_url"] != base_url:
|
||||
raise ValueError(f"Provider {provider} is not supported for base URL {base_url}")
|
||||
|
||||
if not provider:
|
||||
if not base_url:
|
||||
raise ValueError("Provider and base URL are not provided")
|
||||
for provider, metadata in verification_config["providers"].items():
|
||||
if metadata["base_url"] == base_url:
|
||||
provider = provider
|
||||
break
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base_url(request, provider, verification_config):
|
||||
return request.config.getoption("--base-url") or verification_config.get("providers", {}).get(provider, {}).get(
|
||||
"base_url"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_key(request, provider, verification_config):
|
||||
provider_conf = verification_config.get("providers", {}).get(provider, {})
|
||||
api_key_env_var = provider_conf.get("api_key_var")
|
||||
|
||||
key_from_option = request.config.getoption("--api-key")
|
||||
key_from_env = os.getenv(api_key_env_var) if api_key_env_var else None
|
||||
|
||||
final_key = key_from_option or key_from_env
|
||||
return final_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_mapping(provider, providers_model_mapping):
|
||||
return providers_model_mapping[provider]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openai_client(base_url, api_key, provider):
|
||||
# Simplify running against a local Llama Stack
|
||||
if base_url and "localhost" in base_url and not api_key:
|
||||
api_key = "empty"
|
||||
if provider.startswith("stack:"):
|
||||
parts = provider.split(":")
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
|
||||
config = parts[1]
|
||||
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
|
||||
if not client.initialize():
|
||||
raise RuntimeError("Initialization failed")
|
||||
return client
|
||||
|
||||
return OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
Binary file not shown.
After Width: | Height: | Size: 108 KiB |
Binary file not shown.
After Width: | Height: | Size: 148 KiB |
Binary file not shown.
After Width: | Height: | Size: 139 KiB |
16
tests/integration/non_ci/responses/fixtures/load.py
Normal file
16
tests/integration/non_ci/responses/fixtures/load.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_test_cases(name: str):
|
||||
fixture_dir = Path(__file__).parent / "test_cases"
|
||||
yaml_path = fixture_dir / f"{name}.yaml"
|
||||
with open(yaml_path) as f:
|
||||
return yaml.safe_load(f)
|
Binary file not shown.
|
@ -0,0 +1,397 @@
|
|||
test_chat_basic:
|
||||
test_name: test_chat_basic
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
output: Earth
|
||||
- case_id: "saturn"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet has rings around it with a name starting with letter
|
||||
S?
|
||||
role: user
|
||||
output: Saturn
|
||||
test_chat_input_validation:
|
||||
test_name: test_chat_input_validation
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "messages_missing"
|
||||
input:
|
||||
messages: []
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "messages_role_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: fake_role
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tool_choice_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tool_choice: invalid
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tool_choice_no_tools"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tool_choice: required
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tools_type_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tools:
|
||||
- type: invalid
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
test_chat_image:
|
||||
test_name: test_chat_image
|
||||
test_params:
|
||||
case:
|
||||
- input:
|
||||
messages:
|
||||
- content:
|
||||
- text: What is in this image?
|
||||
type: text
|
||||
- image_url:
|
||||
url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg
|
||||
type: image_url
|
||||
role: user
|
||||
output: llama
|
||||
test_chat_structured_output:
|
||||
test_name: test_chat_structured_output
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "calendar"
|
||||
input:
|
||||
messages:
|
||||
- content: Extract the event information.
|
||||
role: system
|
||||
- content: Alice and Bob are going to a science fair on Friday.
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: calendar_event
|
||||
schema:
|
||||
properties:
|
||||
date:
|
||||
title: Date
|
||||
type: string
|
||||
name:
|
||||
title: Name
|
||||
type: string
|
||||
participants:
|
||||
items:
|
||||
type: string
|
||||
title: Participants
|
||||
type: array
|
||||
required:
|
||||
- name
|
||||
- date
|
||||
- participants
|
||||
title: CalendarEvent
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_calendar_event
|
||||
- case_id: "math"
|
||||
input:
|
||||
messages:
|
||||
- content: You are a helpful math tutor. Guide the user through the solution
|
||||
step by step.
|
||||
role: system
|
||||
- content: how can I solve 8x + 7 = -23
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: math_reasoning
|
||||
schema:
|
||||
$defs:
|
||||
Step:
|
||||
properties:
|
||||
explanation:
|
||||
title: Explanation
|
||||
type: string
|
||||
output:
|
||||
title: Output
|
||||
type: string
|
||||
required:
|
||||
- explanation
|
||||
- output
|
||||
title: Step
|
||||
type: object
|
||||
properties:
|
||||
final_answer:
|
||||
title: Final Answer
|
||||
type: string
|
||||
steps:
|
||||
items:
|
||||
$ref: '#/$defs/Step'
|
||||
title: Steps
|
||||
type: array
|
||||
required:
|
||||
- steps
|
||||
- final_answer
|
||||
title: MathReasoning
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_math_reasoning
|
||||
test_tool_calling:
|
||||
test_name: test_tool_calling
|
||||
test_params:
|
||||
case:
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful assistant that can use tools to get information.
|
||||
role: system
|
||||
- content: What's the weather like in San Francisco?
|
||||
role: user
|
||||
tools:
|
||||
- function:
|
||||
description: Get current temperature for a given location.
|
||||
name: get_weather
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
type: function
|
||||
output: get_weather_tool_call
|
||||
|
||||
test_chat_multi_turn_tool_calling:
|
||||
test_name: test_chat_multi_turn_tool_calling
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "text_then_weather_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "What's the name of the Sun in latin?"
|
||||
- - role: user
|
||||
content: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- function:
|
||||
description: Get the current weather
|
||||
name: get_weather
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
location:
|
||||
description: "The city and state (both required), e.g. San Francisco, CA."
|
||||
type: string
|
||||
required: ["location"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': '70 degrees and foggy'}"
|
||||
expected:
|
||||
- num_tool_calls: 0
|
||||
answer: ["sol"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_weather
|
||||
tool_arguments:
|
||||
location: "San Francisco, CA"
|
||||
- num_tool_calls: 0
|
||||
answer: ["foggy", "70 degrees"]
|
||||
- case_id: "weather_tool_then_text"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- function:
|
||||
description: Get the current weather
|
||||
name: get_weather
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
location:
|
||||
description: "The city and state (both required), e.g. San Francisco, CA."
|
||||
type: string
|
||||
required: ["location"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': '70 degrees and foggy'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_weather
|
||||
tool_arguments:
|
||||
location: "San Francisco, CA"
|
||||
- num_tool_calls: 0
|
||||
answer: ["foggy", "70 degrees"]
|
||||
- case_id: "add_product_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "Please add a new product with name 'Widget', price 19.99, in stock, and tags ['new', 'sale'] and give me the product id."
|
||||
tools:
|
||||
- function:
|
||||
description: Add a new product
|
||||
name: addProduct
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: "Name of the product"
|
||||
type: string
|
||||
price:
|
||||
description: "Price of the product"
|
||||
type: number
|
||||
inStock:
|
||||
description: "Availability status of the product."
|
||||
type: boolean
|
||||
tags:
|
||||
description: "List of product tags"
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: ["name", "price", "inStock"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'Successfully added product with id: 123'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: addProduct
|
||||
tool_arguments:
|
||||
name: "Widget"
|
||||
price: 19.99
|
||||
inStock: true
|
||||
tags:
|
||||
- "new"
|
||||
- "sale"
|
||||
- num_tool_calls: 0
|
||||
answer: ["123", "product id: 123"]
|
||||
- case_id: "get_then_create_event_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: system
|
||||
content: "Todays date is 2025-03-01."
|
||||
- role: user
|
||||
content: "Do i have any meetings on March 3rd at 10 am? Yes or no?"
|
||||
- - role: user
|
||||
content: "Alright then, Create an event named 'Team Building', scheduled for that time same time, in the 'Main Conference Room' and add Alice, Bob, Charlie to it. Give me the created event id."
|
||||
tools:
|
||||
- function:
|
||||
description: Create a new event
|
||||
name: create_event
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: "Name of the event"
|
||||
type: string
|
||||
date:
|
||||
description: "Date of the event in ISO format"
|
||||
type: string
|
||||
time:
|
||||
description: "Event Time (HH:MM)"
|
||||
type: string
|
||||
location:
|
||||
description: "Location of the event"
|
||||
type: string
|
||||
participants:
|
||||
description: "List of participant names"
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: ["name", "date", "time", "location", "participants"]
|
||||
type: function
|
||||
- function:
|
||||
description: Get an event by date and time
|
||||
name: get_event
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
date:
|
||||
description: "Date of the event in ISO format"
|
||||
type: string
|
||||
time:
|
||||
description: "Event Time (HH:MM)"
|
||||
type: string
|
||||
required: ["date", "time"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'No events found for 2025-03-03 at 10:00'}"
|
||||
- response: "{'response': 'Successfully created new event with id: e_123'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_event
|
||||
tool_arguments:
|
||||
date: "2025-03-03"
|
||||
time: "10:00"
|
||||
- num_tool_calls: 0
|
||||
answer: ["no", "no events found", "no meetings"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: create_event
|
||||
tool_arguments:
|
||||
name: "Team Building"
|
||||
date: "2025-03-03"
|
||||
time: "10:00"
|
||||
location: "Main Conference Room"
|
||||
participants:
|
||||
- "Alice"
|
||||
- "Bob"
|
||||
- "Charlie"
|
||||
- num_tool_calls: 0
|
||||
answer: ["e_123", "event id: e_123"]
|
||||
- case_id: "compare_monthly_expense_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: system
|
||||
content: "Todays date is 2025-03-01."
|
||||
- role: user
|
||||
content: "what was my monthly expense in Jan of this year?"
|
||||
- - role: user
|
||||
content: "Was it less than Feb of last year? Only answer with yes or no."
|
||||
tools:
|
||||
- function:
|
||||
description: Get monthly expense summary
|
||||
name: getMonthlyExpenseSummary
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
month:
|
||||
description: "Month of the year (1-12)"
|
||||
type: integer
|
||||
year:
|
||||
description: "Year"
|
||||
type: integer
|
||||
required: ["month", "year"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'Total expenses for January 2025: $1000'}"
|
||||
- response: "{'response': 'Total expenses for February 2024: $2000'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: getMonthlyExpenseSummary
|
||||
tool_arguments:
|
||||
month: 1
|
||||
year: 2025
|
||||
- num_tool_calls: 0
|
||||
answer: ["1000", "$1,000", "1,000"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: getMonthlyExpenseSummary
|
||||
tool_arguments:
|
||||
month: 2
|
||||
year: 2024
|
||||
- num_tool_calls: 0
|
||||
answer: ["yes"]
|
|
@ -0,0 +1,166 @@
|
|||
test_response_basic:
|
||||
test_name: test_response_basic
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
input: "Which planet do humans live on?"
|
||||
output: "earth"
|
||||
- case_id: "saturn"
|
||||
input: "Which planet has rings around it with a name starting with letter S?"
|
||||
output: "saturn"
|
||||
- case_id: "image_input"
|
||||
input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "what teams are playing in this image?"
|
||||
- role: user
|
||||
content:
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg"
|
||||
output: "brooklyn nets"
|
||||
|
||||
test_response_multi_turn:
|
||||
test_name: test_response_multi_turn
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
turns:
|
||||
- input: "Which planet do humans live on?"
|
||||
output: "earth"
|
||||
- input: "What is the name of the planet from your previous response?"
|
||||
output: "earth"
|
||||
|
||||
test_response_web_search:
|
||||
test_name: test_response_web_search
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_experts"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: web_search
|
||||
search_context_size: "low"
|
||||
output: "128"
|
||||
|
||||
test_response_file_search:
|
||||
test_name: test_response_file_search
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_experts"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: file_search
|
||||
# vector_store_ids param for file_search tool gets added by the test runner
|
||||
file_content: "Llama 4 Maverick has 128 experts"
|
||||
output: "128"
|
||||
- case_id: "llama_experts_pdf"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: file_search
|
||||
# vector_store_ids param for file_search toolgets added by the test runner
|
||||
file_path: "pdfs/llama_stack_and_models.pdf"
|
||||
output: "128"
|
||||
|
||||
test_response_mcp_tool:
|
||||
test_name: test_response_mcp_tool
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "boiling_point_tool"
|
||||
input: "What is the boiling point of myawesomeliquid in Celsius?"
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "Hello, world!"
|
||||
|
||||
test_response_custom_tool:
|
||||
test_name: test_response_custom_tool
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "sf_weather"
|
||||
input: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- type: function
|
||||
name: get_weather
|
||||
description: Get current temperature for a given location.
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
|
||||
test_response_image:
|
||||
test_name: test_response_image
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_image"
|
||||
input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "Identify the type of animal in this image."
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
||||
output: "llama"
|
||||
|
||||
# the models are really poor at tool calling after seeing images :/
|
||||
test_response_multi_turn_image:
|
||||
test_name: test_response_multi_turn_image
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_image_understanding"
|
||||
turns:
|
||||
- input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'."
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
||||
output: "llama"
|
||||
- input: "What country do you find this animal primarily in? What continent?"
|
||||
output: "peru"
|
||||
|
||||
test_response_multi_turn_tool_execution:
|
||||
test_name: test_response_multi_turn_tool_execution
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_file_access_check"
|
||||
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "yes"
|
||||
- case_id: "experiment_results_lookup"
|
||||
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "100°C"
|
||||
|
||||
test_response_multi_turn_tool_execution_streaming:
|
||||
test_name: test_response_multi_turn_tool_execution_streaming
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_permissions_workflow"
|
||||
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "no"
|
||||
- case_id: "experiment_analysis_streaming"
|
||||
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "85%"
|
922
tests/integration/non_ci/responses/test_responses.py
Normal file
922
tests/integration/non_ci/responses/test_responses.py
Normal file
|
@ -0,0 +1,922 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError
|
||||
from tests.common.mcp import dependency_tools, make_mcp_server
|
||||
|
||||
from .fixtures.fixtures import case_id_generator
|
||||
from .fixtures.load import load_test_cases
|
||||
|
||||
responses_test_cases = load_test_cases("responses")
|
||||
|
||||
|
||||
def _new_vector_store(openai_client, name):
|
||||
# Ensure we don't reuse an existing vector store
|
||||
vector_stores = openai_client.vector_stores.list()
|
||||
for vector_store in vector_stores:
|
||||
if vector_store.name == name:
|
||||
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||
|
||||
# Create a new vector store
|
||||
vector_store = openai_client.vector_stores.create(
|
||||
name=name,
|
||||
)
|
||||
return vector_store
|
||||
|
||||
|
||||
def _upload_file(openai_client, name, file_path):
|
||||
# Ensure we don't reuse an existing file
|
||||
files = openai_client.files.list()
|
||||
for file in files:
|
||||
if file.filename == name:
|
||||
openai_client.files.delete(file_id=file.id)
|
||||
|
||||
# Upload a text file with our document content
|
||||
return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_basic"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_basic(request, compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
stream=False,
|
||||
)
|
||||
output_text = response.output_text.lower().strip()
|
||||
assert len(output_text) > 0
|
||||
assert case["output"].lower() in output_text
|
||||
|
||||
retrieved_response = compat_client.responses.retrieve(response_id=response.id)
|
||||
assert retrieved_response.output_text == response.output_text
|
||||
|
||||
next_response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Repeat your previous response in all caps.",
|
||||
previous_response_id=response.id,
|
||||
)
|
||||
next_output_text = next_response.output_text.strip()
|
||||
assert case["output"].upper() in next_output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_basic"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_streaming_basic(request, compat_client, text_model_id, case):
|
||||
import time
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Track events and timing to verify proper streaming
|
||||
events = []
|
||||
event_times = []
|
||||
response_id = ""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for chunk in response:
|
||||
current_time = time.time()
|
||||
event_times.append(current_time - start_time)
|
||||
events.append(chunk)
|
||||
|
||||
if chunk.type == "response.created":
|
||||
# Verify response.created is emitted first and immediately
|
||||
assert len(events) == 1, "response.created should be the first event"
|
||||
assert event_times[0] < 0.1, "response.created should be emitted immediately"
|
||||
assert chunk.response.status == "in_progress"
|
||||
response_id = chunk.response.id
|
||||
|
||||
elif chunk.type == "response.completed":
|
||||
# Verify response.completed comes after response.created
|
||||
assert len(events) >= 2, "response.completed should come after response.created"
|
||||
assert chunk.response.status == "completed"
|
||||
assert chunk.response.id == response_id, "Response ID should be consistent"
|
||||
|
||||
# Verify content quality
|
||||
output_text = chunk.response.output_text.lower().strip()
|
||||
assert len(output_text) > 0, "Response should have content"
|
||||
assert case["output"].lower() in output_text, f"Expected '{case['output']}' in response"
|
||||
|
||||
# Verify we got both required events
|
||||
event_types = [event.type for event in events]
|
||||
assert "response.created" in event_types, "Missing response.created event"
|
||||
assert "response.completed" in event_types, "Missing response.completed event"
|
||||
|
||||
# Verify event order
|
||||
created_index = event_types.index("response.created")
|
||||
completed_index = event_types.index("response.completed")
|
||||
assert created_index < completed_index, "response.created should come before response.completed"
|
||||
|
||||
# Verify stored response matches streamed response
|
||||
retrieved_response = compat_client.responses.retrieve(response_id=response_id)
|
||||
final_event = events[-1]
|
||||
assert retrieved_response.output_text == final_event.response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_basic"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_streaming_incremental_content(request, compat_client, text_model_id, case):
|
||||
"""Test that streaming actually delivers content incrementally, not just at the end."""
|
||||
import time
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Track all events and their content to verify incremental streaming
|
||||
events = []
|
||||
content_snapshots = []
|
||||
event_times = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for chunk in response:
|
||||
current_time = time.time()
|
||||
event_times.append(current_time - start_time)
|
||||
events.append(chunk)
|
||||
|
||||
# Track content at each event based on event type
|
||||
if chunk.type == "response.output_text.delta":
|
||||
# For delta events, track the delta content
|
||||
content_snapshots.append(chunk.delta)
|
||||
elif hasattr(chunk, "response") and hasattr(chunk.response, "output_text"):
|
||||
# For response.created/completed events, track the full output_text
|
||||
content_snapshots.append(chunk.response.output_text)
|
||||
else:
|
||||
content_snapshots.append("")
|
||||
|
||||
# Verify we have the expected events
|
||||
event_types = [event.type for event in events]
|
||||
assert "response.created" in event_types, "Missing response.created event"
|
||||
assert "response.completed" in event_types, "Missing response.completed event"
|
||||
|
||||
# Check if we have incremental content updates
|
||||
created_index = event_types.index("response.created")
|
||||
completed_index = event_types.index("response.completed")
|
||||
|
||||
# The key test: verify content progression
|
||||
created_content = content_snapshots[created_index]
|
||||
completed_content = content_snapshots[completed_index]
|
||||
|
||||
# Verify that response.created has empty or minimal content
|
||||
assert len(created_content) == 0, f"response.created should have empty content, got: {repr(created_content[:100])}"
|
||||
|
||||
# Verify that response.completed has the full content
|
||||
assert len(completed_content) > 0, "response.completed should have content"
|
||||
assert case["output"].lower() in completed_content.lower(), f"Expected '{case['output']}' in final content"
|
||||
|
||||
# Check for true incremental streaming by looking for delta events
|
||||
delta_events = [i for i, event_type in enumerate(event_types) if event_type == "response.output_text.delta"]
|
||||
|
||||
# Assert that we have delta events (true incremental streaming)
|
||||
assert len(delta_events) > 0, "Expected delta events for true incremental streaming, but found none"
|
||||
|
||||
# Verify delta events have content and accumulate to final content
|
||||
delta_content_total = ""
|
||||
non_empty_deltas = 0
|
||||
|
||||
for delta_idx in delta_events:
|
||||
delta_content = content_snapshots[delta_idx]
|
||||
if delta_content:
|
||||
delta_content_total += delta_content
|
||||
non_empty_deltas += 1
|
||||
|
||||
# Assert that we have meaningful delta content
|
||||
assert non_empty_deltas > 0, "Delta events found but none contain content"
|
||||
assert len(delta_content_total) > 0, "Delta events found but total delta content is empty"
|
||||
|
||||
# Verify that the accumulated delta content matches the final content
|
||||
assert delta_content_total.strip() == completed_content.strip(), (
|
||||
f"Delta content '{delta_content_total}' should match final content '{completed_content}'"
|
||||
)
|
||||
|
||||
# Verify timing: delta events should come between created and completed
|
||||
for delta_idx in delta_events:
|
||||
assert created_index < delta_idx < completed_index, (
|
||||
f"Delta event at index {delta_idx} should be between created ({created_index}) and completed ({completed_index})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_multi_turn(request, compat_client, text_model_id, case):
|
||||
previous_response_id = None
|
||||
for turn in case["turns"]:
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=turn["input"],
|
||||
previous_response_id=previous_response_id,
|
||||
tools=turn["tools"] if "tools" in turn else None,
|
||||
)
|
||||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn["output"].lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_web_search"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_web_search(request, compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
tools=case["tools"],
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "message"
|
||||
assert response.output[1].status == "completed"
|
||||
assert response.output[1].role == "assistant"
|
||||
assert len(response.output[1].content) > 0
|
||||
assert case["output"].lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_file_search"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_file_search(request, compat_client, text_model_id, tmp_path, case):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = _new_vector_store(compat_client, "test_vector_store")
|
||||
|
||||
if "file_content" in case:
|
||||
file_name = "test_response_non_streaming_file_search.txt"
|
||||
file_path = tmp_path / file_name
|
||||
file_path.write_text(case["file_content"])
|
||||
elif "file_path" in case:
|
||||
file_path = os.path.join(os.path.dirname(__file__), "fixtures", case["file_path"])
|
||||
file_name = os.path.basename(file_path)
|
||||
else:
|
||||
raise ValueError(f"No file content or path provided for case {case['case_id']}")
|
||||
|
||||
file_response = _upload_file(compat_client, file_name, file_path)
|
||||
|
||||
# Attach our file to the vector store
|
||||
file_attach_response = compat_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
|
||||
# Wait for the file to be attached
|
||||
while file_attach_response.status == "in_progress":
|
||||
time.sleep(0.1)
|
||||
file_attach_response = compat_client.vector_stores.files.retrieve(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}"
|
||||
assert not file_attach_response.last_error
|
||||
|
||||
# Update our tools with the right vector store id
|
||||
tools = case["tools"]
|
||||
for tool in tools:
|
||||
if tool["type"] == "file_search":
|
||||
tool["vector_store_ids"] = [vector_store.id]
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries # ensure it's some non-empty list
|
||||
assert response.output[0].results
|
||||
assert case["output"].lower() in response.output[0].results[0].text.lower()
|
||||
assert response.output[0].results[0].score > 0
|
||||
|
||||
# Verify the output_text generated by the response
|
||||
assert case["output"].lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
def test_response_non_streaming_file_search_empty_vector_store(request, compat_client, text_model_id):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = _new_vector_store(compat_client, "test_vector_store")
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="How many experts does the Llama 4 Maverick model have?",
|
||||
tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries # ensure it's some non-empty list
|
||||
assert not response.output[0].results # ensure we don't get any results
|
||||
|
||||
# Verify some output_text was generated by the response
|
||||
assert response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_mcp_tool(request, compat_client, text_model_id, case):
|
||||
with make_mcp_server() as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(response.output) >= 3
|
||||
list_tools = response.output[0]
|
||||
assert list_tools.type == "mcp_list_tools"
|
||||
assert list_tools.server_label == "localmcp"
|
||||
assert len(list_tools.tools) == 2
|
||||
assert {t.name for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"}
|
||||
|
||||
call = response.output[1]
|
||||
assert call.type == "mcp_call"
|
||||
assert call.name == "get_boiling_point"
|
||||
assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True}
|
||||
assert call.error is None
|
||||
assert "-100" in call.output
|
||||
|
||||
# sometimes the model will call the tool again, so we need to get the last message
|
||||
message = response.output[-1]
|
||||
text_content = message.content[0].text
|
||||
assert "boiling point" in text_content.lower()
|
||||
|
||||
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
exc_type = (
|
||||
AuthenticationRequiredError
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient)
|
||||
else (httpx.HTTPStatusError, openai.AuthenticationError)
|
||||
)
|
||||
with pytest.raises(exc_type):
|
||||
compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
tool["headers"] = {"Authorization": "Bearer test-token"}
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) >= 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_custom_tool(request, compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
tools=case["tools"],
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_image"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_image(request, compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case["input"],
|
||||
stream=False,
|
||||
)
|
||||
output_text = response.output_text.lower()
|
||||
assert case["output"].lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_image"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_multi_turn_image(request, compat_client, text_model_id, case):
|
||||
previous_response_id = None
|
||||
for turn in case["turns"]:
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=turn["input"],
|
||||
previous_response_id=previous_response_id,
|
||||
tools=turn["tools"] if "tools" in turn else None,
|
||||
)
|
||||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn["output"].lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_multi_turn_tool_execution(request, compat_client, text_model_id, case):
|
||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
# Replace the placeholder URL with the actual server URL
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
input=case["input"],
|
||||
model=text_model_id,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we have MCP tool calls in the output
|
||||
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
|
||||
|
||||
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
expected_output = case["output"]
|
||||
assert expected_output.lower() in response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {response.output_text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
async def test_response_streaming_multi_turn_tool_execution(request, compat_client, text_model_id, case):
|
||||
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
# Replace the placeholder URL with the actual server URL
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
stream = compat_client.responses.create(
|
||||
input=case["input"],
|
||||
model=text_model_id,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Should have at least response.created and response.completed
|
||||
assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
|
||||
|
||||
# First chunk should be response.created
|
||||
assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
|
||||
|
||||
# Last chunk should be response.completed
|
||||
assert chunks[-1].type == "response.completed", (
|
||||
f"Last chunk should be response.completed, got {chunks[-1].type}"
|
||||
)
|
||||
|
||||
# Get the final response from the last chunk
|
||||
final_chunk = chunks[-1]
|
||||
if hasattr(final_chunk, "response"):
|
||||
final_response = final_chunk.response
|
||||
|
||||
# Verify multi-turn MCP tool execution results
|
||||
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in final_response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
# Should have at least 1 MCP call (the model should call at least one tool)
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
|
||||
# All MCP calls should be completed (verifies our tool execution works)
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
# Should have at least one final message response
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
# Final message should be from assistant and completed
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", (
|
||||
f"Final message should be from assistant, got {final_message.role}"
|
||||
)
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
# Check that the expected output appears in the response
|
||||
expected_output = case["output"]
|
||||
assert expected_output.lower() in final_response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_format",
|
||||
# Not testing json_object because most providers don't actually support it.
|
||||
[
|
||||
{"type": "text"},
|
||||
{
|
||||
"type": "json_schema",
|
||||
"name": "capitals",
|
||||
"description": "A schema for the capital of each country",
|
||||
"schema": {"type": "object", "properties": {"capital": {"type": "string"}}},
|
||||
"strict": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_response_text_format(request, compat_client, text_model_id, text_format):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API text format is not yet supported in library client.")
|
||||
|
||||
stream = False
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What is the capital of France?",
|
||||
stream=stream,
|
||||
text={"format": text_format},
|
||||
)
|
||||
# by_alias=True is needed because otherwise Pydantic renames our "schema" field
|
||||
assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format
|
||||
assert "paris" in response.output_text.lower()
|
||||
if text_format["type"] == "json_schema":
|
||||
assert "paris" in json.loads(response.output_text)["capital"].lower()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_with_filtered_files(request, compat_client, text_model_id, tmp_path_factory):
|
||||
"""Create a vector store with multiple files that have different attributes for filtering tests."""
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = _new_vector_store(compat_client, "test_vector_store_with_filters")
|
||||
tmp_path = tmp_path_factory.mktemp("filter_test_files")
|
||||
|
||||
# Create multiple files with different attributes
|
||||
files_data = [
|
||||
{
|
||||
"name": "us_marketing_q1.txt",
|
||||
"content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
|
||||
"attributes": {
|
||||
"region": "us",
|
||||
"category": "marketing",
|
||||
"date": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "us_engineering_q2.txt",
|
||||
"content": "US technical updates for Q2 2023. New features deployed in the US region.",
|
||||
"attributes": {
|
||||
"region": "us",
|
||||
"category": "engineering",
|
||||
"date": 1680307200, # Apr 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "eu_marketing_q1.txt",
|
||||
"content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
|
||||
"attributes": {
|
||||
"region": "eu",
|
||||
"category": "marketing",
|
||||
"date": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "asia_sales_q3.txt",
|
||||
"content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
|
||||
"attributes": {
|
||||
"region": "asia",
|
||||
"category": "sales",
|
||||
"date": 1688169600, # Jul 1, 2023
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
file_ids = []
|
||||
for file_data in files_data:
|
||||
# Create file
|
||||
file_path = tmp_path / file_data["name"]
|
||||
file_path.write_text(file_data["content"])
|
||||
|
||||
# Upload file
|
||||
file_response = _upload_file(compat_client, file_data["name"], str(file_path))
|
||||
file_ids.append(file_response.id)
|
||||
|
||||
# Attach file to vector store with attributes
|
||||
file_attach_response = compat_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"]
|
||||
)
|
||||
|
||||
# Wait for attachment
|
||||
while file_attach_response.status == "in_progress":
|
||||
time.sleep(0.1)
|
||||
file_attach_response = compat_client.vector_stores.files.retrieve(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
assert file_attach_response.status == "completed"
|
||||
|
||||
yield vector_store
|
||||
|
||||
# Cleanup: delete vector store and files
|
||||
try:
|
||||
compat_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
compat_client.files.delete(file_id=file_id)
|
||||
except Exception:
|
||||
pass # File might already be deleted
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_region(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with region equality filter."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {"type": "eq", "key": "region", "value": "us"},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What are the updates from the US region?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify file search was called with US filter
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return US files (not EU or Asia files)
|
||||
for result in response.output[0].results:
|
||||
assert "us" in result.text.lower() or "US" in result.text
|
||||
# Ensure non-US regions are NOT returned
|
||||
assert "european" not in result.text.lower()
|
||||
assert "asia" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_category(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with category equality filter."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {"type": "eq", "key": "category", "value": "marketing"},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Show me all marketing reports",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return marketing files (not engineering or sales)
|
||||
for result in response.output[0].results:
|
||||
# Marketing files should have promotional/advertising content
|
||||
assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
|
||||
# Ensure non-marketing categories are NOT returned
|
||||
assert "technical" not in result.text.lower()
|
||||
assert "revenue figures" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_date_range(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with date range filter using compound AND."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "and",
|
||||
"filters": [
|
||||
{
|
||||
"type": "gte",
|
||||
"key": "date",
|
||||
"value": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
{
|
||||
"type": "lt",
|
||||
"key": "date",
|
||||
"value": 1680307200, # Apr 1, 2023
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What happened in Q1 2023?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return Q1 files (not Q2 or Q3)
|
||||
for result in response.output[0].results:
|
||||
assert "q1" in result.text.lower()
|
||||
# Ensure non-Q1 quarters are NOT returned
|
||||
assert "q2" not in result.text.lower()
|
||||
assert "q3" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_compound_and(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with compound AND filter (region AND category)."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "and",
|
||||
"filters": [
|
||||
{"type": "eq", "key": "region", "value": "us"},
|
||||
{"type": "eq", "key": "category", "value": "engineering"},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What are the engineering updates from the US?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return US engineering files
|
||||
assert len(response.output[0].results) >= 1
|
||||
for result in response.output[0].results:
|
||||
assert "us" in result.text.lower() and "technical" in result.text.lower()
|
||||
# Ensure it's not from other regions or categories
|
||||
assert "european" not in result.text.lower() and "asia" not in result.text.lower()
|
||||
assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_compound_or(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with compound OR filter (marketing OR sales)."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "or",
|
||||
"filters": [
|
||||
{"type": "eq", "key": "category", "value": "marketing"},
|
||||
{"type": "eq", "key": "category", "value": "sales"},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Show me marketing and sales documents",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should return marketing and sales files, but NOT engineering
|
||||
categories_found = set()
|
||||
for result in response.output[0].results:
|
||||
text_lower = result.text.lower()
|
||||
if "promotional" in text_lower or "advertising" in text_lower:
|
||||
categories_found.add("marketing")
|
||||
if "revenue figures" in text_lower:
|
||||
categories_found.add("sales")
|
||||
# Ensure engineering files are NOT returned
|
||||
assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
|
||||
|
||||
# Verify we got at least one of the expected categories
|
||||
assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
|
||||
assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"
|
Loading…
Add table
Add a link
Reference in a new issue