diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 486068258..5407f9808 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -7296,7 +7296,7 @@
}
},
"additionalProperties": false,
- "title": "FileSearchRankingOptions"
+ "title": "SearchRankingOptions"
}
},
"additionalProperties": false,
@@ -13478,28 +13478,16 @@
},
"ranking_options": {
"type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
+ "properties": {
+ "ranker": {
+ "type": "string"
+ },
+ "score_threshold": {
+ "type": "number",
+ "default": 0.0
+ }
},
+ "additionalProperties": false,
"description": "Ranking options for fine-tuning the search results."
},
"rewrite_query": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 72f60032e..a354e4bd0 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -5179,7 +5179,7 @@ components:
type: number
default: 0.0
additionalProperties: false
- title: FileSearchRankingOptions
+ title: SearchRankingOptions
additionalProperties: false
required:
- type
@@ -9408,14 +9408,13 @@ components:
Maximum number of results to return (1 to 50 inclusive, default 10).
ranking_options:
type: object
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
+ properties:
+ ranker:
+ type: string
+ score_threshold:
+ type: number
+ default: 0.0
+ additionalProperties: false
description: >-
Ranking options for fine-tuning the search results.
rewrite_query:
diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py
index 2e1cb257a..addb72f14 100644
--- a/llama_stack/apis/agents/openai_responses.py
+++ b/llama_stack/apis/agents/openai_responses.py
@@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
+from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
@@ -400,11 +401,6 @@ class OpenAIResponseInputToolFunction(BaseModel):
strict: bool | None = None
-class FileSearchRankingOptions(BaseModel):
- ranker: str | None = None
- score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
-
-
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
type: Literal["file_search"] = "file_search"
diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py
index 77d4cfc5a..20cc594cc 100644
--- a/llama_stack/apis/vector_io/vector_io.py
+++ b/llama_stack/apis/vector_io/vector_io.py
@@ -157,6 +157,11 @@ VectorStoreChunkingStrategy = Annotated[
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
+class SearchRankingOptions(BaseModel):
+ ranker: str | None = None
+ score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
+
+
@json_schema_type
class VectorStoreFileLastError(BaseModel):
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
@@ -319,7 +324,7 @@ class VectorIO(Protocol):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
- ranking_options: dict[str, Any] | None = None,
+ ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store.
diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py
index 44c1fafa7..c4191b8a1 100644
--- a/llama_stack/distribution/routers/vector_io.py
+++ b/llama_stack/distribution/routers/vector_io.py
@@ -14,6 +14,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
+ SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@@ -246,7 +247,7 @@ class VectorIORouter(VectorIO):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
- ranking_options: dict[str, Any] | None = None,
+ ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py
index ea3c5da97..6b05f90dd 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agents.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agents.py
@@ -85,6 +85,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_groups_api=self.tool_groups_api,
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
+ vector_io_api=self.vector_io_api,
)
async def create_agent(
diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
index 33fcbfa5d..4465a32fe 100644
--- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
+++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
@@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import asyncio
import json
import time
import uuid
@@ -42,6 +43,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
OpenAIResponseTextFormat,
)
+from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference.inference import (
Inference,
OpenAIAssistantMessageParam,
@@ -64,7 +66,8 @@ from llama_stack.apis.inference.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
-from llama_stack.apis.tools import RAGQueryConfig, ToolGroups, ToolRuntime
+from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
+from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
@@ -214,11 +217,13 @@ class OpenAIResponsesImpl:
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
+ vector_io_api: VectorIO, # VectorIO
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
+ self.vector_io_api = vector_io_api
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
@@ -666,6 +671,71 @@ class OpenAIResponsesImpl:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
+ async def _execute_knowledge_search_via_vector_store(
+ self,
+ query: str,
+ response_file_search_tool: OpenAIResponseInputToolFileSearch,
+ ) -> ToolInvocationResult:
+ """Execute knowledge search using vector_stores.search API with filters support."""
+ search_results = []
+
+ # Create search tasks for all vector stores
+ async def search_single_store(vector_store_id):
+ try:
+ search_response = await self.vector_io_api.openai_search_vector_store(
+ vector_store_id=vector_store_id,
+ query=query,
+ filters=response_file_search_tool.filters,
+ max_num_results=response_file_search_tool.max_num_results,
+ ranking_options=response_file_search_tool.ranking_options,
+ rewrite_query=False,
+ )
+ return search_response.data
+ except Exception as e:
+ logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
+ return []
+
+ # Run all searches in parallel using gather
+ search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
+ all_results = await asyncio.gather(*search_tasks)
+
+ # Flatten results
+ for results in all_results:
+ search_results.extend(results)
+
+ # Convert search results to tool result format matching memory.py
+ # Format the results as interleaved content similar to memory.py
+ content_items = []
+ content_items.append(
+ TextContentItem(
+ text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
+ )
+ )
+
+ for i, result_item in enumerate(search_results):
+ chunk_text = result_item.content[0].text if result_item.content else ""
+ metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
+ if result_item.attributes:
+ metadata_text += f", attributes: {result_item.attributes}"
+ text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
+ content_items.append(TextContentItem(text=text_content))
+
+ content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
+ content_items.append(
+ TextContentItem(
+ text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
+ )
+ )
+
+ return ToolInvocationResult(
+ content=content_items,
+ metadata={
+ "document_ids": [r.file_id for r in search_results],
+ "chunks": [r.content[0].text if r.content else "" for r in search_results],
+ "scores": [r.score for r in search_results],
+ },
+ )
+
async def _execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
@@ -693,21 +763,19 @@ class OpenAIResponsesImpl:
tool_name=function.name,
kwargs=tool_kwargs,
)
- else:
- if function.name == "knowledge_search":
- response_file_search_tool = next(
- t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)
+ elif function.name == "knowledge_search":
+ response_file_search_tool = next(
+ (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
+ )
+ if response_file_search_tool:
+ # Use vector_stores.search API instead of knowledge_search tool
+ # to support filters and ranking_options
+ query = tool_kwargs.get("query", "")
+ result = await self._execute_knowledge_search_via_vector_store(
+ query=query,
+ response_file_search_tool=response_file_search_tool,
)
- if response_file_search_tool:
- if response_file_search_tool.filters:
- logger.warning("Filters are not yet supported for file_search tool")
- if response_file_search_tool.ranking_options:
- logger.warning("Ranking options are not yet supported for file_search tool")
- tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids
- tool_kwargs["query_config"] = RAGQueryConfig(
- mode="vector",
- max_chunks=response_file_search_tool.max_num_results,
- )
+ else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=tool_kwargs,
diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py
index 027cdcb11..12c1b5022 100644
--- a/llama_stack/providers/remote/vector_io/chroma/chroma.py
+++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py
@@ -17,6 +17,7 @@ from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
+ SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@@ -249,7 +250,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
- ranking_options: dict[str, Any] | None = None,
+ ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py
index 42ab4fa3e..31a9535db 100644
--- a/llama_stack/providers/remote/vector_io/milvus/milvus.py
+++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py
@@ -19,6 +19,7 @@ from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
+ SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@@ -247,7 +248,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
- ranking_options: dict[str, Any] | None = None,
+ ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
index fa7782f04..1ebf861e2 100644
--- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
+++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
@@ -17,6 +17,7 @@ from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
+ SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@@ -249,7 +250,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
- ranking_options: dict[str, Any] | None = None,
+ ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
index f9701897a..13d2d7423 100644
--- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
+++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
@@ -15,6 +15,7 @@ from llama_stack.apis.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
QueryChunksResponse,
+ SearchRankingOptions,
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@@ -296,7 +297,7 @@ class OpenAIVectorStoreMixin(ABC):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
- ranking_options: dict[str, Any] | None = None,
+ ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponsePage:
@@ -314,7 +315,11 @@ class OpenAIVectorStoreMixin(ABC):
search_query = query
try:
- score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
+ score_threshold = (
+ ranking_options.score_threshold
+ if ranking_options and ranking_options.score_threshold is not None
+ else 0.0
+ )
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
@@ -399,12 +404,49 @@ class OpenAIVectorStoreMixin(ABC):
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
- for key, value in filters.items():
+ if not filters:
+ return True
+
+ filter_type = filters.get("type")
+
+ if filter_type in ["eq", "ne", "gt", "gte", "lt", "lte"]:
+ # Comparison filter
+ key = filters.get("key")
+ value = filters.get("value")
+
if key not in metadata:
return False
- if metadata[key] != value:
- return False
- return True
+
+ metadata_value = metadata[key]
+
+ if filter_type == "eq":
+ return bool(metadata_value == value)
+ elif filter_type == "ne":
+ return bool(metadata_value != value)
+ elif filter_type == "gt":
+ return bool(metadata_value > value)
+ elif filter_type == "gte":
+ return bool(metadata_value >= value)
+ elif filter_type == "lt":
+ return bool(metadata_value < value)
+ elif filter_type == "lte":
+ return bool(metadata_value <= value)
+ else:
+ raise ValueError(f"Unsupported filter type: {filter_type}")
+
+ elif filter_type == "and":
+ # All filters must match
+ sub_filters = filters.get("filters", [])
+ return all(self._matches_filters(metadata, f) for f in sub_filters)
+
+ elif filter_type == "or":
+ # At least one filter must match
+ sub_filters = filters.get("filters", [])
+ return any(self._matches_filters(metadata, f) for f in sub_filters)
+
+ else:
+ # Unknown filter type, default to no match
+ raise ValueError(f"Unsupported filter type: {filter_type}")
async def openai_attach_file_to_vector_store(
self,
diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py
index 34f22c39f..6bf1b7e0c 100644
--- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py
+++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py
@@ -71,12 +71,21 @@ def mock_responses_store():
@pytest.fixture
-def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store):
+def mock_vector_io_api():
+ vector_io_api = AsyncMock()
+ return vector_io_api
+
+
+@pytest.fixture
+def openai_responses_impl(
+ mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
+):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
+ vector_io_api=mock_vector_io_api,
)
diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py
index 1c9cdaa3a..08bbb2252 100644
--- a/tests/verifications/openai_api/test_responses.py
+++ b/tests/verifications/openai_api/test_responses.py
@@ -714,3 +714,277 @@ def test_response_text_format(request, openai_client, model, provider, verificat
assert "paris" in response.output_text.lower()
if text_format["type"] == "json_schema":
assert "paris" in json.loads(response.output_text)["capital"].lower()
+
+
+@pytest.fixture
+def vector_store_with_filtered_files(request, openai_client, model, provider, verification_config, tmp_path_factory):
+ """Create a vector store with multiple files that have different attributes for filtering tests."""
+ if isinstance(openai_client, LlamaStackAsLibraryClient):
+ pytest.skip("Responses API file search is not yet supported in library client.")
+
+ test_name_base = get_base_test_name(request)
+ if should_skip_test(verification_config, provider, model, test_name_base):
+ pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
+
+ vector_store = _new_vector_store(openai_client, "test_vector_store_with_filters")
+ tmp_path = tmp_path_factory.mktemp("filter_test_files")
+
+ # Create multiple files with different attributes
+ files_data = [
+ {
+ "name": "us_marketing_q1.txt",
+ "content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
+ "attributes": {
+ "region": "us",
+ "category": "marketing",
+ "date": 1672531200, # Jan 1, 2023
+ },
+ },
+ {
+ "name": "us_engineering_q2.txt",
+ "content": "US technical updates for Q2 2023. New features deployed in the US region.",
+ "attributes": {
+ "region": "us",
+ "category": "engineering",
+ "date": 1680307200, # Apr 1, 2023
+ },
+ },
+ {
+ "name": "eu_marketing_q1.txt",
+ "content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
+ "attributes": {
+ "region": "eu",
+ "category": "marketing",
+ "date": 1672531200, # Jan 1, 2023
+ },
+ },
+ {
+ "name": "asia_sales_q3.txt",
+ "content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
+ "attributes": {
+ "region": "asia",
+ "category": "sales",
+ "date": 1688169600, # Jul 1, 2023
+ },
+ },
+ ]
+
+ file_ids = []
+ for file_data in files_data:
+ # Create file
+ file_path = tmp_path / file_data["name"]
+ file_path.write_text(file_data["content"])
+
+ # Upload file
+ file_response = _upload_file(openai_client, file_data["name"], str(file_path))
+ file_ids.append(file_response.id)
+
+ # Attach file to vector store with attributes
+ file_attach_response = openai_client.vector_stores.files.create(
+ vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"]
+ )
+
+ # Wait for attachment
+ while file_attach_response.status == "in_progress":
+ time.sleep(0.1)
+ file_attach_response = openai_client.vector_stores.files.retrieve(
+ vector_store_id=vector_store.id,
+ file_id=file_response.id,
+ )
+ assert file_attach_response.status == "completed"
+
+ yield vector_store
+
+ # Cleanup: delete vector store and files
+ try:
+ openai_client.vector_stores.delete(vector_store_id=vector_store.id)
+ for file_id in file_ids:
+ try:
+ openai_client.files.delete(file_id=file_id)
+ except Exception:
+ pass # File might already be deleted
+ except Exception:
+ pass # Best effort cleanup
+
+
+def test_response_file_search_filter_by_region(openai_client, model, vector_store_with_filtered_files):
+ """Test file search with region equality filter."""
+ tools = [
+ {
+ "type": "file_search",
+ "vector_store_ids": [vector_store_with_filtered_files.id],
+ "filters": {"type": "eq", "key": "region", "value": "us"},
+ }
+ ]
+
+ response = openai_client.responses.create(
+ model=model,
+ input="What are the updates from the US region?",
+ tools=tools,
+ stream=False,
+ include=["file_search_call.results"],
+ )
+
+ # Verify file search was called with US filter
+ assert len(response.output) > 1
+ assert response.output[0].type == "file_search_call"
+ assert response.output[0].status == "completed"
+ assert response.output[0].results
+ # Should only return US files (not EU or Asia files)
+ for result in response.output[0].results:
+ assert "us" in result.text.lower() or "US" in result.text
+ # Ensure non-US regions are NOT returned
+ assert "european" not in result.text.lower()
+ assert "asia" not in result.text.lower()
+
+
+def test_response_file_search_filter_by_category(openai_client, model, vector_store_with_filtered_files):
+ """Test file search with category equality filter."""
+ tools = [
+ {
+ "type": "file_search",
+ "vector_store_ids": [vector_store_with_filtered_files.id],
+ "filters": {"type": "eq", "key": "category", "value": "marketing"},
+ }
+ ]
+
+ response = openai_client.responses.create(
+ model=model,
+ input="Show me all marketing reports",
+ tools=tools,
+ stream=False,
+ include=["file_search_call.results"],
+ )
+
+ assert response.output[0].type == "file_search_call"
+ assert response.output[0].status == "completed"
+ assert response.output[0].results
+ # Should only return marketing files (not engineering or sales)
+ for result in response.output[0].results:
+ # Marketing files should have promotional/advertising content
+ assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
+ # Ensure non-marketing categories are NOT returned
+ assert "technical" not in result.text.lower()
+ assert "revenue figures" not in result.text.lower()
+
+
+def test_response_file_search_filter_by_date_range(openai_client, model, vector_store_with_filtered_files):
+ """Test file search with date range filter using compound AND."""
+ tools = [
+ {
+ "type": "file_search",
+ "vector_store_ids": [vector_store_with_filtered_files.id],
+ "filters": {
+ "type": "and",
+ "filters": [
+ {
+ "type": "gte",
+ "key": "date",
+ "value": 1672531200, # Jan 1, 2023
+ },
+ {
+ "type": "lt",
+ "key": "date",
+ "value": 1680307200, # Apr 1, 2023
+ },
+ ],
+ },
+ }
+ ]
+
+ response = openai_client.responses.create(
+ model=model,
+ input="What happened in Q1 2023?",
+ tools=tools,
+ stream=False,
+ include=["file_search_call.results"],
+ )
+
+ assert response.output[0].type == "file_search_call"
+ assert response.output[0].status == "completed"
+ assert response.output[0].results
+ # Should only return Q1 files (not Q2 or Q3)
+ for result in response.output[0].results:
+ assert "q1" in result.text.lower()
+ # Ensure non-Q1 quarters are NOT returned
+ assert "q2" not in result.text.lower()
+ assert "q3" not in result.text.lower()
+
+
+def test_response_file_search_filter_compound_and(openai_client, model, vector_store_with_filtered_files):
+ """Test file search with compound AND filter (region AND category)."""
+ tools = [
+ {
+ "type": "file_search",
+ "vector_store_ids": [vector_store_with_filtered_files.id],
+ "filters": {
+ "type": "and",
+ "filters": [
+ {"type": "eq", "key": "region", "value": "us"},
+ {"type": "eq", "key": "category", "value": "engineering"},
+ ],
+ },
+ }
+ ]
+
+ response = openai_client.responses.create(
+ model=model,
+ input="What are the engineering updates from the US?",
+ tools=tools,
+ stream=False,
+ include=["file_search_call.results"],
+ )
+
+ assert response.output[0].type == "file_search_call"
+ assert response.output[0].status == "completed"
+ assert response.output[0].results
+ # Should only return US engineering files
+ assert len(response.output[0].results) >= 1
+ for result in response.output[0].results:
+ assert "us" in result.text.lower() and "technical" in result.text.lower()
+ # Ensure it's not from other regions or categories
+ assert "european" not in result.text.lower() and "asia" not in result.text.lower()
+ assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
+
+
+def test_response_file_search_filter_compound_or(openai_client, model, vector_store_with_filtered_files):
+ """Test file search with compound OR filter (marketing OR sales)."""
+ tools = [
+ {
+ "type": "file_search",
+ "vector_store_ids": [vector_store_with_filtered_files.id],
+ "filters": {
+ "type": "or",
+ "filters": [
+ {"type": "eq", "key": "category", "value": "marketing"},
+ {"type": "eq", "key": "category", "value": "sales"},
+ ],
+ },
+ }
+ ]
+
+ response = openai_client.responses.create(
+ model=model,
+ input="Show me marketing and sales documents",
+ tools=tools,
+ stream=False,
+ include=["file_search_call.results"],
+ )
+
+ assert response.output[0].type == "file_search_call"
+ assert response.output[0].status == "completed"
+ assert response.output[0].results
+ # Should return marketing and sales files, but NOT engineering
+ categories_found = set()
+ for result in response.output[0].results:
+ text_lower = result.text.lower()
+ if "promotional" in text_lower or "advertising" in text_lower:
+ categories_found.add("marketing")
+ if "revenue figures" in text_lower:
+ categories_found.add("sales")
+ # Ensure engineering files are NOT returned
+ assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
+
+ # Verify we got at least one of the expected categories
+ assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
+ assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"