From db2cd9e8f32d88a1eab8fea2a739351a10c1843d Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 18 Jun 2025 21:50:55 -0700 Subject: [PATCH] feat: support filters in file search (#2472) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Move to use vector_stores.search for file search tool in Responses, which supports filters. closes #2435 ## Test Plan Added e2e test with fitlers. myenv ❯ llama stack run llama_stack/templates/fireworks/run.yaml pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search and filters' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=meta-llama/Llama-3.3-70B-Instruct --- docs/_static/llama-stack-spec.html | 32 +- docs/_static/llama-stack-spec.yaml | 17 +- llama_stack/apis/agents/openai_responses.py | 6 +- llama_stack/apis/vector_io/vector_io.py | 7 +- llama_stack/distribution/routers/vector_io.py | 3 +- .../inline/agents/meta_reference/agents.py | 1 + .../agents/meta_reference/openai_responses.py | 98 ++++++- .../remote/vector_io/chroma/chroma.py | 3 +- .../remote/vector_io/milvus/milvus.py | 3 +- .../remote/vector_io/qdrant/qdrant.py | 3 +- .../utils/memory/openai_vector_store_mixin.py | 54 +++- .../meta_reference/test_openai_responses.py | 11 +- .../openai_api/test_responses.py | 274 ++++++++++++++++++ 13 files changed, 449 insertions(+), 63 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 486068258..5407f9808 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -7296,7 +7296,7 @@ } }, "additionalProperties": false, - "title": "FileSearchRankingOptions" + "title": "SearchRankingOptions" } }, "additionalProperties": false, @@ -13478,28 +13478,16 @@ }, "ranking_options": { "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] + "properties": { + "ranker": { + "type": "string" + }, + "score_threshold": { + "type": "number", + "default": 0.0 + } }, + "additionalProperties": false, "description": "Ranking options for fine-tuning the search results." }, "rewrite_query": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 72f60032e..a354e4bd0 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -5179,7 +5179,7 @@ components: type: number default: 0.0 additionalProperties: false - title: FileSearchRankingOptions + title: SearchRankingOptions additionalProperties: false required: - type @@ -9408,14 +9408,13 @@ components: Maximum number of results to return (1 to 50 inclusive, default 10). ranking_options: type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + properties: + ranker: + type: string + score_threshold: + type: number + default: 0.0 + additionalProperties: false description: >- Ranking options for fine-tuning the search results. rewrite_query: diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 2e1cb257a..addb72f14 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal from pydantic import BaseModel, Field from typing_extensions import TypedDict +from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions from llama_stack.schema_utils import json_schema_type, register_schema # NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably @@ -400,11 +401,6 @@ class OpenAIResponseInputToolFunction(BaseModel): strict: bool | None = None -class FileSearchRankingOptions(BaseModel): - ranker: str | None = None - score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0) - - @json_schema_type class OpenAIResponseInputToolFileSearch(BaseModel): type: Literal["file_search"] = "file_search" diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 77d4cfc5a..20cc594cc 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -157,6 +157,11 @@ VectorStoreChunkingStrategy = Annotated[ register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy") +class SearchRankingOptions(BaseModel): + ranker: str | None = None + score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0) + + @json_schema_type class VectorStoreFileLastError(BaseModel): code: Literal["server_error"] | Literal["rate_limit_exceeded"] @@ -319,7 +324,7 @@ class VectorIO(Protocol): query: str | list[str], filters: dict[str, Any] | None = None, max_num_results: int | None = 10, - ranking_options: dict[str, Any] | None = None, + ranking_options: SearchRankingOptions | None = None, rewrite_query: bool | None = False, ) -> VectorStoreSearchResponsePage: """Search for chunks in a vector store. diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index 44c1fafa7..c4191b8a1 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -14,6 +14,7 @@ from llama_stack.apis.models import ModelType from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, + SearchRankingOptions, VectorIO, VectorStoreDeleteResponse, VectorStoreListResponse, @@ -246,7 +247,7 @@ class VectorIORouter(VectorIO): query: str | list[str], filters: dict[str, Any] | None = None, max_num_results: int | None = 10, - ranking_options: dict[str, Any] | None = None, + ranking_options: SearchRankingOptions | None = None, rewrite_query: bool | None = False, ) -> VectorStoreSearchResponsePage: logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index ea3c5da97..6b05f90dd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -85,6 +85,7 @@ class MetaReferenceAgentsImpl(Agents): tool_groups_api=self.tool_groups_api, tool_runtime_api=self.tool_runtime_api, responses_store=self.responses_store, + vector_io_api=self.vector_io_api, ) async def create_agent( diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 33fcbfa5d..4465a32fe 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import json import time import uuid @@ -42,6 +43,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseText, OpenAIResponseTextFormat, ) +from llama_stack.apis.common.content_types import TextContentItem from llama_stack.apis.inference.inference import ( Inference, OpenAIAssistantMessageParam, @@ -64,7 +66,8 @@ from llama_stack.apis.inference.inference import ( OpenAIToolMessageParam, OpenAIUserMessageParam, ) -from llama_stack.apis.tools import RAGQueryConfig, ToolGroups, ToolRuntime +from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.vector_io import VectorIO from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool @@ -214,11 +217,13 @@ class OpenAIResponsesImpl: tool_groups_api: ToolGroups, tool_runtime_api: ToolRuntime, responses_store: ResponsesStore, + vector_io_api: VectorIO, # VectorIO ): self.inference_api = inference_api self.tool_groups_api = tool_groups_api self.tool_runtime_api = tool_runtime_api self.responses_store = responses_store + self.vector_io_api = vector_io_api async def _prepend_previous_response( self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None @@ -666,6 +671,71 @@ class OpenAIResponsesImpl: raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") return chat_tools, mcp_tool_to_server, mcp_list_message + async def _execute_knowledge_search_via_vector_store( + self, + query: str, + response_file_search_tool: OpenAIResponseInputToolFileSearch, + ) -> ToolInvocationResult: + """Execute knowledge search using vector_stores.search API with filters support.""" + search_results = [] + + # Create search tasks for all vector stores + async def search_single_store(vector_store_id): + try: + search_response = await self.vector_io_api.openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=response_file_search_tool.filters, + max_num_results=response_file_search_tool.max_num_results, + ranking_options=response_file_search_tool.ranking_options, + rewrite_query=False, + ) + return search_response.data + except Exception as e: + logger.warning(f"Failed to search vector store {vector_store_id}: {e}") + return [] + + # Run all searches in parallel using gather + search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] + all_results = await asyncio.gather(*search_tasks) + + # Flatten results + for results in all_results: + search_results.extend(results) + + # Convert search results to tool result format matching memory.py + # Format the results as interleaved content similar to memory.py + content_items = [] + content_items.append( + TextContentItem( + text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ) + + for i, result_item in enumerate(search_results): + chunk_text = result_item.content[0].text if result_item.content else "" + metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" + if result_item.attributes: + metadata_text += f", attributes: {result_item.attributes}" + text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" + content_items.append(TextContentItem(text=text_content)) + + content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + content_items.append( + TextContentItem( + text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', + ) + ) + + return ToolInvocationResult( + content=content_items, + metadata={ + "document_ids": [r.file_id for r in search_results], + "chunks": [r.content[0].text if r.content else "" for r in search_results], + "scores": [r.score for r in search_results], + }, + ) + async def _execute_tool_call( self, tool_call: OpenAIChatCompletionToolCall, @@ -693,21 +763,19 @@ class OpenAIResponsesImpl: tool_name=function.name, kwargs=tool_kwargs, ) - else: - if function.name == "knowledge_search": - response_file_search_tool = next( - t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch) + elif function.name == "knowledge_search": + response_file_search_tool = next( + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None + ) + if response_file_search_tool: + # Use vector_stores.search API instead of knowledge_search tool + # to support filters and ranking_options + query = tool_kwargs.get("query", "") + result = await self._execute_knowledge_search_via_vector_store( + query=query, + response_file_search_tool=response_file_search_tool, ) - if response_file_search_tool: - if response_file_search_tool.filters: - logger.warning("Filters are not yet supported for file_search tool") - if response_file_search_tool.ranking_options: - logger.warning("Ranking options are not yet supported for file_search tool") - tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids - tool_kwargs["query_config"] = RAGQueryConfig( - mode="vector", - max_chunks=response_file_search_tool.max_num_results, - ) + else: result = await self.tool_runtime_api.invoke_tool( tool_name=function.name, kwargs=tool_kwargs, diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 027cdcb11..12c1b5022 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -17,6 +17,7 @@ from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, + SearchRankingOptions, VectorIO, VectorStoreDeleteResponse, VectorStoreListResponse, @@ -249,7 +250,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): query: str | list[str], filters: dict[str, Any] | None = None, max_num_results: int | None = 10, - ranking_options: dict[str, Any] | None = None, + ranking_options: SearchRankingOptions | None = None, rewrite_query: bool | None = False, ) -> VectorStoreSearchResponsePage: raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 42ab4fa3e..31a9535db 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -19,6 +19,7 @@ from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, + SearchRankingOptions, VectorIO, VectorStoreDeleteResponse, VectorStoreListResponse, @@ -247,7 +248,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): query: str | list[str], filters: dict[str, Any] | None = None, max_num_results: int | None = 10, - ranking_options: dict[str, Any] | None = None, + ranking_options: SearchRankingOptions | None = None, rewrite_query: bool | None = False, ) -> VectorStoreSearchResponsePage: raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index fa7782f04..1ebf861e2 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -17,6 +17,7 @@ from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, + SearchRankingOptions, VectorIO, VectorStoreDeleteResponse, VectorStoreListResponse, @@ -249,7 +250,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): query: str | list[str], filters: dict[str, Any] | None = None, max_num_results: int | None = 10, - ranking_options: dict[str, Any] | None = None, + ranking_options: SearchRankingOptions | None = None, rewrite_query: bool | None = False, ) -> VectorStoreSearchResponsePage: raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index f9701897a..13d2d7423 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -15,6 +15,7 @@ from llama_stack.apis.files import Files from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( QueryChunksResponse, + SearchRankingOptions, VectorStoreContent, VectorStoreDeleteResponse, VectorStoreListResponse, @@ -296,7 +297,7 @@ class OpenAIVectorStoreMixin(ABC): query: str | list[str], filters: dict[str, Any] | None = None, max_num_results: int | None = 10, - ranking_options: dict[str, Any] | None = None, + ranking_options: SearchRankingOptions | None = None, rewrite_query: bool | None = False, # search_mode: Literal["keyword", "vector", "hybrid"] = "vector", ) -> VectorStoreSearchResponsePage: @@ -314,7 +315,11 @@ class OpenAIVectorStoreMixin(ABC): search_query = query try: - score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0 + score_threshold = ( + ranking_options.score_threshold + if ranking_options and ranking_options.score_threshold is not None + else 0.0 + ) params = { "max_chunks": max_num_results * CHUNK_MULTIPLIER, "score_threshold": score_threshold, @@ -399,12 +404,49 @@ class OpenAIVectorStoreMixin(ABC): def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: """Check if metadata matches the provided filters.""" - for key, value in filters.items(): + if not filters: + return True + + filter_type = filters.get("type") + + if filter_type in ["eq", "ne", "gt", "gte", "lt", "lte"]: + # Comparison filter + key = filters.get("key") + value = filters.get("value") + if key not in metadata: return False - if metadata[key] != value: - return False - return True + + metadata_value = metadata[key] + + if filter_type == "eq": + return bool(metadata_value == value) + elif filter_type == "ne": + return bool(metadata_value != value) + elif filter_type == "gt": + return bool(metadata_value > value) + elif filter_type == "gte": + return bool(metadata_value >= value) + elif filter_type == "lt": + return bool(metadata_value < value) + elif filter_type == "lte": + return bool(metadata_value <= value) + else: + raise ValueError(f"Unsupported filter type: {filter_type}") + + elif filter_type == "and": + # All filters must match + sub_filters = filters.get("filters", []) + return all(self._matches_filters(metadata, f) for f in sub_filters) + + elif filter_type == "or": + # At least one filter must match + sub_filters = filters.get("filters", []) + return any(self._matches_filters(metadata, f) for f in sub_filters) + + else: + # Unknown filter type, default to no match + raise ValueError(f"Unsupported filter type: {filter_type}") async def openai_attach_file_to_vector_store( self, diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 34f22c39f..6bf1b7e0c 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -71,12 +71,21 @@ def mock_responses_store(): @pytest.fixture -def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store): +def mock_vector_io_api(): + vector_io_api = AsyncMock() + return vector_io_api + + +@pytest.fixture +def openai_responses_impl( + mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api +): return OpenAIResponsesImpl( inference_api=mock_inference_api, tool_groups_api=mock_tool_groups_api, tool_runtime_api=mock_tool_runtime_api, responses_store=mock_responses_store, + vector_io_api=mock_vector_io_api, ) diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py index 1c9cdaa3a..08bbb2252 100644 --- a/tests/verifications/openai_api/test_responses.py +++ b/tests/verifications/openai_api/test_responses.py @@ -714,3 +714,277 @@ def test_response_text_format(request, openai_client, model, provider, verificat assert "paris" in response.output_text.lower() if text_format["type"] == "json_schema": assert "paris" in json.loads(response.output_text)["capital"].lower() + + +@pytest.fixture +def vector_store_with_filtered_files(request, openai_client, model, provider, verification_config, tmp_path_factory): + """Create a vector store with multiple files that have different attributes for filtering tests.""" + if isinstance(openai_client, LlamaStackAsLibraryClient): + pytest.skip("Responses API file search is not yet supported in library client.") + + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + vector_store = _new_vector_store(openai_client, "test_vector_store_with_filters") + tmp_path = tmp_path_factory.mktemp("filter_test_files") + + # Create multiple files with different attributes + files_data = [ + { + "name": "us_marketing_q1.txt", + "content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.", + "attributes": { + "region": "us", + "category": "marketing", + "date": 1672531200, # Jan 1, 2023 + }, + }, + { + "name": "us_engineering_q2.txt", + "content": "US technical updates for Q2 2023. New features deployed in the US region.", + "attributes": { + "region": "us", + "category": "engineering", + "date": 1680307200, # Apr 1, 2023 + }, + }, + { + "name": "eu_marketing_q1.txt", + "content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.", + "attributes": { + "region": "eu", + "category": "marketing", + "date": 1672531200, # Jan 1, 2023 + }, + }, + { + "name": "asia_sales_q3.txt", + "content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.", + "attributes": { + "region": "asia", + "category": "sales", + "date": 1688169600, # Jul 1, 2023 + }, + }, + ] + + file_ids = [] + for file_data in files_data: + # Create file + file_path = tmp_path / file_data["name"] + file_path.write_text(file_data["content"]) + + # Upload file + file_response = _upload_file(openai_client, file_data["name"], str(file_path)) + file_ids.append(file_response.id) + + # Attach file to vector store with attributes + file_attach_response = openai_client.vector_stores.files.create( + vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"] + ) + + # Wait for attachment + while file_attach_response.status == "in_progress": + time.sleep(0.1) + file_attach_response = openai_client.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=file_response.id, + ) + assert file_attach_response.status == "completed" + + yield vector_store + + # Cleanup: delete vector store and files + try: + openai_client.vector_stores.delete(vector_store_id=vector_store.id) + for file_id in file_ids: + try: + openai_client.files.delete(file_id=file_id) + except Exception: + pass # File might already be deleted + except Exception: + pass # Best effort cleanup + + +def test_response_file_search_filter_by_region(openai_client, model, vector_store_with_filtered_files): + """Test file search with region equality filter.""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": {"type": "eq", "key": "region", "value": "us"}, + } + ] + + response = openai_client.responses.create( + model=model, + input="What are the updates from the US region?", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + # Verify file search was called with US filter + assert len(response.output) > 1 + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should only return US files (not EU or Asia files) + for result in response.output[0].results: + assert "us" in result.text.lower() or "US" in result.text + # Ensure non-US regions are NOT returned + assert "european" not in result.text.lower() + assert "asia" not in result.text.lower() + + +def test_response_file_search_filter_by_category(openai_client, model, vector_store_with_filtered_files): + """Test file search with category equality filter.""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": {"type": "eq", "key": "category", "value": "marketing"}, + } + ] + + response = openai_client.responses.create( + model=model, + input="Show me all marketing reports", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should only return marketing files (not engineering or sales) + for result in response.output[0].results: + # Marketing files should have promotional/advertising content + assert "promotional" in result.text.lower() or "advertising" in result.text.lower() + # Ensure non-marketing categories are NOT returned + assert "technical" not in result.text.lower() + assert "revenue figures" not in result.text.lower() + + +def test_response_file_search_filter_by_date_range(openai_client, model, vector_store_with_filtered_files): + """Test file search with date range filter using compound AND.""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": { + "type": "and", + "filters": [ + { + "type": "gte", + "key": "date", + "value": 1672531200, # Jan 1, 2023 + }, + { + "type": "lt", + "key": "date", + "value": 1680307200, # Apr 1, 2023 + }, + ], + }, + } + ] + + response = openai_client.responses.create( + model=model, + input="What happened in Q1 2023?", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should only return Q1 files (not Q2 or Q3) + for result in response.output[0].results: + assert "q1" in result.text.lower() + # Ensure non-Q1 quarters are NOT returned + assert "q2" not in result.text.lower() + assert "q3" not in result.text.lower() + + +def test_response_file_search_filter_compound_and(openai_client, model, vector_store_with_filtered_files): + """Test file search with compound AND filter (region AND category).""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": { + "type": "and", + "filters": [ + {"type": "eq", "key": "region", "value": "us"}, + {"type": "eq", "key": "category", "value": "engineering"}, + ], + }, + } + ] + + response = openai_client.responses.create( + model=model, + input="What are the engineering updates from the US?", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should only return US engineering files + assert len(response.output[0].results) >= 1 + for result in response.output[0].results: + assert "us" in result.text.lower() and "technical" in result.text.lower() + # Ensure it's not from other regions or categories + assert "european" not in result.text.lower() and "asia" not in result.text.lower() + assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower() + + +def test_response_file_search_filter_compound_or(openai_client, model, vector_store_with_filtered_files): + """Test file search with compound OR filter (marketing OR sales).""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": { + "type": "or", + "filters": [ + {"type": "eq", "key": "category", "value": "marketing"}, + {"type": "eq", "key": "category", "value": "sales"}, + ], + }, + } + ] + + response = openai_client.responses.create( + model=model, + input="Show me marketing and sales documents", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should return marketing and sales files, but NOT engineering + categories_found = set() + for result in response.output[0].results: + text_lower = result.text.lower() + if "promotional" in text_lower or "advertising" in text_lower: + categories_found.add("marketing") + if "revenue figures" in text_lower: + categories_found.add("sales") + # Ensure engineering files are NOT returned + assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}" + + # Verify we got at least one of the expected categories + assert len(categories_found) > 0, "Should have found at least one marketing or sales file" + assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"