mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-21 17:33:12 +00:00
refactor(tests): make the responses tests nicer (#3161)
# What does this PR do? A _bunch_ on cleanup for the Responses tests. - Got rid of YAML test cases, moved them to just use simple pydantic models - Splitting the large monolithic test file into multiple focused test files: - `test_basic_responses.py` for basic and image response tests - `test_tool_responses.py` for tool-related tests - `test_file_search.py` for file search specific tests - Adding a `StreamingValidator` helper class to standardize streaming response validation ## Test Plan Run the tests: ``` pytest -s -v tests/integration/non_ci/responses/ \ --stack-config=starter \ --text-model openai/gpt-4o \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ -k "client_with_models" ```
This commit is contained in:
parent
ba664474de
commit
8ed69978f9
11 changed files with 1312 additions and 1736 deletions
318
tests/integration/non_ci/responses/test_file_search.py
Normal file
318
tests/integration/non_ci/responses/test_file_search.py
Normal file
|
@ -0,0 +1,318 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
from .helpers import new_vector_store, upload_file
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_format",
|
||||
# Not testing json_object because most providers don't actually support it.
|
||||
[
|
||||
{"type": "text"},
|
||||
{
|
||||
"type": "json_schema",
|
||||
"name": "capitals",
|
||||
"description": "A schema for the capital of each country",
|
||||
"schema": {"type": "object", "properties": {"capital": {"type": "string"}}},
|
||||
"strict": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_response_text_format(compat_client, text_model_id, text_format):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API text format is not yet supported in library client.")
|
||||
|
||||
stream = False
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What is the capital of France?",
|
||||
stream=stream,
|
||||
text={"format": text_format},
|
||||
)
|
||||
# by_alias=True is needed because otherwise Pydantic renames our "schema" field
|
||||
assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format
|
||||
assert "paris" in response.output_text.lower()
|
||||
if text_format["type"] == "json_schema":
|
||||
assert "paris" in json.loads(response.output_text)["capital"].lower()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory):
|
||||
"""Create a vector store with multiple files that have different attributes for filtering tests."""
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store_with_filters")
|
||||
tmp_path = tmp_path_factory.mktemp("filter_test_files")
|
||||
|
||||
# Create multiple files with different attributes
|
||||
files_data = [
|
||||
{
|
||||
"name": "us_marketing_q1.txt",
|
||||
"content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
|
||||
"attributes": {
|
||||
"region": "us",
|
||||
"category": "marketing",
|
||||
"date": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "us_engineering_q2.txt",
|
||||
"content": "US technical updates for Q2 2023. New features deployed in the US region.",
|
||||
"attributes": {
|
||||
"region": "us",
|
||||
"category": "engineering",
|
||||
"date": 1680307200, # Apr 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "eu_marketing_q1.txt",
|
||||
"content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
|
||||
"attributes": {
|
||||
"region": "eu",
|
||||
"category": "marketing",
|
||||
"date": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "asia_sales_q3.txt",
|
||||
"content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
|
||||
"attributes": {
|
||||
"region": "asia",
|
||||
"category": "sales",
|
||||
"date": 1688169600, # Jul 1, 2023
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
file_ids = []
|
||||
for file_data in files_data:
|
||||
# Create file
|
||||
file_path = tmp_path / file_data["name"]
|
||||
file_path.write_text(file_data["content"])
|
||||
|
||||
# Upload file
|
||||
file_response = upload_file(compat_client, file_data["name"], str(file_path))
|
||||
file_ids.append(file_response.id)
|
||||
|
||||
# Attach file to vector store with attributes
|
||||
file_attach_response = compat_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
attributes=file_data["attributes"],
|
||||
)
|
||||
|
||||
# Wait for attachment
|
||||
while file_attach_response.status == "in_progress":
|
||||
time.sleep(0.1)
|
||||
file_attach_response = compat_client.vector_stores.files.retrieve(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
assert file_attach_response.status == "completed"
|
||||
|
||||
yield vector_store
|
||||
|
||||
# Cleanup: delete vector store and files
|
||||
try:
|
||||
compat_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
compat_client.files.delete(file_id=file_id)
|
||||
except Exception:
|
||||
pass # File might already be deleted
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_region(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with region equality filter."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {"type": "eq", "key": "region", "value": "us"},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What are the updates from the US region?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify file search was called with US filter
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return US files (not EU or Asia files)
|
||||
for result in response.output[0].results:
|
||||
assert "us" in result.text.lower() or "US" in result.text
|
||||
# Ensure non-US regions are NOT returned
|
||||
assert "european" not in result.text.lower()
|
||||
assert "asia" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_category(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with category equality filter."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {"type": "eq", "key": "category", "value": "marketing"},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Show me all marketing reports",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return marketing files (not engineering or sales)
|
||||
for result in response.output[0].results:
|
||||
# Marketing files should have promotional/advertising content
|
||||
assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
|
||||
# Ensure non-marketing categories are NOT returned
|
||||
assert "technical" not in result.text.lower()
|
||||
assert "revenue figures" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_date_range(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with date range filter using compound AND."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "and",
|
||||
"filters": [
|
||||
{
|
||||
"type": "gte",
|
||||
"key": "date",
|
||||
"value": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
{
|
||||
"type": "lt",
|
||||
"key": "date",
|
||||
"value": 1680307200, # Apr 1, 2023
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What happened in Q1 2023?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return Q1 files (not Q2 or Q3)
|
||||
for result in response.output[0].results:
|
||||
assert "q1" in result.text.lower()
|
||||
# Ensure non-Q1 quarters are NOT returned
|
||||
assert "q2" not in result.text.lower()
|
||||
assert "q3" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_compound_and(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with compound AND filter (region AND category)."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "and",
|
||||
"filters": [
|
||||
{"type": "eq", "key": "region", "value": "us"},
|
||||
{"type": "eq", "key": "category", "value": "engineering"},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What are the engineering updates from the US?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return US engineering files
|
||||
assert len(response.output[0].results) >= 1
|
||||
for result in response.output[0].results:
|
||||
assert "us" in result.text.lower() and "technical" in result.text.lower()
|
||||
# Ensure it's not from other regions or categories
|
||||
assert "european" not in result.text.lower() and "asia" not in result.text.lower()
|
||||
assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_compound_or(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with compound OR filter (marketing OR sales)."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "or",
|
||||
"filters": [
|
||||
{"type": "eq", "key": "category", "value": "marketing"},
|
||||
{"type": "eq", "key": "category", "value": "sales"},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Show me marketing and sales documents",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should return marketing and sales files, but NOT engineering
|
||||
categories_found = set()
|
||||
for result in response.output[0].results:
|
||||
text_lower = result.text.lower()
|
||||
if "promotional" in text_lower or "advertising" in text_lower:
|
||||
categories_found.add("marketing")
|
||||
if "revenue figures" in text_lower:
|
||||
categories_found.add("sales")
|
||||
# Ensure engineering files are NOT returned
|
||||
assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
|
||||
|
||||
# Verify we got at least one of the expected categories
|
||||
assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
|
||||
assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"
|
Loading…
Add table
Add a link
Reference in a new issue