feat: support filters in file search (#2472)

# What does this PR do?
Move to use vector_stores.search for file search tool in Responses,
which supports filters.

closes #2435 

## Test Plan
Added e2e test with fitlers.
myenv ❯ llama stack run llama_stack/templates/fireworks/run.yaml

pytest -sv tests/verifications/openai_api/test_responses.py \
  -k 'file_search and filters' \
  --base-url=http://localhost:8321/v1/openai/v1 \
  --model=meta-llama/Llama-3.3-70B-Instruct
This commit is contained in:
ehhuang 2025-06-18 21:50:55 -07:00 committed by GitHub
parent fd37a50e6a
commit db2cd9e8f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 449 additions and 63 deletions

View file

@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
@ -400,11 +401,6 @@ class OpenAIResponseInputToolFunction(BaseModel):
strict: bool | None = None
class FileSearchRankingOptions(BaseModel):
ranker: str | None = None
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
type: Literal["file_search"] = "file_search"

View file

@ -157,6 +157,11 @@ VectorStoreChunkingStrategy = Annotated[
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
class SearchRankingOptions(BaseModel):
ranker: str | None = None
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
@json_schema_type
class VectorStoreFileLastError(BaseModel):
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
@ -319,7 +324,7 @@ class VectorIO(Protocol):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store.

View file

@ -14,6 +14,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@ -246,7 +247,7 @@ class VectorIORouter(VectorIO):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")

View file

@ -85,6 +85,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_groups_api=self.tool_groups_api,
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
)
async def create_agent(

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import time
import uuid
@ -42,6 +43,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference.inference import (
Inference,
OpenAIAssistantMessageParam,
@ -64,7 +66,8 @@ from llama_stack.apis.inference.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools import RAGQueryConfig, ToolGroups, ToolRuntime
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
@ -214,11 +217,13 @@ class OpenAIResponsesImpl:
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
@ -666,6 +671,71 @@ class OpenAIResponsesImpl:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_knowledge_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
async def search_single_store(vector_store_id):
try:
search_response = await self.vector_io_api.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=response_file_search_tool.filters,
max_num_results=response_file_search_tool.max_num_results,
ranking_options=response_file_search_tool.ranking_options,
rewrite_query=False,
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
all_results = await asyncio.gather(*search_tasks)
# Flatten results
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
)
)
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
},
)
async def _execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
@ -693,21 +763,19 @@ class OpenAIResponsesImpl:
tool_name=function.name,
kwargs=tool_kwargs,
)
else:
if function.name == "knowledge_search":
response_file_search_tool = next(
t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)
elif function.name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
if response_file_search_tool:
if response_file_search_tool.filters:
logger.warning("Filters are not yet supported for file_search tool")
if response_file_search_tool.ranking_options:
logger.warning("Ranking options are not yet supported for file_search tool")
tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids
tool_kwargs["query_config"] = RAGQueryConfig(
mode="vector",
max_chunks=response_file_search_tool.max_num_results,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=tool_kwargs,

View file

@ -17,6 +17,7 @@ from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@ -249,7 +250,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -19,6 +19,7 @@ from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@ -247,7 +248,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -17,6 +17,7 @@ from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@ -249,7 +250,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -15,6 +15,7 @@ from llama_stack.apis.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
QueryChunksResponse,
SearchRankingOptions,
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreListResponse,
@ -296,7 +297,7 @@ class OpenAIVectorStoreMixin(ABC):
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponsePage:
@ -314,7 +315,11 @@ class OpenAIVectorStoreMixin(ABC):
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
score_threshold = (
ranking_options.score_threshold
if ranking_options and ranking_options.score_threshold is not None
else 0.0
)
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
@ -399,12 +404,49 @@ class OpenAIVectorStoreMixin(ABC):
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if not filters:
return True
filter_type = filters.get("type")
if filter_type in ["eq", "ne", "gt", "gte", "lt", "lte"]:
# Comparison filter
key = filters.get("key")
value = filters.get("value")
if key not in metadata:
return False
if metadata[key] != value:
return False
return True
metadata_value = metadata[key]
if filter_type == "eq":
return bool(metadata_value == value)
elif filter_type == "ne":
return bool(metadata_value != value)
elif filter_type == "gt":
return bool(metadata_value > value)
elif filter_type == "gte":
return bool(metadata_value >= value)
elif filter_type == "lt":
return bool(metadata_value < value)
elif filter_type == "lte":
return bool(metadata_value <= value)
else:
raise ValueError(f"Unsupported filter type: {filter_type}")
elif filter_type == "and":
# All filters must match
sub_filters = filters.get("filters", [])
return all(self._matches_filters(metadata, f) for f in sub_filters)
elif filter_type == "or":
# At least one filter must match
sub_filters = filters.get("filters", [])
return any(self._matches_filters(metadata, f) for f in sub_filters)
else:
# Unknown filter type, default to no match
raise ValueError(f"Unsupported filter type: {filter_type}")
async def openai_attach_file_to_vector_store(
self,