feat: support filters in file search (#2472)

# What does this PR do?
Move to use vector_stores.search for file search tool in Responses,
which supports filters.

closes #2435 

## Test Plan
Added e2e test with fitlers.
myenv ❯ llama stack run llama_stack/templates/fireworks/run.yaml

pytest -sv tests/verifications/openai_api/test_responses.py \
  -k 'file_search and filters' \
  --base-url=http://localhost:8321/v1/openai/v1 \
  --model=meta-llama/Llama-3.3-70B-Instruct
This commit is contained in:
ehhuang 2025-06-18 21:50:55 -07:00 committed by GitHub
parent fd37a50e6a
commit db2cd9e8f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 449 additions and 63 deletions

View file

@ -71,12 +71,21 @@ def mock_responses_store():
@pytest.fixture
def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store):
def mock_vector_io_api():
vector_io_api = AsyncMock()
return vector_io_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
)

View file

@ -714,3 +714,277 @@ def test_response_text_format(request, openai_client, model, provider, verificat
assert "paris" in response.output_text.lower()
if text_format["type"] == "json_schema":
assert "paris" in json.loads(response.output_text)["capital"].lower()
@pytest.fixture
def vector_store_with_filtered_files(request, openai_client, model, provider, verification_config, tmp_path_factory):
"""Create a vector store with multiple files that have different attributes for filtering tests."""
if isinstance(openai_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
vector_store = _new_vector_store(openai_client, "test_vector_store_with_filters")
tmp_path = tmp_path_factory.mktemp("filter_test_files")
# Create multiple files with different attributes
files_data = [
{
"name": "us_marketing_q1.txt",
"content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
"attributes": {
"region": "us",
"category": "marketing",
"date": 1672531200, # Jan 1, 2023
},
},
{
"name": "us_engineering_q2.txt",
"content": "US technical updates for Q2 2023. New features deployed in the US region.",
"attributes": {
"region": "us",
"category": "engineering",
"date": 1680307200, # Apr 1, 2023
},
},
{
"name": "eu_marketing_q1.txt",
"content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
"attributes": {
"region": "eu",
"category": "marketing",
"date": 1672531200, # Jan 1, 2023
},
},
{
"name": "asia_sales_q3.txt",
"content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
"attributes": {
"region": "asia",
"category": "sales",
"date": 1688169600, # Jul 1, 2023
},
},
]
file_ids = []
for file_data in files_data:
# Create file
file_path = tmp_path / file_data["name"]
file_path.write_text(file_data["content"])
# Upload file
file_response = _upload_file(openai_client, file_data["name"], str(file_path))
file_ids.append(file_response.id)
# Attach file to vector store with attributes
file_attach_response = openai_client.vector_stores.files.create(
vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"]
)
# Wait for attachment
while file_attach_response.status == "in_progress":
time.sleep(0.1)
file_attach_response = openai_client.vector_stores.files.retrieve(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
assert file_attach_response.status == "completed"
yield vector_store
# Cleanup: delete vector store and files
try:
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
for file_id in file_ids:
try:
openai_client.files.delete(file_id=file_id)
except Exception:
pass # File might already be deleted
except Exception:
pass # Best effort cleanup
def test_response_file_search_filter_by_region(openai_client, model, vector_store_with_filtered_files):
"""Test file search with region equality filter."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {"type": "eq", "key": "region", "value": "us"},
}
]
response = openai_client.responses.create(
model=model,
input="What are the updates from the US region?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify file search was called with US filter
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return US files (not EU or Asia files)
for result in response.output[0].results:
assert "us" in result.text.lower() or "US" in result.text
# Ensure non-US regions are NOT returned
assert "european" not in result.text.lower()
assert "asia" not in result.text.lower()
def test_response_file_search_filter_by_category(openai_client, model, vector_store_with_filtered_files):
"""Test file search with category equality filter."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {"type": "eq", "key": "category", "value": "marketing"},
}
]
response = openai_client.responses.create(
model=model,
input="Show me all marketing reports",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return marketing files (not engineering or sales)
for result in response.output[0].results:
# Marketing files should have promotional/advertising content
assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
# Ensure non-marketing categories are NOT returned
assert "technical" not in result.text.lower()
assert "revenue figures" not in result.text.lower()
def test_response_file_search_filter_by_date_range(openai_client, model, vector_store_with_filtered_files):
"""Test file search with date range filter using compound AND."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "and",
"filters": [
{
"type": "gte",
"key": "date",
"value": 1672531200, # Jan 1, 2023
},
{
"type": "lt",
"key": "date",
"value": 1680307200, # Apr 1, 2023
},
],
},
}
]
response = openai_client.responses.create(
model=model,
input="What happened in Q1 2023?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return Q1 files (not Q2 or Q3)
for result in response.output[0].results:
assert "q1" in result.text.lower()
# Ensure non-Q1 quarters are NOT returned
assert "q2" not in result.text.lower()
assert "q3" not in result.text.lower()
def test_response_file_search_filter_compound_and(openai_client, model, vector_store_with_filtered_files):
"""Test file search with compound AND filter (region AND category)."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "and",
"filters": [
{"type": "eq", "key": "region", "value": "us"},
{"type": "eq", "key": "category", "value": "engineering"},
],
},
}
]
response = openai_client.responses.create(
model=model,
input="What are the engineering updates from the US?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return US engineering files
assert len(response.output[0].results) >= 1
for result in response.output[0].results:
assert "us" in result.text.lower() and "technical" in result.text.lower()
# Ensure it's not from other regions or categories
assert "european" not in result.text.lower() and "asia" not in result.text.lower()
assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
def test_response_file_search_filter_compound_or(openai_client, model, vector_store_with_filtered_files):
"""Test file search with compound OR filter (marketing OR sales)."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "or",
"filters": [
{"type": "eq", "key": "category", "value": "marketing"},
{"type": "eq", "key": "category", "value": "sales"},
],
},
}
]
response = openai_client.responses.create(
model=model,
input="Show me marketing and sales documents",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should return marketing and sales files, but NOT engineering
categories_found = set()
for result in response.output[0].results:
text_lower = result.text.lower()
if "promotional" in text_lower or "advertising" in text_lower:
categories_found.add("marketing")
if "revenue figures" in text_lower:
categories_found.add("sales")
# Ensure engineering files are NOT returned
assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
# Verify we got at least one of the expected categories
assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"